| # Copyright 2025 The Chromium Authors |
| # Use of this source code is governed by a BSD-style license that can be |
| # found in the LICENSE file. |
| """A promptfoo provider for the Gemini CLI.""" |
| |
| import dataclasses |
| import functools |
| import json |
| import logging |
| import os |
| import pathlib |
| import subprocess |
| import sys |
| import textwrap |
| import threading |
| import time |
| from collections.abc import Collection |
| from typing import Any |
| |
| import constants |
| |
| sys.path.append(str(constants.CHROMIUM_SRC)) |
| from agents.common import gemini_helpers |
| from agents.testing import checkout_helpers |
| |
| DEFAULT_TIMEOUT_SECONDS = 600 |
| DEFAULT_EXTENSIONS = [ |
| 'build-information', |
| 'depot-tools', |
| 'landmines', |
| 'test-landmines', |
| ] |
| |
| |
| @dataclasses.dataclass |
| class GeminiCliArguments: |
| """Information that is relevant to starting gemini-cli for a test.""" |
| # The command to run gemini-cli. |
| command: list[str] |
| # The home directory that gemini-cli will use. |
| home_dir: pathlib.Path | None |
| # The environment that gemini-cli will be started in. |
| env: dict[str, str] |
| # The duration that gemini-cli will be allowed to run for. |
| timeout_seconds: int |
| # The system prompt that gemini-cli will be run with. |
| system_prompt: str |
| # The user prompt to pass to gemini-cli |
| user_prompt: str |
| # How wide to treat the console that gemini-cli is run in. |
| console_width: int |
| |
| |
| def _stream_reader(stream, output_list: list[str], width): |
| """Reads a stream line-by-line and appends to a list.""" |
| try: |
| for line in iter(stream.readline, ''): |
| output_list.append(line) |
| wrapped_text = '\n'.join( |
| textwrap.wrap(line.rstrip('\r\n'), width=width)) |
| sys.stderr.write(wrapped_text + '\n') |
| except OSError: |
| # Stream may be closed unexpectedly |
| pass |
| finally: |
| stream.close() |
| |
| |
| def _get_sandbox_image_tag() -> str | None: |
| """Gets the full sandbox image tag.""" |
| gemini_version = gemini_helpers.get_gemini_version() |
| if not gemini_version: |
| logging.error('Failed to get gemini version.') |
| return None |
| return f'{constants.GEMINI_SANDBOX_IMAGE_URL}:{gemini_version}' |
| |
| |
| @functools.cache |
| def _get_container_path(sandbox_image: str | None) -> str | None: |
| """Gets the default PATH from the sandbox container.""" |
| if not sandbox_image: |
| return None |
| |
| # This is a Go template that iterates over all environment variables in the |
| # image's configuration and prints each one on a new line. |
| command = [ |
| 'docker', 'inspect', |
| r'--format={{range .Config.Env}}{{printf "%s\n" .}}{{end}}', |
| sandbox_image |
| ] |
| try: |
| result = subprocess.run(command, |
| stdout=subprocess.PIPE, |
| stderr=subprocess.PIPE, |
| text=True, |
| check=True) |
| logging.debug('docker inspect output:\n%s', result.stdout) |
| for line in result.stdout.splitlines(): |
| if line.startswith('PATH='): |
| return line.split('=', 1)[1] |
| |
| logging.warning('PATH not found in environment of %s', sandbox_image) |
| return None |
| except (subprocess.CalledProcessError, FileNotFoundError) as e: |
| error_message = f'Failed to get container PATH for {sandbox_image}: {e}' |
| if hasattr(e, 'stderr') and e.stderr: |
| error_message += f'\nstderr:\n{e.stderr}' |
| logging.error(error_message) |
| return None |
| |
| |
| def _get_env_with_overrides( |
| home: pathlib.Path | None = None, |
| sandbox_flags: list[str] | None = None, |
| sandbox_image: str | None = None) -> dict[str, str]: |
| """Returns a copy of the environment with the given overrides.""" |
| env = os.environ.copy() |
| if home: |
| env['HOME'] = str(home) |
| logging.debug('HOME: %s', env.get('HOME')) |
| if sandbox_flags: |
| env['SANDBOX_FLAGS'] = ' '.join(sandbox_flags) |
| logging.debug('SANDBOX_FLAGS: %s', env.get('SANDBOX_FLAGS')) |
| if sandbox_image: |
| env['GEMINI_SANDBOX_IMAGE'] = sandbox_image |
| logging.debug('GEMINI_SANDBOX_IMAGE: %s', |
| env.get('GEMINI_SANDBOX_IMAGE')) |
| return env |
| |
| |
| def _install_extensions(extensions: Collection[str] | None = None, |
| home_dir: pathlib.Path | None = None) -> None: |
| # The installation script should identify the working tree as the "repo |
| # root", so use the copy in the working tree with the CWD set |
| # appropriately for subprocesses like `git`. |
| if not extensions: |
| return |
| |
| logging.info('Installing extensions: %s', extensions) |
| command = [ |
| sys.executable, |
| pathlib.Path('agents', 'extensions', 'install.py'), |
| '--extra-extensions-dir', |
| pathlib.Path('agents', 'testing', 'extensions'), |
| 'add', |
| '--copy', |
| '--skip-prompt', |
| *extensions, |
| ] |
| result = subprocess.run(command, |
| env=_get_env_with_overrides(home=home_dir), |
| stdout=subprocess.PIPE, |
| stderr=subprocess.STDOUT, |
| text=True, |
| check=False) |
| logging.debug('Extension install output:\n%s', result.stdout) |
| result.check_returncode() |
| logging.debug('Installed extensions:\n%s', |
| _get_installed_extensions(home_dir)) |
| |
| |
| def _load_templates(templates: list[str]) -> str: |
| """Loads and combines system prompt templates.""" |
| if not templates: |
| return '' |
| |
| logging.info('Loading templates: %s', templates) |
| prompt_parts = [] |
| for template in templates: |
| with open(template, encoding='utf-8') as t: |
| prompt_parts.append(t.read()) |
| return '\n\n'.join(prompt_parts) |
| |
| |
| def _apply_changes(changes: list[dict[str, str]]) -> None: |
| """Applies changes to the repository.""" |
| if not changes: |
| return |
| |
| logging.info('Applying changes: %s', changes) |
| for change in changes: |
| if len(change) != 1: |
| raise ValueError( |
| 'Invalid change object: must have exactly one key.') |
| |
| if 'apply' in change: |
| subprocess.check_call(['git', 'apply', change['apply']]) |
| elif 'stage' in change: |
| subprocess.check_call(['git', 'add', change['stage']]) |
| else: |
| raise ValueError( |
| 'Invalid change object: key must be "apply" or "stage".') |
| |
| |
| def _get_installed_extensions(home_dir: pathlib.Path | None) -> str: |
| """Returns a string listing the installed extensions.""" |
| return subprocess.check_output( |
| [ |
| sys.executable, |
| pathlib.Path('agents', 'extensions', 'install.py'), |
| 'list', |
| ], |
| env=_get_env_with_overrides(home_dir), |
| text=True, |
| ) |
| |
| |
| def _get_sandbox_flags() -> tuple[list[str], str]: |
| """Gets flags for the gemini-cli sandbox. |
| |
| Returns: |
| A tuple (flags, error). |flags| is a list of flags to use with the |
| sandbox. |error| is an empty string if no error occurred, otherwise the |
| error string that should be surfaced to promptfoo. |
| """ |
| sandbox_flags = [] |
| depot_tools_path = checkout_helpers.get_depot_tools_path() |
| if not depot_tools_path: |
| return ([], |
| 'Sandbox requires depot_tools, but it could not be located.') |
| sandbox_flags.append(f'-v {depot_tools_path.as_posix()}:/depot_tools') |
| |
| container_path = _get_container_path(_get_sandbox_image_tag()) |
| if container_path: |
| sandbox_flags.append(f'-e PATH=/depot_tools:{container_path}') |
| else: |
| return ([], 'Could not determine container PATH. PATH will not be ' |
| 'overridden.') |
| |
| return sandbox_flags, '' |
| |
| |
| def _get_gemini_cli_arguments( |
| provider_vars: dict[str, Any], provider_config: dict[str, Any], |
| user_prompt: str) -> tuple[GeminiCliArguments | None, str]: |
| """Collects arguments relevant to starting/running gemini-cli. |
| |
| Args: |
| provider_vars: The key/value variables given to the provider. |
| provider_config: The config parsed from the test's YAML config file. |
| user_prompt: The user prompt to pass to gemini-cli |
| |
| Returns: |
| A tuple (arguments, error). On success, |arguments| will be a |
| GeminiCliArguments instance with all fields filled and |error| will be |
| an empty string. On failure, |arguments| will be None and |error| will |
| be a non-empty string containing the error message. |
| """ |
| try: |
| unparsed_timeout = provider_config.get('timeoutSeconds', |
| DEFAULT_TIMEOUT_SECONDS) |
| timeout_seconds = int(unparsed_timeout) |
| except (ValueError, TypeError): |
| return None, f'Failed to parse timeout from {unparsed_timeout}' |
| |
| gemini_cli_bin = provider_vars.get('gemini_cli_bin', 'gemini') |
| command = [gemini_cli_bin, '-y'] |
| |
| sandbox_flags = [] |
| if provider_vars.get('sandbox', False): |
| command.append('--sandbox') |
| sandbox_flags, error = _get_sandbox_flags() |
| if error: |
| return None, error |
| |
| home_dir_str = provider_vars.get('home_dir') |
| home_dir = pathlib.Path(home_dir_str) if home_dir_str else None |
| |
| return GeminiCliArguments( |
| command=command, |
| home_dir=home_dir, |
| env=_get_env_with_overrides( |
| home=home_dir, |
| sandbox_flags=sandbox_flags, |
| sandbox_image=_get_sandbox_image_tag(), |
| ), |
| timeout_seconds=timeout_seconds, |
| system_prompt=_get_system_prompt(provider_config), |
| user_prompt=user_prompt, |
| console_width=provider_vars.get('console_width', 80)), '' |
| |
| |
| def _get_system_prompt(provider_config: dict[str, Any]) -> str: |
| """Gets the system prompt to use for gemini-cli. |
| |
| Args: |
| provider_config: The config parsed from the test's YAML config file. |
| |
| Returns: |
| A string to use as the system prompt for the test. |
| """ |
| system_prompt = provider_config.get('system_prompt', '') |
| templates = provider_config.get('templates', []) |
| template_prompt = _load_templates(templates) |
| if template_prompt: |
| if system_prompt: |
| system_prompt = f'{system_prompt}\n\n{template_prompt}' |
| else: |
| system_prompt = template_prompt |
| return system_prompt |
| |
| |
| def _run_gemini_cli_with_output_streaming( |
| arguments: GeminiCliArguments) -> tuple[subprocess.Popen, list[str]]: |
| """Runs gemini-cli and with output streamed to console. |
| |
| The caller is responsible for handling any exceptions that may arise from |
| running gemini-cli. |
| |
| Args: |
| arguments: The GeminiCliArguments to run gemini-cli with |
| |
| Returns: |
| A tuple (process, combined_output). |process| is the subprocess used |
| to run gemini-cli. |combined_output| is a list of all stdout and stderr |
| output collected from |process|. |process| will have terminated by the |
| time this function returns. |
| """ |
| output_thread = None |
| process = None |
| combined_output = [] |
| try: |
| process = subprocess.Popen( # pylint: disable=consider-using-with |
| arguments.command, |
| stdin=subprocess.PIPE, |
| stdout=subprocess.PIPE, |
| stderr=subprocess.STDOUT, |
| text=True, |
| universal_newlines=True, |
| env=arguments.env, |
| ) |
| process.stdin.write(f'{arguments.system_prompt}\n\n' |
| f'{arguments.user_prompt}') |
| process.stdin.close() |
| logging.info('--- Streaming Output (Timeout: %ss) ---', |
| arguments.timeout_seconds) |
| output_thread = threading.Thread( |
| target=_stream_reader, |
| args=(process.stdout, combined_output, arguments.console_width), |
| daemon=True, |
| ) |
| output_thread.start() |
| process.wait(timeout=arguments.timeout_seconds) |
| output_thread.join(timeout=5) |
| logging.info('\n--- End of Stream ---') |
| return process, combined_output |
| finally: |
| if process and process.poll() is None: |
| process.kill() |
| if output_thread: |
| output_thread.join(timeout=5) |
| if output_thread.is_alive(): |
| logging.warning('Output thread did not cleanly terminate.') |
| |
| |
| def call_api(prompt: str, options: dict[str, Any], |
| context: dict[str, Any]) -> dict[str, Any]: |
| """A flexible promptfoo provider that runs a command-line tool. |
| |
| This provider streams the tool's output and captures artifacts with a |
| reliable timeout. |
| """ |
| provider_config = options.get('config', {}) |
| provider_vars = context.get('vars', {}) |
| logging.basicConfig( |
| level=logging.DEBUG |
| if provider_vars.get('verbose', False) else logging.INFO, |
| format='%(message)s', |
| ) |
| logging.debug('options: %s', json.dumps(options, indent=2)) |
| logging.debug('context: %s', json.dumps(context, indent=2)) |
| |
| gcli_arguments, error = _get_gemini_cli_arguments(provider_vars, |
| provider_config, prompt) |
| if error: |
| return {'error': error} |
| |
| _install_extensions(provider_config.get('extensions', DEFAULT_EXTENSIONS), |
| home_dir=gcli_arguments.home_dir) |
| _apply_changes(provider_config.get('changes', [])) |
| |
| process = None |
| combined_output: list[str] = [] |
| metrics = { |
| 'system_prompt': gcli_arguments.system_prompt, |
| 'user_prompt': gcli_arguments.user_prompt, |
| } |
| try: |
| start_time = time.time() |
| process, combined_output = _run_gemini_cli_with_output_streaming( |
| gcli_arguments) |
| elapsed_time = time.time() - start_time |
| |
| full_output = ''.join(combined_output) |
| metrics['full_output'] = full_output |
| metrics['duration'] = elapsed_time |
| if process.returncode != 0: |
| error_message = ( |
| f"Command '{' '.join(gcli_arguments.command)}' failed with " |
| f'return code {process.returncode}.\n' |
| f'Output:\n{full_output}') |
| return {'error': error_message, 'metrics': metrics} |
| return { |
| 'output': full_output.strip(), |
| 'metrics': metrics, |
| } |
| except subprocess.TimeoutExpired: |
| metrics['full_output'] = ''.join(combined_output) |
| return { |
| 'error': (f'Command timed out after ' |
| f'{gcli_arguments.timeout_seconds} seconds.'), |
| 'metrics': |
| metrics, |
| } |
| except FileNotFoundError: |
| return { |
| 'error': (f"Command not found: '{gcli_arguments.command[0]}'. " |
| f'Please ensure it is in your PATH.'), |
| 'metrics': |
| metrics, |
| } |
| except Exception as e: |
| metrics['full_output'] = ''.join(combined_output) |
| return { |
| 'error': f'An unexpected error occurred: {e}', |
| 'metrics': metrics, |
| } |