| #!/usr/bin/env vpython3 |
| # Copyright 2025 The Chromium Authors |
| # Use of this source code is governed by a BSD-style license that can be |
| # found in the LICENSE file. |
| """Tests for workers.""" |
| |
| import json |
| import pathlib |
| import queue |
| import shutil |
| import subprocess |
| import time |
| import unittest |
| from unittest import mock |
| |
| from pyfakefs import fake_filesystem_unittest |
| |
| import promptfoo_installation |
| import results |
| import workers |
| import eval_config |
| |
| # pylint: disable=protected-access |
| |
| _POLLING_INTERVAL = 0.001 |
| |
| |
| class WorkDirUnittest(fake_filesystem_unittest.TestCase): |
| """Unit tests for the WorkDir class.""" |
| |
| def setUp(self): |
| self.setUpPyfakefs() |
| self.fs.create_dir('/tmp/src') |
| self._setUpPatches() |
| |
| def _setUpPatches(self): |
| """Set up patches for the tests.""" |
| rmtree_patcher = mock.patch('shutil.rmtree') |
| self.mock_rmtree = rmtree_patcher.start() |
| self.addCleanup(rmtree_patcher.stop) |
| |
| call_patcher = mock.patch('subprocess.call') |
| self.mock_call = call_patcher.start() |
| self.addCleanup(call_patcher.stop) |
| |
| check_call_patcher = mock.patch('subprocess.check_call') |
| self.mock_check_call = check_call_patcher.start() |
| self.addCleanup(check_call_patcher.stop) |
| |
| check_btrfs_patcher = mock.patch('checkout_helpers.check_btrfs') |
| self.mock_check_btrfs = check_btrfs_patcher.start() |
| self.addCleanup(check_btrfs_patcher.stop) |
| |
| def test_enter_btrfs(self): |
| """Tests that a btrfs snapshot is created when btrfs is true.""" |
| self.mock_check_btrfs.return_value = True |
| workdir = workers.WorkDir('workdir', |
| pathlib.Path('/tmp/src'), |
| clean=False, |
| verbose=False, |
| force=False) |
| with workdir as w: |
| self.assertEqual(w, workdir) |
| |
| self.mock_check_call.assert_called_once_with( |
| [ |
| 'btrfs', |
| 'subvol', |
| 'snapshot', |
| pathlib.Path('/tmp/src'), |
| pathlib.Path('/tmp/workdir'), |
| ], |
| stdout=subprocess.DEVNULL, |
| stderr=subprocess.STDOUT, |
| ) |
| |
| def test_enter_no_btrfs(self): |
| """Tests that gclient-new-workdir is called when btrfs is false.""" |
| self.mock_check_btrfs.return_value = False |
| workdir = workers.WorkDir('workdir', |
| pathlib.Path('/tmp/src'), |
| clean=False, |
| verbose=False, |
| force=False) |
| with workdir as w: |
| self.assertEqual(w, workdir) |
| |
| self.mock_check_call.assert_called_once_with( |
| [ |
| 'gclient-new-workdir.py', |
| pathlib.Path('/tmp/src'), |
| pathlib.Path('/tmp/workdir'), |
| ], |
| stdout=subprocess.DEVNULL, |
| stderr=subprocess.STDOUT, |
| ) |
| |
| def test_enter_verbose(self): |
| """Tests that verbose logging is enabled when verbose is true.""" |
| self.mock_check_btrfs.return_value = False |
| workdir = workers.WorkDir('workdir', |
| pathlib.Path('/tmp/src'), |
| clean=False, |
| verbose=True, |
| force=False) |
| with workdir as w: |
| self.assertEqual(w, workdir) |
| |
| self.mock_check_call.assert_called_once_with( |
| [ |
| 'gclient-new-workdir.py', |
| pathlib.Path('/tmp/src'), |
| pathlib.Path('/tmp/workdir'), |
| ], |
| stdout=None, # Output surfaced instead of going to DEVNULL. |
| stderr=subprocess.STDOUT, |
| ) |
| |
| def test_enter_exists(self): |
| """Tests that the workdir is removed if it exists.""" |
| self.fs.create_dir('/tmp/workdir') |
| self.mock_check_btrfs.return_value = True |
| workdir = workers.WorkDir('workdir', |
| pathlib.Path('/tmp/src'), |
| clean=False, |
| verbose=False, |
| force=True) |
| with workdir: |
| pass |
| |
| self.mock_call.assert_called_once_with( |
| [ |
| 'sudo', |
| '-n', |
| 'btrfs', |
| 'subvolume', |
| 'delete', |
| pathlib.Path('/tmp/workdir'), |
| ], |
| stdout=subprocess.DEVNULL, |
| stderr=subprocess.STDOUT, |
| ) |
| |
| def test_exit_clean_btrfs(self): |
| """Tests that the workdir is removed when clean is true w/ btrfs .""" |
| self.mock_check_btrfs.return_value = True |
| workdir = workers.WorkDir('workdir', |
| pathlib.Path('/tmp/src'), |
| clean=True, |
| verbose=False, |
| force=False) |
| with workdir: |
| pass |
| |
| self.mock_call.assert_called_once_with( |
| [ |
| 'sudo', |
| 'btrfs', |
| 'subvolume', |
| 'delete', |
| pathlib.Path('/tmp/workdir'), |
| ], |
| stdout=subprocess.DEVNULL, |
| stderr=subprocess.STDOUT, |
| ) |
| |
| def test_exit_clean_no_btrfs(self): |
| """Tests that the workdir is removed when clean is True w/o btrfs.""" |
| self.mock_check_btrfs.return_value = False |
| workdir = workers.WorkDir('workdir', |
| pathlib.Path('/tmp/src'), |
| clean=True, |
| verbose=False, |
| force=False) |
| with workdir: |
| pass |
| |
| self.mock_rmtree.assert_called_once_with(pathlib.Path('/tmp/workdir')) |
| |
| def test_exit_no_clean(self): |
| """Tests that the workdir is not cleaned up when clean is False.""" |
| self.mock_check_btrfs.return_value = False |
| workdir = workers.WorkDir('workdir', |
| pathlib.Path('/tmp/src'), |
| clean=False, |
| verbose=False, |
| force=False) |
| with workdir: |
| pass |
| |
| self.mock_call.assert_not_called() |
| self.mock_rmtree.assert_not_called() |
| |
| |
| def test_exit_clean_btrfs_fallback(self): |
| """Tests that shutil is used when btrfs subvolume delete fails.""" |
| self.mock_check_btrfs.return_value = True |
| self.mock_call.return_value = 1 |
| workdir = workers.WorkDir('workdir', |
| pathlib.Path('/tmp/src'), |
| clean=True, |
| verbose=False, |
| force=False) |
| with workdir: |
| pass |
| |
| self.mock_call.assert_called_once_with( |
| [ |
| 'sudo', |
| 'btrfs', |
| 'subvolume', |
| 'delete', |
| pathlib.Path('/tmp/workdir'), |
| ], |
| stdout=subprocess.DEVNULL, |
| stderr=subprocess.STDOUT, |
| ) |
| self.mock_rmtree.assert_called_once_with(pathlib.Path('/tmp/workdir')) |
| |
| |
| class ExtractMetricsUnittest(fake_filesystem_unittest.TestCase): |
| """Unit tests for the _extract_metrics_from_promptfoo_results.""" |
| |
| def setUp(self): |
| self.setUpPyfakefs() |
| |
| def test_success(self): |
| """Tests a successful extraction.""" |
| results_data = { |
| 'results': { |
| 'results': [ |
| { |
| 'score': 0.5, |
| 'response': { |
| 'metrics': { |
| 'gemini_cli_token_usage': { |
| 'total_tokens': 10, |
| }, |
| }, |
| }, |
| }, |
| ], |
| }, |
| } |
| metrics = workers._extract_metrics_from_promptfoo_results(results_data) |
| self.assertEqual(metrics, { |
| 'token_usage': { |
| 'total_tokens': 10 |
| }, |
| 'score': 0.5 |
| }) |
| |
| def test_no_score(self): |
| """Tests when the score is missing.""" |
| results_data = { |
| 'results': { |
| 'results': [ |
| { |
| 'response': { |
| 'metrics': { |
| 'gemini_cli_token_usage': { |
| 'total_tokens': 10, |
| }, |
| }, |
| }, |
| }, |
| ], |
| }, |
| } |
| metrics = workers._extract_metrics_from_promptfoo_results(results_data) |
| self.assertEqual(metrics, {'token_usage': {'total_tokens': 10}}) |
| |
| def test_no_token_usage(self): |
| """Tests when token usage is missing.""" |
| results_data = { |
| 'results': { |
| 'results': [ |
| { |
| 'score': 0.5, |
| 'response': { |
| 'metrics': {}, |
| }, |
| }, |
| ], |
| }, |
| } |
| metrics = workers._extract_metrics_from_promptfoo_results(results_data) |
| self.assertEqual(metrics, {'token_usage': {}, 'score': 0.5}) |
| |
| def test_empty_results(self): |
| """Tests when the results file is empty.""" |
| metrics = workers._extract_metrics_from_promptfoo_results({}) |
| self.assertEqual(metrics, {}) |
| |
| |
| class ParseTestLogResultsTest(unittest.TestCase): |
| |
| def test_empty_json(self): |
| self.assertEqual(workers._parse_test_log_results(None), '') |
| self.assertEqual(workers._parse_test_log_results({}), |
| 'No results found in promptfoo output.') |
| |
| def test_empty_results_list(self): |
| json_data = {'results': {'results': []}} |
| self.assertEqual(workers._parse_test_log_results(json_data), |
| 'No results found in promptfoo output.') |
| |
| def test_missing_keys(self): |
| json_data = {'results': {'results': [{}]}} |
| expected = ('Input prompt: None\n' |
| 'Response: None\n' |
| 'Assertion results:\n') |
| self.assertEqual(workers._parse_test_log_results(json_data), expected) |
| |
| json_data = {'results': {'results': [{'gradingResult': {}}]}} |
| self.assertEqual(workers._parse_test_log_results(json_data), expected) |
| |
| json_data = { |
| 'results': { |
| 'results': [{ |
| 'gradingResult': { |
| 'componentResults': [] |
| } |
| }] |
| } |
| } |
| self.assertEqual(workers._parse_test_log_results(json_data), expected) |
| |
| def test_full_json(self): |
| json_data = { |
| "results": { |
| "results": [{ |
| "gradingResult": { |
| "componentResults": [{ |
| "pass": True, |
| "reason": "Looks good", |
| "score": 1.0 |
| }, { |
| "pass": False, |
| "reason": "Not so good", |
| "score": 0.0 |
| }] |
| }, |
| "response": { |
| "metrics": { |
| "user_prompt": "This is the prompt.", |
| "full_output": "This is the output." |
| } |
| } |
| }] |
| } |
| } |
| expected_output = ("Input prompt: This is the prompt.\n" |
| "Response: This is the output.\n" |
| "Assertion results:\n" |
| "pass: True\nreason: Looks good\nscore: 1.0\n\n" |
| "pass: False\nreason: Not so good\nscore: 0.0\n\n") |
| self.assertEqual(workers._parse_test_log_results(json_data), |
| expected_output) |
| |
| def test_no_component_results(self): |
| json_data = { |
| "results": { |
| "results": [{ |
| "gradingResult": { |
| "componentResults": [] |
| }, |
| "response": { |
| "metrics": { |
| "user_prompt": "This is the prompt.", |
| "full_output": "This is the output." |
| } |
| } |
| }] |
| } |
| } |
| expected_output = ("Input prompt: This is the prompt.\n" |
| "Response: This is the output.\n" |
| "Assertion results:\n") |
| self.assertEqual(workers._parse_test_log_results(json_data), |
| expected_output) |
| |
| |
| class LoadPromptfooResultsUnittest(fake_filesystem_unittest.TestCase): |
| """Unit tests for the _load_promptfoo_results.""" |
| |
| def setUp(self): |
| self.setUpPyfakefs() |
| |
| def test_success(self): |
| """Tests a successful load.""" |
| results_data = { |
| 'results': { |
| 'results': [], |
| }, |
| } |
| results_content = json.dumps(results_data) |
| results_file = pathlib.Path('/results.json') |
| self.fs.create_file(results_file, contents=results_content) |
| data = workers._load_promptfoo_results(results_file) |
| self.assertEqual(data, results_data) |
| |
| def test_invalid_json(self): |
| """Tests with invalid JSON content.""" |
| results_file = pathlib.Path('/results.json') |
| self.fs.create_file(results_file, contents='{invalid json') |
| with self.assertLogs(level='ERROR') as cm: |
| data = workers._load_promptfoo_results(results_file) |
| self.assertIn('Error when parsing promptfoo results', cm.output[0]) |
| self.assertEqual(data, {}) |
| |
| def test_unicode_error(self): |
| """Tests with invalid unicode content.""" |
| results_file = pathlib.Path('/results.json') |
| with open(results_file, 'wb') as f: |
| f.write(b'\x80') |
| with self.assertLogs(level='ERROR') as cm: |
| data = workers._load_promptfoo_results(results_file) |
| self.assertIn('Error when parsing promptfoo results', cm.output[0]) |
| self.assertEqual(data, {}) |
| |
| |
| class ExtractTokenUsageUnittest(unittest.TestCase): |
| """Unit tests for the _extract_token_usage_from_promptfoo_results.""" |
| |
| def test_success(self): |
| """Tests a successful extraction.""" |
| results_data = { |
| 'results': { |
| 'results': [ |
| { |
| 'response': { |
| 'metrics': { |
| 'gemini_cli_token_usage': { |
| 'total_tokens': 10, |
| 'prompt_tokens': 5, |
| 'completion_tokens': 5, |
| }, |
| }, |
| }, |
| }, |
| ], |
| }, |
| } |
| token_usage = workers._extract_token_usage_from_promptfoo_results( |
| results_data) |
| self.assertEqual(token_usage, { |
| 'total_tokens': 10, |
| 'prompt_tokens': 5, |
| 'completion_tokens': 5, |
| }) |
| |
| def test_no_results_key(self): |
| """Tests when the top-level 'results' key is missing.""" |
| with self.assertLogs(level='ERROR') as cm: |
| token_usage = workers._extract_token_usage_from_promptfoo_results( |
| {}) |
| self.assertIn('Did not find promptfoo result information', |
| cm.output[0]) |
| self.assertEqual(token_usage, {}) |
| |
| def test_no_nested_results_key(self): |
| """Tests when the nested 'results' key is missing.""" |
| with self.assertLogs(level='ERROR') as cm: |
| token_usage = workers._extract_token_usage_from_promptfoo_results({ |
| 'results': {}, |
| }) |
| self.assertIn('Did not find promptfoo result information', |
| cm.output[0]) |
| self.assertEqual(token_usage, {}) |
| |
| def test_empty_results_list(self): |
| """Tests when the results list is empty.""" |
| results_data = { |
| 'results': { |
| 'results': [], |
| }, |
| } |
| with self.assertLogs(level='ERROR') as cm: |
| token_usage = workers._extract_token_usage_from_promptfoo_results( |
| results_data) |
| self.assertIn('Did not find promptfoo result information', |
| cm.output[0]) |
| self.assertEqual(token_usage, {}) |
| |
| def test_multiple_results(self): |
| """Tests that only the first result is used when there are many.""" |
| results_data = { |
| 'results': { |
| 'results': [ |
| { |
| 'response': { |
| 'metrics': { |
| 'gemini_cli_token_usage': { |
| 'total_tokens': 10, |
| }, |
| }, |
| }, |
| }, |
| { |
| 'response': { |
| 'metrics': { |
| 'gemini_cli_token_usage': { |
| 'total_tokens': 20, |
| }, |
| }, |
| }, |
| }, |
| ], |
| }, |
| } |
| with self.assertLogs(level='WARNING') as cm: |
| token_usage = workers._extract_token_usage_from_promptfoo_results( |
| results_data) |
| self.assertIn('Unexpectedly got 2 results', cm.output[0]) |
| self.assertEqual(token_usage, {'total_tokens': 10}) |
| |
| def test_no_response_key(self): |
| """Tests when the 'response' key is missing.""" |
| results_data = { |
| 'results': { |
| 'results': [ |
| {}, |
| ], |
| }, |
| } |
| with self.assertLogs(level='WARNING') as cm: |
| token_usage = workers._extract_token_usage_from_promptfoo_results( |
| results_data) |
| self.assertIn('Did not find gemini-cli token usage', cm.output[0]) |
| self.assertEqual(token_usage, {}) |
| |
| def test_no_metrics_key(self): |
| """Tests when the 'metrics' key is missing.""" |
| results_data = { |
| 'results': { |
| 'results': [ |
| { |
| 'response': {}, |
| }, |
| ], |
| }, |
| } |
| with self.assertLogs(level='WARNING') as cm: |
| token_usage = workers._extract_token_usage_from_promptfoo_results( |
| results_data) |
| self.assertIn('Did not find gemini-cli token usage', cm.output[0]) |
| self.assertEqual(token_usage, {}) |
| |
| def test_no_token_usage_key(self): |
| """Tests when the 'gemini_cli_token_usage' key is missing.""" |
| results_data = { |
| 'results': { |
| 'results': [ |
| { |
| 'response': { |
| 'metrics': {}, |
| }, |
| }, |
| ], |
| }, |
| } |
| with self.assertLogs(level='WARNING') as cm: |
| token_usage = workers._extract_token_usage_from_promptfoo_results( |
| results_data) |
| self.assertIn('Did not find gemini-cli token usage', cm.output[0]) |
| self.assertEqual(token_usage, {}) |
| |
| def test_empty_token_usage_dict(self): |
| """Tests when the token usage dict is empty.""" |
| results_data = { |
| 'results': { |
| 'results': [ |
| { |
| 'response': { |
| 'metrics': { |
| 'gemini_cli_token_usage': {}, |
| }, |
| }, |
| }, |
| ], |
| }, |
| } |
| with self.assertLogs(level='WARNING') as cm: |
| token_usage = workers._extract_token_usage_from_promptfoo_results( |
| results_data) |
| self.assertIn('Did not find gemini-cli token usage', cm.output[0]) |
| self.assertEqual(token_usage, {}) |
| |
| |
| class ExtractScoreUnittest(unittest.TestCase): |
| """Unit tests for the _extract_score_from_promptfoo_results.""" |
| |
| def test_success(self): |
| """Tests a successful extraction.""" |
| results_data = { |
| 'results': { |
| 'results': [ |
| { |
| 'score': 0.5, |
| }, |
| ], |
| }, |
| } |
| score = workers._extract_score_from_promptfoo_results(results_data) |
| self.assertEqual(score, 0.5) |
| |
| def test_no_score(self): |
| """Tests when the score is missing.""" |
| results_data = { |
| 'results': { |
| 'results': [ |
| {}, |
| ], |
| }, |
| } |
| with self.assertLogs(level='WARNING') as cm: |
| score = workers._extract_score_from_promptfoo_results(results_data) |
| self.assertIn('Did not find reported score', cm.output[0]) |
| self.assertIsNone(score) |
| |
| def test_multiple_results(self): |
| """Tests that only the first result is used when there are many.""" |
| results_data = { |
| 'results': { |
| 'results': [ |
| { |
| 'score': 0.5, |
| }, |
| { |
| 'score': 1.0, |
| }, |
| ], |
| }, |
| } |
| with self.assertLogs(level='WARNING') as cm: |
| score = workers._extract_score_from_promptfoo_results(results_data) |
| self.assertIn('Unexpectedly got 2 results', cm.output[0]) |
| self.assertEqual(score, 0.5) |
| |
| def test_empty_results(self): |
| """Tests when the results list is empty.""" |
| results_data = { |
| 'results': { |
| 'results': [], |
| }, |
| } |
| with self.assertLogs(level='ERROR') as cm: |
| score = workers._extract_score_from_promptfoo_results(results_data) |
| self.assertIn('Did not find promptfoo result information', |
| cm.output[0]) |
| self.assertIsNone(score) |
| |
| |
| class WorkerThreadUnittest(unittest.TestCase): |
| """Unit tests for the WorkerThread class.""" |
| |
| def setUp(self): |
| self._setUpMocks() |
| self._setUpPatches() |
| |
| def _setUpMocks(self): |
| """Set up mocks for the tests.""" |
| self.mock_promptfoo = mock.Mock( |
| spec=promptfoo_installation.PromptfooInstallation) |
| self.mock_promptfoo.run.return_value = subprocess.CompletedProcess( |
| args=[], returncode=0, stdout='Success') |
| self.worker_options = workers.WorkerOptions( |
| clean=True, |
| verbose=False, |
| force=False, |
| sandbox=False, |
| gemini_cli_bin=None, |
| ) |
| self.test_input_queue = queue.Queue() |
| self.test_result_queue = queue.Queue() |
| |
| def _setUpPatches(self): |
| """Set up patches for the tests.""" |
| workdir_patcher = mock.patch('workers.WorkDir') |
| self.mock_workdir = workdir_patcher.start() |
| mock_workdir_instance = ( |
| self.mock_workdir.return_value.__enter__.return_value) |
| mock_workdir_instance.path = pathlib.Path('/workdir') |
| self.addCleanup(workdir_patcher.stop) |
| |
| polling_patcher = mock.patch( |
| 'workers._AVAILABLE_TEST_POLLING_SLEEP_DURATION', |
| _POLLING_INTERVAL) |
| polling_patcher.start() |
| self.addCleanup(polling_patcher.stop) |
| |
| get_gclient_root_patcher = mock.patch( |
| 'checkout_helpers.get_gclient_root') |
| self.mock_get_gclient_root = get_gclient_root_patcher.start() |
| self.mock_get_gclient_root.return_value = pathlib.Path('/root') |
| self.addCleanup(get_gclient_root_patcher.stop) |
| |
| def _create_and_run_worker(self, configs): |
| """Helper to create and run a worker thread.""" |
| for config in configs: |
| self.test_input_queue.put(config) |
| |
| worker = workers.WorkerThread( |
| worker_index=0, |
| promptfoo=self.mock_promptfoo, |
| worker_options=self.worker_options, |
| test_input_queue=self.test_input_queue, |
| test_result_queue=self.test_result_queue, |
| ) |
| worker.start() |
| |
| while self.test_result_queue.qsize() < len(configs): |
| worker.maybe_reraise_fatal_exception() |
| time.sleep(_POLLING_INTERVAL) |
| |
| worker.shutdown() |
| worker.join(1) |
| return worker |
| |
| def test_run_one_test_pass(self): |
| config = eval_config.TestConfig(test_file=pathlib.Path('/test/a.yaml')) |
| self._create_and_run_worker([config]) |
| |
| self.mock_workdir.assert_called_once_with('workdir-0', |
| pathlib.Path('/root'), True, |
| False, False) |
| self.mock_promptfoo.run.assert_called_once() |
| self.assertEqual(self.test_result_queue.qsize(), 1) |
| result = self.test_result_queue.get() |
| self.assertEqual(result.config.test_file, config.test_file) |
| self.assertTrue(result.success) |
| |
| def test_run_one_test_fail(self): |
| """Tests running a single failing test.""" |
| self.mock_promptfoo.run.return_value = subprocess.CompletedProcess( |
| args=[], returncode=1, stdout='Failure') |
| config = eval_config.TestConfig(test_file=pathlib.Path('/test/a.yaml')) |
| self._create_and_run_worker([config]) |
| |
| self.assertEqual(self.test_result_queue.qsize(), 1) |
| result = self.test_result_queue.get() |
| self.assertEqual(result.config.test_file, config.test_file) |
| self.assertFalse(result.success) |
| |
| def test_run_multiple_tests(self): |
| configs = [ |
| eval_config.TestConfig(test_file=pathlib.Path('/test/a.yaml')), |
| eval_config.TestConfig(test_file=pathlib.Path('/test/b.yaml')) |
| ] |
| self._create_and_run_worker(configs) |
| |
| self.assertEqual(self.mock_workdir.call_count, 2) |
| self.assertEqual(self.mock_promptfoo.run.call_count, 2) |
| self.assertEqual(self.test_result_queue.qsize(), 2) |
| |
| def test_shutdown(self): |
| """Tests that the worker thread shuts down gracefully.""" |
| worker = self._create_and_run_worker([]) |
| self.assertFalse(worker.is_alive()) |
| |
| def test_fatal_exception(self): |
| """Tests that fatal exceptions are propagated.""" |
| worker = self._create_and_run_worker([]) |
| with mock.patch.object(worker, |
| '_run_incoming_tests_until_shutdown', |
| side_effect=ValueError('Test Error')): |
| worker.run() |
| |
| with self.assertRaisesRegex(ValueError, 'Test Error'): |
| worker.maybe_reraise_fatal_exception() |
| |
| def test_no_fatal_exception(self): |
| """Tests that no exception is raised when there is no fatal error.""" |
| worker = self._create_and_run_worker([]) |
| # Should be a no-op. |
| worker.maybe_reraise_fatal_exception() |
| |
| def test_sandbox_and_verbose(self): |
| """Tests that sandbox and verbose flags are passed to promptfoo.""" |
| self.worker_options.sandbox = True |
| self.worker_options.verbose = True |
| self._create_and_run_worker( |
| [eval_config.TestConfig(test_file=pathlib.Path('/test/a.yaml'))]) |
| |
| self.mock_promptfoo.run.assert_called_once() |
| command = self.mock_promptfoo.run.call_args[0][0] |
| self.assertIn('--var', command) |
| self.assertIn('sandbox=True', command) |
| self.assertIn('verbose=True', command) |
| self.assertIn(f'console_width={shutil.get_terminal_size().columns}', |
| command) |
| |
| def test_gemini_cli_bin(self): |
| """Tests that gemini_cli_bin is passed to promptfoo.""" |
| gemini_cli_bin = pathlib.Path('/', 'custom', 'gemini') |
| self.worker_options.gemini_cli_bin = gemini_cli_bin |
| self._create_and_run_worker( |
| [eval_config.TestConfig(test_file=pathlib.Path('/test/a.yaml'))]) |
| |
| self.mock_promptfoo.run.assert_called_once() |
| command = self.mock_promptfoo.run.call_args[0][0] |
| self.assertIn('--var', command) |
| self.assertIn(f'gemini_cli_bin={gemini_cli_bin}', command) |
| |
| |
| class RunOneConfigTest(WorkerThreadUnittest): |
| """Tests for the `_run_one_config` method in `workers.py`.""" |
| |
| def test_aggregation(self): |
| """Tests that all metrics are aggregated correctly.""" |
| config = eval_config.TestConfig(test_file=pathlib.Path('/test/a.yaml'), |
| runs_per_test=3, |
| pass_k_threshold=2) |
| results_to_return = [ |
| results.IterationResult( |
| success=True, |
| duration=1.0, |
| test_log='log1', |
| metrics={ |
| 'token_usage': { |
| 'total': 10 |
| }, |
| }, |
| prompt=None, |
| response=None, |
| ), |
| results.IterationResult( |
| success=False, |
| duration=1.5, |
| test_log='log2', |
| metrics={ |
| 'token_usage': { |
| 'total': 5 |
| }, |
| }, |
| prompt=None, |
| response=None, |
| ), |
| results.IterationResult( |
| success=True, |
| duration=2.0, |
| test_log='log3', |
| metrics={ |
| 'token_usage': { |
| 'total': 15 |
| }, |
| }, |
| prompt=None, |
| response=None, |
| ), |
| ] |
| with mock.patch.object(workers.WorkerThread, |
| '_run_single_iteration', |
| side_effect=results_to_return): |
| self._create_and_run_worker([config]) |
| |
| self.assertEqual(self.test_result_queue.qsize(), 1) |
| result = self.test_result_queue.get() |
| self.assertTrue(result.success) |
| self.assertEqual(result.successful_runs, 2) |
| self.assertEqual(result.total_duration, 4.5) |
| self.assertEqual(result.average_duration, 1.5) |
| self.assertEqual( |
| result.combined_logs, 'Iteration #0:\nlog1\n' |
| 'Iteration #1:\nlog2\n' |
| 'Iteration #2:\nlog3') |
| |
| def test_success_criteria_pass(self): |
| """Tests that a test is marked as successful when it passes.""" |
| config = eval_config.TestConfig(test_file=pathlib.Path('/test/a.yaml'), |
| runs_per_test=3, |
| pass_k_threshold=2) |
| results_to_return = [ |
| results.IterationResult( |
| success=True, |
| duration=1, |
| test_log='', |
| metrics={}, |
| prompt=None, |
| response=None, |
| ), |
| results.IterationResult( |
| success=False, |
| duration=1, |
| test_log='', |
| metrics={}, |
| prompt=None, |
| response=None, |
| ), |
| results.IterationResult( |
| success=True, |
| duration=1, |
| test_log='', |
| metrics={}, |
| prompt=None, |
| response=None, |
| ), |
| ] |
| with mock.patch.object(workers.WorkerThread, |
| '_run_single_iteration', |
| side_effect=results_to_return): |
| self._create_and_run_worker([config]) |
| |
| self.assertEqual(self.test_result_queue.qsize(), 1) |
| result = self.test_result_queue.get() |
| self.assertTrue(result.success) |
| self.assertEqual(result.successful_runs, 2) |
| |
| def test_success_criteria_fail(self): |
| """Tests that a test is marked as failed when it does not pass.""" |
| config = eval_config.TestConfig(test_file=pathlib.Path('/test/a.yaml'), |
| runs_per_test=3, |
| pass_k_threshold=3) |
| results_to_return = [ |
| results.IterationResult( |
| success=True, |
| duration=1, |
| test_log='', |
| metrics={}, |
| prompt=None, |
| response=None, |
| ), |
| results.IterationResult( |
| success=True, |
| duration=1, |
| test_log='', |
| metrics={}, |
| prompt=None, |
| response=None, |
| ), |
| results.IterationResult( |
| success=False, |
| duration=1, |
| test_log='', |
| metrics={}, |
| prompt=None, |
| response=None, |
| ), |
| ] |
| with mock.patch.object(workers.WorkerThread, |
| '_run_single_iteration', |
| side_effect=results_to_return): |
| self._create_and_run_worker([config]) |
| |
| self.assertEqual(self.test_result_queue.qsize(), 1) |
| result = self.test_result_queue.get() |
| self.assertFalse(result.success) |
| self.assertEqual(result.successful_runs, 2) |
| |
| def test_early_exit_on_pass(self): |
| """Tests that the test exits early when the pass threshold is met.""" |
| config = eval_config.TestConfig(test_file=pathlib.Path('/test/a.yaml'), |
| runs_per_test=5, |
| pass_k_threshold=2) |
| results_to_return = [ |
| results.IterationResult( |
| success=True, |
| duration=1, |
| test_log='', |
| metrics={}, |
| prompt=None, |
| response=None, |
| ), |
| results.IterationResult( |
| success=True, |
| duration=1, |
| test_log='', |
| metrics={}, |
| prompt=None, |
| response=None, |
| ), |
| ] |
| with mock.patch.object(workers.WorkerThread, |
| '_run_single_iteration', |
| side_effect=results_to_return) as mock_run: |
| self._create_and_run_worker([config]) |
| self.assertEqual(mock_run.call_count, 2) |
| |
| self.assertEqual(self.test_result_queue.qsize(), 1) |
| result = self.test_result_queue.get() |
| self.assertTrue(result.success) |
| |
| def test_early_exit_on_fail(self): |
| """Tests that the test exits early when it can no longer pass.""" |
| config = eval_config.TestConfig(test_file=pathlib.Path('/test/a.yaml'), |
| runs_per_test=5, |
| pass_k_threshold=3) |
| results_to_return = [ |
| results.IterationResult( |
| success=False, |
| duration=1, |
| test_log='', |
| metrics={}, |
| prompt=None, |
| response=None, |
| ), |
| results.IterationResult( |
| success=False, |
| duration=1, |
| test_log='', |
| metrics={}, |
| prompt=None, |
| response=None, |
| ), |
| results.IterationResult( |
| success=False, |
| duration=1, |
| test_log='', |
| metrics={}, |
| prompt=None, |
| response=None, |
| ), |
| ] |
| with mock.patch.object(workers.WorkerThread, |
| '_run_single_iteration', |
| side_effect=results_to_return) as mock_run: |
| self._create_and_run_worker([config]) |
| self.assertEqual(mock_run.call_count, 3) |
| |
| self.assertEqual(self.test_result_queue.qsize(), 1) |
| result = self.test_result_queue.get() |
| self.assertFalse(result.success) |
| |
| |
| class WorkerPoolUnittest(unittest.TestCase): |
| """Unit tests for the WorkerPool class.""" |
| |
| def setUp(self): |
| self._setUpMocks() |
| self._setUpPatches() |
| |
| def _setUpMocks(self): |
| """Set up mocks for the tests.""" |
| self.mock_promptfoo = mock.Mock( |
| spec=promptfoo_installation.PromptfooInstallation) |
| self.worker_options = workers.WorkerOptions( |
| clean=True, |
| verbose=False, |
| force=False, |
| sandbox=False, |
| ) |
| self.result_options = results.ResultOptions( |
| print_output_on_success=False, |
| result_handlers=[], |
| ) |
| |
| def _setUpPatches(self): |
| """Set up patches for the tests.""" |
| |
| def create_thread_join_side_effect(mock_thread): |
| |
| def thread_join_side_effect(*args, **kwargs): |
| # pylint: disable=unused-argument |
| mock_thread.is_alive.return_value = False |
| |
| return thread_join_side_effect |
| |
| atomic_counter_patcher = mock.patch('workers.results.AtomicCounter') |
| self.mock_atomic_counter = atomic_counter_patcher.start() |
| self.addCleanup(atomic_counter_patcher.stop) |
| |
| result_thread_patcher = mock.patch('workers.results.ResultThread') |
| self.mock_result_thread = result_thread_patcher.start() |
| mock_result_thread_instance = self.mock_result_thread.return_value |
| mock_result_thread_instance.is_alive.return_value = True |
| mock_result_thread_instance.join.side_effect = ( |
| create_thread_join_side_effect(mock_result_thread_instance)) |
| mock_result_thread_instance.total_results_reported = ( |
| self.mock_atomic_counter.return_value) |
| mock_result_thread_instance.failed_result_output_queue = mock.Mock( |
| spec=queue.Queue) |
| self.addCleanup(result_thread_patcher.stop) |
| |
| worker_thread_patcher = mock.patch('workers.WorkerThread') |
| self.mock_worker_thread = worker_thread_patcher.start() |
| mock_worker_thread_instance = self.mock_worker_thread.return_value |
| mock_worker_thread_instance.is_alive.return_value = True |
| mock_worker_thread_instance.join.side_effect = ( |
| create_thread_join_side_effect(mock_worker_thread_instance)) |
| self.addCleanup(worker_thread_patcher.stop) |
| |
| polling_patcher = mock.patch( |
| 'workers._ALL_QUEUED_TESTS_RUN_POLLING_SLEEP_DURATION', |
| _POLLING_INTERVAL) |
| polling_patcher.start() |
| self.addCleanup(polling_patcher.stop) |
| |
| def test_create_pool(self): |
| """Tests that the pool is created with the correct number of workers.""" |
| pool = workers.WorkerPool( |
| num_workers=3, |
| promptfoo=self.mock_promptfoo, |
| worker_options=self.worker_options, |
| result_options=self.result_options, |
| ) |
| self.assertEqual(self.mock_worker_thread.call_count, 3) |
| self.mock_result_thread.assert_called_once() |
| pool.shutdown_blocking(1) |
| |
| def test_queue_tests(self): |
| """Tests that tests are queued correctly.""" |
| pool = workers.WorkerPool( |
| num_workers=1, |
| promptfoo=self.mock_promptfoo, |
| worker_options=self.worker_options, |
| result_options=self.result_options, |
| ) |
| configs = [ |
| eval_config.TestConfig(test_file=pathlib.Path('/test/a.yaml')), |
| eval_config.TestConfig(test_file=pathlib.Path('/test/b.yaml')) |
| ] |
| pool.queue_tests(configs) |
| self.assertEqual(pool._test_input_queue.qsize(), 2) |
| pool.shutdown_blocking(1) |
| |
| def test_wait_for_all_queued_tests(self): |
| """Tests that the pool waits for all tests to complete.""" |
| self.mock_atomic_counter.return_value.get.side_effect = [0, 1, 2] |
| pool = workers.WorkerPool( |
| num_workers=1, |
| promptfoo=self.mock_promptfoo, |
| worker_options=self.worker_options, |
| result_options=self.result_options, |
| ) |
| configs = [ |
| eval_config.TestConfig(test_file=pathlib.Path('/test/a.yaml')), |
| eval_config.TestConfig(test_file=pathlib.Path('/test/b.yaml')) |
| ] |
| pool.queue_tests(configs) |
| failed_tests = pool.wait_for_all_queued_tests() |
| self.assertEqual(len(failed_tests), 0) |
| self.assertEqual(self.mock_atomic_counter.return_value.get.call_count, |
| 3) |
| pool.shutdown_blocking(1) |
| |
| def test_wait_for_all_queued_tests_with_failures(self): |
| """Tests that failed tests are returned.""" |
| self.mock_atomic_counter.return_value.get.return_value = 1 |
| config = eval_config.TestConfig(test_file='fail.yaml') |
| failed_test = results.TestResult(config=config, |
| success=False, |
| iteration_results=[ |
| results.IterationResult( |
| success=False, |
| duration=1, |
| test_log='', |
| metrics={}, |
| prompt=None, |
| response=None, |
| ) |
| ]) |
| mock_failed_queue = ( |
| self.mock_result_thread.return_value.failed_result_output_queue) |
| mock_failed_queue.empty.side_effect = [False, True] |
| mock_failed_queue.get.return_value = failed_test |
| |
| pool = workers.WorkerPool( |
| num_workers=1, |
| promptfoo=self.mock_promptfoo, |
| worker_options=self.worker_options, |
| result_options=self.result_options, |
| ) |
| pool.queue_tests( |
| [eval_config.TestConfig(test_file=pathlib.Path('fail.yaml'))]) |
| failed_tests = pool.wait_for_all_queued_tests() |
| self.assertEqual(len(failed_tests), 1) |
| self.assertEqual(failed_tests[0], failed_test) |
| pool.shutdown_blocking(1) |
| |
| |
| |
| def test_shutdown_blocking(self): |
| """Tests that shutdown_blocking shuts down all threads.""" |
| pool = workers.WorkerPool( |
| num_workers=2, |
| promptfoo=self.mock_promptfoo, |
| worker_options=self.worker_options, |
| result_options=self.result_options, |
| ) |
| mock_workers = self.mock_worker_thread.return_value |
| mock_result = self.mock_result_thread.return_value |
| |
| pool.shutdown_blocking(1) |
| |
| self.assertEqual(mock_workers.shutdown.call_count, 2) |
| mock_result.shutdown.assert_called_once() |
| self.assertEqual(mock_workers.join.call_count, 2) |
| mock_result.join.assert_called_once() |
| |
| def test_shutdown_blocking_timeout(self): |
| """Tests that shutdown_blocking logs an error on timeout.""" |
| pool = workers.WorkerPool( |
| num_workers=1, |
| promptfoo=self.mock_promptfoo, |
| worker_options=self.worker_options, |
| result_options=self.result_options, |
| ) |
| self.mock_worker_thread.return_value.join.side_effect = None |
| self.mock_worker_thread.return_value.is_alive.return_value = True |
| with self.assertLogs(level='ERROR') as cm: |
| pool.shutdown_blocking(0.01) |
| self.assertIn('Failed to gracefully shut down thread', |
| cm.output[0]) |
| |
| def test_wait_for_all_queued_tests_with_multiple_workers(self): |
| """Tests that the pool waits for all tests with multiple workers.""" |
| self.mock_atomic_counter.return_value.get.side_effect = [0, 0, 1, 1, 2] |
| pool = workers.WorkerPool( |
| num_workers=2, |
| promptfoo=self.mock_promptfoo, |
| worker_options=self.worker_options, |
| result_options=self.result_options, |
| ) |
| test_paths = [ |
| pathlib.Path('/test/a.yaml'), |
| pathlib.Path('/test/b.yaml') |
| ] |
| configs = [eval_config.TestConfig(test_file=p) for p in test_paths] |
| pool.queue_tests(configs) |
| failed_tests = pool.wait_for_all_queued_tests() |
| self.assertEqual(len(failed_tests), 0) |
| self.assertEqual(self.mock_atomic_counter.return_value.get.call_count, |
| 5) |
| pool.shutdown_blocking(1) |
| |
| def test_worker_thread_fatal_exception(self): |
| """Tests that a fatal exception in a worker thread is propagated.""" |
| self.mock_worker_thread.return_value.maybe_reraise_fatal_exception. \ |
| side_effect = ValueError('Worker Error') |
| pool = workers.WorkerPool( |
| num_workers=1, |
| promptfoo=self.mock_promptfoo, |
| worker_options=self.worker_options, |
| result_options=self.result_options, |
| ) |
| pool.queue_tests( |
| [eval_config.TestConfig(test_file=pathlib.Path('/test/a.yaml'))]) |
| with self.assertRaisesRegex(ValueError, 'Worker Error'): |
| pool.wait_for_all_queued_tests() |
| pool.shutdown_blocking(1) |
| |
| def test_result_thread_fatal_exception(self): |
| """Tests that a fatal exception in the result thread is propagated.""" |
| self.mock_result_thread.return_value.maybe_reraise_fatal_exception. \ |
| side_effect = ValueError('Result Error') |
| pool = workers.WorkerPool( |
| num_workers=1, |
| promptfoo=self.mock_promptfoo, |
| worker_options=self.worker_options, |
| result_options=self.result_options, |
| ) |
| pool.queue_tests( |
| [eval_config.TestConfig(test_file=pathlib.Path('/test/a.yaml'))]) |
| with self.assertRaisesRegex(ValueError, 'Result Error'): |
| pool.wait_for_all_queued_tests() |
| pool.shutdown_blocking(1) |
| |
| def test_del(self): |
| """Tests that the destructor calls shutdown_blocking.""" |
| pool = workers.WorkerPool( |
| num_workers=1, |
| promptfoo=self.mock_promptfoo, |
| worker_options=self.worker_options, |
| result_options=self.result_options, |
| ) |
| shutdown_mock = mock.Mock() |
| pool.shutdown_blocking = shutdown_mock |
| del pool |
| shutdown_mock.assert_called_once_with(2) |
| |
| |
| if __name__ == '__main__': |
| unittest.main() |