blob: 449439097abc0a9d750daa370280e11ea9057108 [file] [log] [blame]
#!/usr/bin/env vpython3
# Copyright 2025 The Chromium Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Unit tests for install.py."""
import io
from pathlib import Path
import unittest
import unittest.mock
import install
from pyfakefs import fake_filesystem_unittest
# pylint: disable=protected-access
class InstallTest(fake_filesystem_unittest.TestCase):
"""Tests for the extension installation script."""
def setUp(self):
"""Sets up the test environment."""
self.setUpPyfakefs(additional_skip_names=['subprocess'])
self.tmpdir = '/tmp/test'
self.project_root = Path(self.tmpdir) / 'src'
self.fs.create_dir(self.project_root)
self.source_extensions_dir = self.project_root / 'agents' / 'extensions'
self.fs.create_dir(self.source_extensions_dir)
self.install_script_path = self.source_extensions_dir / 'install.py'
self.fs.create_file(self.install_script_path)
self.testing_extensions_dir = (self.project_root / 'agents' /
'testing' / 'extensions')
self.fs.create_dir(self.testing_extensions_dir)
self.internal_extensions_dir = (self.project_root / 'internal' /
'agents' / 'extensions')
self.fs.create_dir(self.internal_extensions_dir)
# Create sample extensions
self.extension1_dir = self.source_extensions_dir / 'sample_1'
self.fs.create_dir(self.extension1_dir)
self.fs.create_file(
self.extension1_dir / 'gemini-extension.json',
contents='{"name": "sample_1", "version": "1.0.0"}',
)
self.test_extension_dir = self.testing_extensions_dir / 'test_sample'
self.fs.create_dir(self.test_extension_dir)
self.fs.create_file(
self.test_extension_dir / 'gemini-extension.json',
contents='{"name": "test_sample", "version": "1.0.0"}',
)
self.mock_run_command_patcher = unittest.mock.patch(
'install._run_command')
self.mock_run_command = self.mock_run_command_patcher.start()
self.addCleanup(self.mock_run_command_patcher.stop)
self.mock_check_version = unittest.mock.patch(
'install.check_gemini_version')
self.mock_check_version.start()
self.addCleanup(self.mock_check_version.stop)
self.mock_get_project_root_patcher = unittest.mock.patch(
'install.get_project_root')
self.mock_get_project_root = self.mock_get_project_root_patcher.start()
self.addCleanup(self.mock_get_project_root_patcher.stop)
self.mock_get_project_root.return_value = self.project_root
self.mock_subprocess_run_patcher = unittest.mock.patch(
'subprocess.run')
self.mock_subprocess_run = self.mock_subprocess_run_patcher.start()
self.addCleanup(self.mock_subprocess_run_patcher.stop)
def test_find_extensions_dir_for_extension(self):
"""Tests finding an extension directory."""
extensions_dirs = install.get_extensions_dirs(self.project_root)
# Extension in source directory
ext_dir = install.find_extensions_dir_for_extension(
'sample_1', extensions_dirs)
self.assertEqual(ext_dir, self.source_extensions_dir)
# Extension in internal directory
internal_extension_dir = self.internal_extensions_dir / 'internal_ext'
self.fs.create_dir(internal_extension_dir)
self.fs.create_file(
internal_extension_dir / 'gemini-extension.json',
contents='{"name": "internal_ext", "version": "1.0.0"}',
)
extensions_dirs = install.get_extensions_dirs(self.project_root)
ext_dir = install.find_extensions_dir_for_extension(
'internal_ext', extensions_dirs)
self.assertEqual(ext_dir, self.internal_extensions_dir)
# Extension in testing directory
extensions_dirs = install.get_extensions_dirs(
self.project_root,
extra_extensions_dirs=[self.testing_extensions_dir])
ext_dir = install.find_extensions_dir_for_extension(
'test_sample', extensions_dirs)
self.assertEqual(ext_dir, self.testing_extensions_dir)
def test_get_extensions_dirs(self):
"""Tests that get_extensions_dirs returns correct directories."""
# By default, test extensions should not be included
dirs = install.get_extensions_dirs(self.project_root)
self.assertIn(self.source_extensions_dir, dirs)
self.assertIn(self.internal_extensions_dir, dirs)
self.assertNotIn(self.testing_extensions_dir, dirs)
dirs = install.get_extensions_dirs(
self.project_root,
extra_extensions_dirs=[self.testing_extensions_dir])
self.assertIn(self.source_extensions_dir, dirs)
self.assertIn(self.internal_extensions_dir, dirs)
self.assertIn(self.testing_extensions_dir, dirs)
def test_get_extensions_dirs_no_project_root(self):
"""Tests get_extensions_dirs() when no project root is found."""
extensions_dirs = install.get_extensions_dirs(None)
self.assertEqual(extensions_dirs, [])
@unittest.mock.patch('install.find_extensions_dir_for_extension')
def test_add_extension_copy(self, mock_find_dir):
"""Tests add command with copy."""
mock_find_dir.return_value = self.source_extensions_dir
with unittest.mock.patch('sys.argv',
['install.py', 'add', '--copy', 'sample_1']):
install.main()
self.mock_run_command.assert_called_once_with([
'gemini', 'extensions', 'install',
str(self.source_extensions_dir / 'sample_1')
],
skip_prompt=False)
@unittest.mock.patch('install.find_extensions_dir_for_extension')
def test_add_extension_link(self, mock_find_dir):
"""Tests add command."""
mock_find_dir.return_value = self.source_extensions_dir
with unittest.mock.patch('sys.argv',
['install.py', 'add', 'sample_1']):
install.main()
self.mock_run_command.assert_called_once_with(
[
'gemini', 'extensions', 'link',
str(self.source_extensions_dir / 'sample_1')
],
skip_prompt=False)
@unittest.mock.patch('install.find_extensions_dir_for_extension')
def test_add_extension_skip_prompt(self, mock_find_dir):
"""Tests that the skip_prompt flag is accepted."""
mock_find_dir.return_value = self.source_extensions_dir
with unittest.mock.patch(
'sys.argv',
['install.py', 'add', '--skip-prompt', 'sample_1']):
install.main()
self.mock_run_command.assert_called_once_with(
[
'gemini', 'extensions', 'link',
str(self.source_extensions_dir / 'sample_1')
],
skip_prompt=True)
def test_add_test_extension(self):
"""Tests add command with a test extension."""
with unittest.mock.patch('sys.argv', [
'install.py', '--extra-extensions-dir',
str(self.testing_extensions_dir), 'add', 'test_sample'
]):
install.main()
self.mock_run_command.assert_called_once_with(
[
'gemini', 'extensions', 'link',
str(self.testing_extensions_dir / 'test_sample')
],
skip_prompt=False)
def test_add_test_extension_without_flag_fails(self):
"""Tests add command with a test extension."""
with unittest.mock.patch('sys.argv',
['install.py', 'add', 'test_sample']):
with self.assertRaises(SystemExit):
install.main()
def test_add_invalid_extension(self):
"""Tests add command with an invalid extension."""
with unittest.mock.patch('sys.argv',
['install.py', 'add', 'nonexistent']):
with unittest.mock.patch('sys.stderr',
new_callable=io.StringIO) as mock_stderr:
with self.assertRaises(SystemExit) as e:
install.main()
self.assertEqual(e.exception.code, 1)
self.assertIn("Extension 'nonexistent' not found.",
mock_stderr.getvalue())
self.mock_run_command.assert_not_called()
def test_update_extension(self):
"""Tests update command."""
with unittest.mock.patch('sys.argv',
['install.py', 'update', 'sample_1']):
install.main()
self.mock_run_command.assert_called_once_with(
['gemini', 'extensions', 'update', 'sample_1'], skip_prompt=False)
def test_update_all_extensions(self):
"""Tests update command with no extension specified."""
with unittest.mock.patch('sys.argv', ['install.py', 'update']):
install.main()
self.mock_run_command.assert_called_once_with(
['gemini', 'extensions', 'update', '--all'])
def test_remove_extension(self):
"""Tests remove command."""
with unittest.mock.patch('sys.argv',
['install.py', 'remove', 'sample-1']):
install.main()
self.mock_run_command.assert_called_once_with(
['gemini', 'extensions', 'uninstall', 'sample-1'])
@unittest.mock.patch('pathlib.Path.home')
def test_remove_legacy_extension(self, mock_home):
"""Tests remove command for legacy extensions with underscores."""
fake_home = Path(self.tmpdir) / 'home'
mock_home.return_value = fake_home
# Set up a legacy extension
legacy_extension_dir = (install.get_global_extension_dir() /
'my_legacy_ext')
self.fs.create_dir(legacy_extension_dir)
self.assertTrue(legacy_extension_dir.exists())
with unittest.mock.patch('sys.argv',
['install.py', 'remove', 'my_legacy_ext']):
install.main()
self.mock_run_command.assert_not_called()
self.assertFalse(legacy_extension_dir.exists())
def test_list_extensions(self):
"""Tests the list command, showing all extensions."""
self.mock_subprocess_run.return_value.stdout = """
✓ user-enabled (1.0.0)
ID: abc
Path: /path/to/user-enabled
Source: /path/to/source/user-enabled (Type: link)
Enabled (User): true
Enabled (Workspace): false
✓ workspace-enabled (2.0.0)
ID: def
Path: /path/to/workspace-enabled
Source: /path/to/source/workspace-enabled (Type: local)
Enabled (User): false
Enabled (Workspace): true
✓ both-enabled (3.0.0)
ID: ghi
Path: /path/to/both-enabled
Source: /path/to/source/both-enabled (Type: local)
Enabled (User): true
Enabled (Workspace): true
"""
self.mock_subprocess_run.return_value.returncode = 0
with unittest.mock.patch('sys.argv', ['install.py', 'list']):
with unittest.mock.patch('sys.stdout',
new_callable=io.StringIO) as mock_stdout:
install.main()
output = mock_stdout.getvalue()
expected_extensions = {
'workspace-enabled':
install.ExtensionInfo(name='workspace-enabled',
installed='2.0.0',
linked=False,
enabled_for_workspace=True),
'user-enabled':
install.ExtensionInfo(name='user-enabled',
installed='1.0.0',
linked=True,
enabled_for_workspace=False),
'both-enabled':
install.ExtensionInfo(name='both-enabled',
installed='3.0.0',
linked=False,
enabled_for_workspace=True),
'sample_1':
install.ExtensionInfo(name='sample_1', available='1.0.0'),
}
with unittest.mock.patch('sys.stdout',
new_callable=io.StringIO) as expected_stdout:
install._print_extensions_table(expected_extensions)
expected_output = expected_stdout.getvalue()
self.assertEqual(output, expected_output)
def test_list_extensions_no_installed(self):
"""Tests the list command with no installed extensions."""
self.mock_subprocess_run.return_value.stdout = ''
self.mock_subprocess_run.return_value.returncode = 0
with unittest.mock.patch('sys.argv', ['install.py', 'list']):
with unittest.mock.patch('sys.stdout',
new_callable=io.StringIO) as mock_stdout:
install.main()
output = mock_stdout.getvalue()
expected_extensions = {
'sample_1': install.ExtensionInfo(name='sample_1',
available='1.0.0'),
}
with unittest.mock.patch('sys.stdout',
new_callable=io.StringIO) as expected_stdout:
install._print_extensions_table(expected_extensions)
expected_output = expected_stdout.getvalue()
self.assertEqual(output, expected_output)
def test_list_extensions_empty_table(self):
"""Tests the list command with no available or installed extensions."""
self.mock_subprocess_run.return_value.stdout = ''
self.mock_subprocess_run.return_value.returncode = 0
# Remove the sample extension created in setUp
self.fs.remove_object(str(self.extension1_dir))
with unittest.mock.patch('sys.argv', ['install.py', 'list']):
with unittest.mock.patch('sys.stdout',
new_callable=io.StringIO) as mock_stdout:
install.main()
output = mock_stdout.getvalue()
expected_output = (
'EXTENSION AVAILABLE INSTALLED LINKED ENABLED\n'
'--------- --------- --------- ------ -------\n')
self.assertEqual(output, expected_output)
def test_print_extensions_table_formatting(self):
"""Tests the formatting of the extensions table."""
extensions_data = {
'ext_a':
install.ExtensionInfo(name='ext_a',
available='1.0.0',
installed='1.0.0',
linked=True,
enabled_for_workspace=True),
'another_extension':
install.ExtensionInfo(name='another_extension',
available='2.0.0',
installed='-',
linked=False,
enabled_for_workspace=False),
'third_ext':
install.ExtensionInfo(name='third_ext',
available='-',
installed='3.0.0',
linked=False,
enabled_for_workspace=True),
}
expected_output = (
'EXTENSION AVAILABLE INSTALLED LINKED ENABLED \n'
'----------------- --------- --------- ------ ---------\n'
'another_extension 2.0.0 - no - \n'
'ext_a 1.0.0 1.0.0 yes workspace\n'
'third_ext - 3.0.0 no workspace\n')
with unittest.mock.patch('sys.stdout',
new_callable=io.StringIO) as mock_stdout:
install._print_extensions_table(extensions_data)
self.assertEqual(mock_stdout.getvalue(), expected_output)
def test_find_extensions_dir_for_nonexistent_extension(self):
"""Tests finding a non-existent extension."""
extensions_dirs = install.get_extensions_dirs(self.project_root)
ext_dir = install.find_extensions_dir_for_extension(
'nonexistent', extensions_dirs)
self.assertIsNone(ext_dir)
@unittest.mock.patch('install.find_extensions_dir_for_extension')
def test_fix_extensions(self, mock_find_dir):
"""Tests fix command."""
mock_find_dir.return_value = self.source_extensions_dir
project_extensions_dir = self.project_root / '.gemini' / 'extensions'
self.fs.create_dir(project_extensions_dir)
self.fs.create_file(
project_extensions_dir / 'sample_1' / 'gemini-extension.json',
contents='{"name": "sample_1", "version": "1.0.0"}',
)
with unittest.mock.patch('sys.argv', ['install.py', 'fix']):
install.main()
calls = [
unittest.mock.call([
'gemini', 'extensions', 'link',
str(self.source_extensions_dir / 'sample_1')
]),
unittest.mock.call([
'gemini', 'extensions', 'disable', 'sample_1', '--scope=User'
]),
unittest.mock.call([
'gemini', 'extensions', 'enable', 'sample_1',
'--scope=Workspace'
]),
]
self.mock_run_command.assert_has_calls(calls)
self.assertFalse(project_extensions_dir.exists())
def test_fix_extensions_no_project_dir(self):
"""Tests fix command when no project-level directory exists."""
with unittest.mock.patch('sys.stdout',
new_callable=io.StringIO) as mock_stdout:
with unittest.mock.patch('sys.argv', ['install.py', 'fix']):
install.main()
self.assertIn('No project-level extensions found to fix.',
mock_stdout.getvalue())
self.mock_run_command.assert_not_called()
def test_fix_extensions_no_extensions(self):
"""Tests fix command when no project-level extensions are found."""
project_extensions_dir = self.project_root / '.gemini' / 'extensions'
self.fs.create_dir(project_extensions_dir)
with unittest.mock.patch('sys.stdout',
new_callable=io.StringIO) as mock_stdout:
with unittest.mock.patch('sys.argv', ['install.py', 'fix']):
install.main()
self.assertIn(
'No valid project-level extensions found.',
mock_stdout.getvalue(),
)
self.mock_run_command.assert_not_called()
self.assertFalse(project_extensions_dir.exists())
@unittest.mock.patch('pathlib.Path.home')
def test_fix_skips_existing_user_extension(self, mock_home):
"""Tests that fix skips extensions that already exist for the user."""
fake_home = Path(self.tmpdir) / 'home'
mock_home.return_value = fake_home
# Set up a user-level extension
(install.get_global_extension_dir() / 'sample_1').mkdir(parents=True)
# Create a project-level extension with the same name
project_extensions_dir = self.project_root / '.gemini' / 'extensions'
self.fs.create_dir(project_extensions_dir)
self.fs.create_file(
project_extensions_dir / 'sample_1' / 'gemini-extension.json',
contents='{"name": "sample_1", "version": "1.0.0"}',
)
with unittest.mock.patch('sys.stderr',
new_callable=io.StringIO) as mock_stderr:
with unittest.mock.patch('sys.argv', ['install.py', 'fix']):
install.main()
self.assertIn(
'Warning: User extension "sample_1" already exists.',
mock_stderr.getvalue(),
)
self.mock_run_command.assert_not_called()
self.assertFalse(project_extensions_dir.exists())
def test_prompt_for_fix(self):
"""Tests that the user is prompted to run fix."""
project_extensions_dir = self.project_root / '.gemini' / 'extensions'
self.fs.create_dir(project_extensions_dir)
with unittest.mock.patch('sys.stderr',
new_callable=io.StringIO) as mock_stderr:
with unittest.mock.patch('sys.argv', ['install.py', 'list']):
install.main()
self.assertIn('WARNING: Project-level extensions are deprecated.',
mock_stderr.getvalue())
def test_get_project_root(self):
"""Tests the get_project_root function."""
with unittest.mock.patch('install._PROJECT_ROOT', self.project_root):
project_root = install.get_project_root()
self.assertEqual(project_root, self.project_root)
if __name__ == '__main__':
unittest.main()