blob: a408ac9445fc1c6941a51dea62f1959953b12cc8 [file] [log] [blame]
# Copyright 2025 The Chromium Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""A promptfoo provider for the Gemini CLI."""
import json
import logging
import pathlib
import subprocess
import sys
import textwrap
import threading
import time
from collections.abc import Collection
from typing import Any
DEFAULT_TIMEOUT_SECONDS = 600
DEFAULT_COMMAND = ['gemini', '-y']
DEFAULT_EXTENSIONS = [
'build_information',
'depot_tools',
'landmines',
'test_landmines',
]
def _stream_reader(stream, output_list: list[str], width):
"""Reads a stream line-by-line and appends to a list."""
try:
for line in iter(stream.readline, ''):
output_list.append(line)
wrapped_text = '\n'.join(
textwrap.wrap(line.rstrip('\r\n'), width=width))
sys.stderr.write(wrapped_text + '\n')
except OSError:
# Stream may be closed unexpectedly
pass
finally:
stream.close()
def _install_extensions(extensions: Collection[str] | None = None, ) -> None:
# The installation script should identify the working tree as the "repo
# root", so use the copy in the working tree with the CWD set
# appropriately for subprocesses like `git`.
if not extensions:
return
logging.info('Installing extensions: %s', extensions)
command = [
sys.executable,
pathlib.Path('agents', 'extensions', 'install.py'),
'--extra-extensions-dir',
pathlib.Path('agents', 'testing', 'extensions'),
'add',
'--copy',
*extensions,
]
subprocess.check_call(command)
def _load_templates(templates: list[str]) -> str:
"""Loads and combines system prompt templates."""
if not templates:
return ''
logging.info('Loading templates: %s', templates)
prompt_parts = []
for template in templates:
with open(template, encoding='utf-8') as t:
prompt_parts.append(t.read())
return '\n\n'.join(prompt_parts)
def _apply_changes(changes: list[dict[str, str]]) -> None:
"""Applies changes to the repository."""
if not changes:
return
logging.info('Applying changes: %s', changes)
for change in changes:
if len(change) != 1:
raise ValueError(
'Invalid change object: must have exactly one key.')
if 'apply' in change:
subprocess.check_call(['git', 'apply', change['apply']])
elif 'stage' in change:
subprocess.check_call(['git', 'add', change['stage']])
else:
raise ValueError(
'Invalid change object: key must be "apply" or "stage".')
def call_api(prompt: str, options: dict[str, Any],
context: dict[str, Any]) -> dict[str, Any]:
"""A flexible promptfoo provider that runs a command-line tool.
This provider streams the tool's output and captures artifacts with a
reliable timeout.
"""
provider_config = options.get('config', {})
provider_vars = context.get('vars', {})
logging.basicConfig(
level=logging.DEBUG
if provider_vars.get('verbose', False) else logging.INFO,
format='%(message)s',
)
command = provider_config.get('command', DEFAULT_COMMAND)
if not isinstance(command, list):
return {
'error': f"'command' must be a list of strings, but got: {command}"
}
if provider_vars.get('sandbox', False):
command.append('--sandbox')
system_prompt = provider_config.get('system_prompt', '')
try:
timeout_seconds = int(
provider_config.get('timeoutSeconds', DEFAULT_TIMEOUT_SECONDS))
except (ValueError, TypeError):
timeout_seconds = DEFAULT_TIMEOUT_SECONDS
process = None
combined_output: list[str] = []
logging.debug('options: %s', json.dumps(options, indent=2))
logging.debug('context: %s', json.dumps(context, indent=2))
extensions = provider_config.get('extensions', DEFAULT_EXTENSIONS)
_install_extensions(extensions)
templates = provider_config.get('templates', [])
template_prompt = _load_templates(templates)
if template_prompt:
if system_prompt:
system_prompt = f'{system_prompt}\n\n{template_prompt}'
else:
system_prompt = template_prompt
changes = provider_config.get('changes', [])
_apply_changes(changes)
try:
start_time = time.time()
process = subprocess.Popen( # pylint: disable=consider-using-with
command,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
universal_newlines=True,
)
if process.stdin:
process.stdin.write(f'{system_prompt}\n\n{prompt}')
process.stdin.close()
logging.info('--- Streaming Output (Timeout: %ss) ---',
timeout_seconds)
console_width = int(provider_vars.get('console_width', 80))
output_thread = threading.Thread(
target=_stream_reader,
args=(process.stdout, combined_output, console_width),
)
output_thread.start()
process.wait(timeout=timeout_seconds)
output_thread.join(timeout=5)
elapsed_time = time.time() - start_time
logging.info('\n--- End of Stream ---')
full_output = ''.join(combined_output)
metrics = {
'system_prompt': system_prompt,
'user_prompt': prompt,
'full_output': full_output,
'duration': elapsed_time,
}
if process.returncode != 0:
error_message = (
f"Command '{' '.join(command)}' failed with return code "
f'{process.returncode}.\n'
f'Output:\n{full_output}')
return {'error': error_message, 'metrics': metrics}
return {
'output': full_output.strip(),
'metrics': metrics,
}
except subprocess.TimeoutExpired:
if process:
process.kill()
output_thread.join(timeout=5)
metrics = {
'system_prompt': system_prompt,
'user_prompt': prompt,
'full_output': ''.join(combined_output),
}
return {
'error': f'Command timed out after {timeout_seconds} seconds.',
'metrics': metrics,
}
except FileNotFoundError:
metrics = {
'system_prompt': system_prompt,
'user_prompt': prompt,
}
return {
'error': f"Command not found: '{command[0]}'. Please ensure it is "
'in your PATH.',
'metrics': metrics,
}
except Exception as e:
if process:
process.kill()
metrics = {
'system_prompt': system_prompt,
'user_prompt': prompt,
'full_output': ''.join(combined_output),
}
return {
'error': f'An unexpected error occurred: {e}',
'metrics': metrics,
}