blob: 318e17a2eeac7166da70bd6394860b2bc63ed208 [file] [log] [blame]
#!/usr/bin/env python3
# Copyright 2025 The Chromium Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""A script to evaluate prompts using promptfoo."""
import argparse
import logging
import os
import pathlib
import subprocess
import sys
import tempfile
import checkout_helpers
import constants
import promptfoo_installation
import workers
TESTCASE_EXTENSION = '.promptfoo.yaml'
_SHARD_INDEX_ENV_VAR = 'GTEST_SHARD_INDEX'
_TOTAL_SHARDS_ENV_VAR = 'GTEST_TOTAL_SHARDS'
def _check_uncommitted_changes(cwd):
out_dir = pathlib.Path(cwd) / 'out'
if out_dir.is_dir():
subdirs = [d.name for d in out_dir.iterdir() if d.is_dir()]
other_dirs = [d for d in subdirs if d != 'Default']
if other_dirs:
logging.warning(
'Warning: The out directory contains unexpected directories: '
'%s. These will get copied into the workdirs and can affect '
'tests.', ', '.join(other_dirs))
result = subprocess.run(['git', 'status', '--porcelain'],
capture_output=True,
text=True,
check=True,
cwd=cwd)
if result.stdout:
logging.warning(
'Warning: There are uncommitted changes in the repository. This '
'can cause some tests to unexpectedly fail or pass. Please '
'commit or stash them before running the evaluation.')
def _build_chromium(cwd):
logging.info('Running `gn gen out/Default`')
subprocess.check_call(['gn', 'gen', 'out/Default'], cwd=cwd)
logging.info('Running `autoninja -C out/Default`')
subprocess.check_call(['autoninja', '-C', 'out/Default'], cwd=cwd)
logging.info('Finished building')
def _discover_testcase_files() -> list[pathlib.Path]:
"""Discovers all testcase files that can be run by this test runner.
Returns:
A list of Paths, each path pointing to a .yaml file containing a
promptfoo test case. No specific ordering is guaranteed.
"""
extensions_path = constants.CHROMIUM_SRC / 'agents' / 'extensions'
all_tests = list(extensions_path.glob(f'*/tests/**/*{TESTCASE_EXTENSION}'))
prompts_path = constants.CHROMIUM_SRC / 'agents' / 'prompts' / 'eval'
all_tests.extend(list(prompts_path.glob(f'**/*{TESTCASE_EXTENSION}')))
return all_tests
def _determine_shard_values(
parsed_shard_index: int | None,
parsed_total_shards: int | None) -> tuple[int, int]:
"""Determines the values that should be used for sharding.
If shard information is set both via command line arguments and environment
variables, the values from the command line are used. If no sharding
information is explicitly provided, a single shard is assumed.
Args:
parsed_shard_index: The shard index parsed from the command line
arguments.
parsed_total_shards: The total shards parsed from the command line
arguments.
Returns:
A tuple (shard_index, total_shards).
"""
env_shard_index = os.environ.get(_SHARD_INDEX_ENV_VAR)
if env_shard_index is not None:
env_shard_index = int(env_shard_index)
env_total_shards = os.environ.get(_TOTAL_SHARDS_ENV_VAR)
if env_total_shards is not None:
env_total_shards = int(env_total_shards)
shard_index_set = (parsed_shard_index is not None
or env_shard_index is not None)
total_shards_set = (parsed_total_shards is not None
or env_total_shards is not None)
if shard_index_set != total_shards_set:
raise ValueError(
'Only one of shard index or total shards was set. Either both or '
'neither must be set.')
shard_index = 0
if parsed_shard_index is not None:
shard_index = parsed_shard_index
if env_shard_index is not None:
logging.warning(
'Shard index set by both arguments and environment variable. '
'Using value provided by arguments.')
elif env_shard_index is not None:
shard_index = env_shard_index
total_shards = 1
if parsed_total_shards is not None:
total_shards = parsed_total_shards
if env_total_shards is not None:
logging.warning(
'Total shards set by both arguments and environment variable. '
'Using value provided by arguments.')
elif env_total_shards is not None:
total_shards = env_total_shards
if shard_index < 0:
raise ValueError('Shard index must be non-negative')
if total_shards < 1:
raise ValueError('Total shards must be positive')
if shard_index >= total_shards:
raise ValueError('Shard index must be < total shards')
return shard_index, total_shards
def _get_tests_to_run(
shard_index: int | None,
total_shards: int | None,
test_filter: str | None,
) -> list[pathlib.Path]:
"""Retrieves which tests should be run for this invocation.
Automatically discovers any valid tests on disk and filters them based on
sharding and test filter information.
Args:
shard_index: The swarming shard index parsed from arguments.
total_shards: The swarming shard total parsed from arguments.
test_filter: The test filter parsed from arguments.
Returns:
A potentially empty list of paths, each path pointing to a valid test
to be run.
"""
shard_index, total_shards = _determine_shard_values(
shard_index, total_shards)
configs_to_run = _discover_testcase_files()
configs_to_run.sort()
if test_filter:
configs_to_run = [c for c in configs_to_run if test_filter in str(c)]
configs_to_run = configs_to_run[shard_index::total_shards]
return configs_to_run
def _perform_chromium_setup(force: bool, build: bool) -> None:
"""Performs setup steps related to the Chromium checkout.
Args:
force: Whether to force execution.
build: Whether to build Chromium as part of setup.
"""
root_path = checkout_helpers.get_gclient_root()
is_btrfs = checkout_helpers.check_btrfs(root_path)
if is_btrfs and not force:
subprocess.run(['sudo', '-v'], check=True)
src_path = root_path / 'src'
_check_uncommitted_changes(src_path)
if build:
_build_chromium(src_path)
def _fetch_sandbox_image() -> bool:
"""Pre-fetches the sandbox image.
Returns:
True on success, False on failure.
"""
logging.info('Pre-fetching sandbox image. This may take a minute...')
# Use a simple, non-destructive prompt to trigger the one-time
# sandbox image download.
with tempfile.TemporaryDirectory() as tmpdir:
try:
subprocess.run(
['gemini', '--sandbox', 'no-op'],
text=True,
check=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
cwd=tmpdir,
)
return True
except subprocess.CalledProcessError as e:
output = ''
if e.stdout:
output += f'\noutput:\n{e.stdout}'
logging.error(
'Failed to pre-fetch sandbox image: %s. This may be '
'because you are in an environment that does not support '
'sandboxing. Try running with --no-sandbox.%s', e, output)
return False
def _run_prompt_eval_tests(args: argparse.Namespace) -> int:
"""Performs all the necessary steps to run prompt evaluation tests.
Args:
args: The parsed command line args.
Returns:
0 on success, a non-zero value on failure.
"""
configs_to_run = _get_tests_to_run(args.shard_index, args.total_shards,
args.filter)
if len(configs_to_run) == 0:
logging.info('No tests to run after filtering and sharding')
return 1
_perform_chromium_setup(force=args.force, build=not args.no_build)
promptfoo_dir = pathlib.Path(tempfile.gettempdir()) / 'promptfoo'
promptfoo = promptfoo_installation.setup_promptfoo(promptfoo_dir,
args.promptfoo_revision,
args.promptfoo_version)
if args.sandbox and not _fetch_sandbox_image():
return 1
worker_options = workers.WorkerOptions(clean=not args.no_clean,
verbose=args.verbose,
force=args.force,
sandbox=args.sandbox)
worker_pool = workers.WorkerPool(args.parallel_workers, promptfoo,
worker_options,
args.print_output_on_success)
configs_for_current_iteration = configs_to_run
failed_test_results = []
for iteration in range(args.retries + 1):
if iteration != 0:
logging.info('Re-running %d failed tests',
len(configs_for_current_iteration))
worker_pool.queue_tests(configs_for_current_iteration)
configs_for_current_iteration = []
failed_test_results = worker_pool.wait_for_all_queued_tests()
if not failed_test_results:
break
configs_for_current_iteration = [
tr.test_file for tr in failed_test_results
]
worker_pool.shutdown_blocking()
returncode = 0
if failed_test_results:
returncode = 1
logging.warning(
'%d tests ran successfully and %d failed after %d additional '
'tries',
len(configs_to_run) - len(failed_test_results),
len(failed_test_results), args.retries)
logging.warning('Failed tests:')
for ftr in failed_test_results:
logging.warning(' %s', ftr.test_file)
else:
logging.info('Successfully ran %d tests', len(configs_to_run))
return returncode
def _parse_args() -> argparse.Namespace:
"""Parses command line args.
Returns:
An argparse.Namespace containing all parsed known arguments.
"""
parser = argparse.ArgumentParser()
parser.add_argument('--no-clean',
action='store_true',
help='Do not clean up the workdir after evaluation.')
parser.add_argument(
'--sandbox',
default=False,
action=argparse.BooleanOptionalAction,
help='Use a sandbox for running gemini-cli. This should only be '
'disabled for local testing.',
)
parser.add_argument('--force',
'-f',
action='store_true',
help='Force execution, deleting existing workdirs if '
'they exist.')
parser.add_argument('--verbose',
'-v',
action='store_true',
help='Print debug information.')
parser.add_argument(
'--print-output-on-success',
action='store_true',
help=('Print test output even when a test succeeds. By default, '
'output is only surfaced when a test fails.'))
parser.add_argument('--filter',
help='Only run configs that contain this substring.')
parser.add_argument(
'--parallel-workers',
type=int,
default=1,
help=('The number of parallel workers to run tests in. Changing this '
'is not recommended if the Chromium checkout being used is not '
'on btrfs.'))
parser.add_argument(
'--shard-index',
type=int,
help=(f'The index of the current shard. If set, --total-shards must '
f'also be set. Can also be set via {_SHARD_INDEX_ENV_VAR}.'))
parser.add_argument(
'--total-shards',
type=int,
help=(f'The total number of shards used to run these tests. If set, '
f'--shard-index must also be set. Can also be set via '
f'{_TOTAL_SHARDS_ENV_VAR}.'))
parser.add_argument('--no-build',
action='store_true',
help='Do not build out/Default.')
parser.add_argument('--retries',
type=int,
default=0,
help='Number of times to retry a failed test.')
promptfoo_install_group = parser.add_mutually_exclusive_group()
promptfoo_install_group.add_argument(
'--install-promptfoo-from-npm',
metavar='VERSION',
nargs='?',
dest='promptfoo_version',
const='latest',
help=('Install promptfoo through npm. If no release version is given, '
'latest will be used.'))
promptfoo_install_group.add_argument(
'--install-promptfoo-from-src',
metavar='REVISION',
nargs='?',
dest='promptfoo_revision',
const='main',
help=('Build promptfoo from the given source revision. If no revision '
'is specified, ToT will be used.'))
return parser.parse_args()
def main() -> int:
"""Evaluates prompts using promptfoo.
This will get a copy of promptfoo and create clean checkouts before running
tests.
"""
args = _parse_args()
logging.basicConfig(
level=logging.DEBUG if args.verbose else logging.INFO,
format='%(message)s',
)
return _run_prompt_eval_tests(args)
if __name__ == '__main__':
sys.exit(main())