blob: 3f7d5f6de85456bf380b55cbbacae0ec6ecf1bfc [file] [log] [blame]
#!/usr/bin/env vpython3
# Copyright 2025 The Chromium Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Linter for promptfoo.yaml files."""
import argparse
import os
import sys
import yaml
def _get_chromium_src_path():
"""Returns the path to the Chromium src directory."""
# This script is in chromium/agents/testing, so three levels up is the src
# root.
return os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
def _check_extension_reference(extension_name, test_case_path):
"""Checks a single extension reference."""
errors = []
if not isinstance(extension_name, str):
errors.append(f'{test_case_path} contains a non-string extension '
f'reference: {extension_name}')
return errors
if extension_name.startswith('file://'):
msg = (f'{test_case_path} contains a file path for an extension. '
'Please use the extension name instead: '
f'{extension_name}')
errors.append(msg)
return errors
chromium_src_path = _get_chromium_src_path()
extension_path = os.path.join(chromium_src_path, 'agents', 'extensions',
extension_name)
if not os.path.exists(extension_path):
msg = (f'{test_case_path} refers to a non-existent extension: '
f'{extension_name}')
errors.append(msg)
return errors
def _check_file_reference(file_url, test_case_path):
"""Checks a single file reference."""
errors = []
if not isinstance(file_url, str):
errors.append(f'{test_case_path} contains a non-string file '
f'reference: {file_url}')
return errors
# The file URL should be a path relative to chromium/src.
file_dir = file_url.removeprefix('file://')
chromium_src_path = _get_chromium_src_path()
# The file path from yaml may contain forward slashes, which is not the
# native path separator on Windows. We need to split the path and join it
# back to get a path with the correct separators.
abs_path = os.path.join(chromium_src_path, *file_dir.split('/'))
if not os.path.exists(abs_path):
msg = (f'{test_case_path} refers to a non-existent file: '
f'{file_dir}')
errors.append(msg)
return errors
def check_test_case(data, test_case_path):
"""Checks that promptfoo.yaml data is valid.
1. Check providers.config.changes.*.apply points to valid files.
2. Check providers.config.templates points to an array of valid files.
3. Check providers.config.extensions points to valid extensions in
//agents/extensions/.
"""
errors = []
if not isinstance(data, dict):
errors.append(f'{test_case_path} must be a dictionary.')
return errors
providers = data.get('providers')
if not providers:
return [f'{test_case_path} must contain at least one provider.']
elif not isinstance(providers, list):
return [f'{test_case_path} "providers" field must be a list.']
for provider in providers:
if not isinstance(provider, dict):
errors.append(f'{test_case_path} "providers" field must be a '
'list of dicts.')
continue
config = provider.get('config')
if config is None:
continue
if not isinstance(config, dict):
errors.append(f'{test_case_path} "providers" field must have a '
'dict "config" field.')
continue
# Check providers.config.changes.*.apply
changes = config.get('changes')
if changes is not None:
if isinstance(changes, list):
for change in changes:
if not isinstance(change, dict):
errors.append(f'{test_case_path} "changes" items '
'must be dicts.')
continue
if 'apply' in change:
errors.extend(
_check_file_reference(change['apply'],
test_case_path))
else:
errors.append(f'{test_case_path} "changes" field must be a '
'list.')
# Check providers.config.templates
templates = config.get('templates')
if templates is not None:
if isinstance(templates, list):
for template in templates:
errors.extend(
_check_file_reference(template, test_case_path))
else:
errors.append(f'{test_case_path} "templates" field must be '
'a list.')
# Check providers.config.extensions
extensions = config.get('extensions')
if extensions is not None:
if isinstance(extensions, list):
for extension in extensions:
errors.extend(
_check_extension_reference(extension, test_case_path))
else:
errors.append(f'{test_case_path} "extensions" field must '
'be a list.')
return errors
def main(argv):
"""Entrypoint for the linter script."""
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('files',
nargs='+',
help='promptfoo.yaml files to lint.')
args = parser.parse_args(argv[1:])
all_errors = []
for f_path in args.files:
try:
with open(f_path, 'r', encoding='utf-8') as f:
data = yaml.safe_load(f)
except yaml.YAMLError as e:
all_errors.append(f'Invalid YAML in {f_path}: {e}')
continue
except OSError as e:
all_errors.append(f'Could not read file {f_path}: {e}')
continue
try:
all_errors.extend(check_test_case(data, f_path))
except Exception as e:
# Broad exception for unexpected data structures.
all_errors.append(
f'Error linting {f_path}: {e}. This may be from a '
'malformed file.')
if all_errors:
for error in all_errors:
# The presubmit wrapper will show this to the user.
print(error, file=sys.stderr)
return 1
return 0
if __name__ == '__main__':
sys.exit(main(sys.argv))