blob: 85a86b93ae6e08da86e6bc63d5baf1459518f2d7 [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/ai/ai_manager.h"
#include <memory>
#include "base/memory/raw_ptr.h"
#include "base/task/current_thread.h"
#include "base/test/mock_callback.h"
#include "base/test/scoped_feature_list.h"
#include "chrome/browser/ai/ai_language_model.h"
#include "chrome/browser/ai/ai_test_utils.h"
#include "chrome/browser/optimization_guide/mock_optimization_guide_keyed_service.h"
#include "components/keyed_service/core/keyed_service.h"
#include "components/optimization_guide/core/mock_optimization_guide_model_executor.h"
#include "components/optimization_guide/core/model_execution/test/fake_model_broker.h"
#include "components/optimization_guide/core/optimization_guide_model_executor.h"
#include "components/optimization_guide/core/optimization_guide_switches.h"
#include "components/policy/core/common/policy_pref_names.h"
#include "components/prefs/pref_service.h"
#include "content/public/browser/web_contents.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/blink/public/common/features_generated.h"
#include "third_party/blink/public/mojom/ai/ai_common.mojom-forward.h"
#include "third_party/blink/public/mojom/ai/ai_common.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_language_model.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_manager.mojom-shared.h"
#include "third_party/blink/public/mojom/ai/ai_manager.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_rewriter.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_summarizer.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_writer.mojom.h"
using optimization_guide::MockSession;
using testing::_;
using testing::AtMost;
using testing::Invoke;
using testing::NiceMock;
namespace {
std::vector<blink::mojom::AILanguageCodePtr> MakeLanguageCodeVector(
const std::vector<std::string>& languages) {
std::vector<blink::mojom::AILanguageCodePtr> result;
for (const auto& language : languages) {
result.push_back(blink::mojom::AILanguageCode::New(language));
}
return result;
}
class AIManagerTest : public AITestUtils::AITestBase {
protected:
AIManagerTest()
: fake_broker_(optimization_guide::FakeAdaptationAsset({
.config =
[] {
optimization_guide::proto::OnDeviceModelExecutionFeatureConfig
config;
config.set_can_skip_text_safety(true);
config.set_feature(
optimization_guide::proto::ModelExecutionFeature::
MODEL_EXECUTION_FEATURE_PROMPT_API);
return config;
}(),
})) {}
void SetUp() override {
AITestUtils::AITestBase::SetUp();
SetupMockOptimizationGuideKeyedService();
ai_manager_ =
std::make_unique<AIManager>(main_rfh()->GetBrowserContext(),
&component_update_service_, main_rfh());
}
void TearDown() override {
ai_manager_.reset();
AITestUtils::AITestBase::TearDown();
}
void SetupMockOptimizationGuideKeyedService() override {
AITestUtils::AITestBase::SetupMockOptimizationGuideKeyedService();
ON_CALL(*mock_optimization_guide_keyed_service_, StartSession(_, _))
.WillByDefault(
[&] { return std::make_unique<NiceMock<MockSession>>(&session_); });
ON_CALL(session_, GetTokenLimits())
.WillByDefault(AITestUtils::GetFakeTokenLimits);
ON_CALL(session_, GetExecutionInputSizeInTokens(_, _))
.WillByDefault(
[&](optimization_guide::MultimodalMessageReadView request_metadata,
optimization_guide::OptimizationGuideModelSizeInTokenCallback
callback) {
std::move(callback).Run(
blink::mojom::kWritingAssistanceMaxInputTokenSize);
});
ON_CALL(session_, GetOnDeviceFeatureMetadata())
.WillByDefault(AITestUtils::GetFakeFeatureMetadata);
ON_CALL(*mock_optimization_guide_keyed_service_, GetOnDeviceCapabilities())
.WillByDefault(testing::Return(on_device_model::Capabilities()));
ON_CALL(*mock_optimization_guide_keyed_service_,
GetOnDeviceModelEligibility(_))
.WillByDefault(testing::Return(
optimization_guide::OnDeviceModelEligibilityReason::kSuccess));
ON_CALL(*mock_optimization_guide_keyed_service_, CreateModelBrokerClient())
.WillByDefault([&]() {
return std::make_unique<optimization_guide::ModelBrokerClient>(
fake_broker_.BindAndPassRemote(),
optimization_guide::CreateSessionArgs(nullptr, {}));
});
}
void SetBuildInAIAPIsEnterprisePolicy(bool value) {
profile()->GetPrefs()->SetBoolean(
policy::policy_prefs::kBuiltInAIAPIsEnabled, value);
}
private:
testing::NiceMock<MockSession> session_;
optimization_guide::FakeModelBroker fake_broker_;
};
// Tests that involve invalid on-device model file paths should not crash when
// the associated RFH is destroyed.
TEST_F(AIManagerTest, NoUAFWithInvalidOnDeviceModelPath) {
auto* command_line = base::CommandLine::ForCurrentProcess();
command_line->AppendSwitchASCII(
optimization_guide::switches::kOnDeviceModelExecutionOverride,
"invalid-on-device-model-file-path");
base::MockCallback<blink::mojom::AIManager::CanCreateLanguageModelCallback>
callback;
EXPECT_CALL(callback, Run(_)).Times(AtMost(1));
ai_manager_->CanCreateLanguageModel(/*options=*/{}, callback.Get());
// The callback may still be pending, delete the WebContents and destroy the
// associated RFH, which should not result in a UAF.
DeleteContents();
task_environment()->RunUntilIdle();
}
// Tests the `AIUserDataSet`'s behavior of managing the lifetime of
// `AILanguageModel`s.
TEST_F(AIManagerTest, AIContextBoundObjectSet) {
mojo::Remote<blink::mojom::AILanguageModel> mock_session;
AITestUtils::MockCreateLanguageModelClient mock_create_language_model_client;
base::RunLoop run_loop;
EXPECT_CALL(mock_create_language_model_client, OnResult(_, _))
.WillOnce(testing::Invoke(
[&](mojo::PendingRemote<blink::mojom::AILanguageModel> language_model,
blink::mojom::AILanguageModelInstanceInfoPtr info) {
EXPECT_TRUE(language_model);
mock_session = mojo::Remote<blink::mojom::AILanguageModel>(
std::move(language_model));
run_loop.Quit();
}));
mojo::Remote<blink::mojom::AIManager> mock_remote = GetAIManagerRemote();
// Initially the `AIContextBoundObjectSet` is empty.
ASSERT_EQ(0u, GetAIManagerContextBoundObjectSetSize());
// After creating one `AILanguageModel`, the `AIContextBoundObjectSet`
// contains 1 element.
mock_remote->CreateLanguageModel(
mock_create_language_model_client.BindNewPipeAndPassRemote(),
blink::mojom::AILanguageModelCreateOptions::New(
/*sampling_params=*/nullptr,
/*initial_prompts=*/
std::vector<blink::mojom::AILanguageModelPromptPtr>(),
/*expected_inputs=*/std::nullopt,
/*expected_outputs=*/std::nullopt));
run_loop.Run();
ASSERT_EQ(1u, GetAIManagerContextBoundObjectSetSize());
// After resetting the session, the size of `AIContextBoundObjectSet` becomes
// empty again.
mock_session.reset();
ASSERT_TRUE(base::test::RunUntil(
[&] { return GetAIManagerContextBoundObjectSetSize() == 0u; }));
}
TEST_F(AIManagerTest, CanCreate) {
base::MockCallback<
base::OnceCallback<void(blink::mojom::ModelAvailabilityCheckResult)>>
callback;
EXPECT_CALL(callback,
Run(blink::mojom::ModelAvailabilityCheckResult::kAvailable))
.Times(4);
ai_manager_->CanCreateLanguageModel(/*options=*/{}, callback.Get());
ai_manager_->CanCreateWriter(/*options=*/{}, callback.Get());
ai_manager_->CanCreateSummarizer(/*options=*/{}, callback.Get());
ai_manager_->CanCreateRewriter(/*options=*/{}, callback.Get());
}
TEST_F(AIManagerTest, CanCreateNotEnabled) {
EXPECT_CALL(*mock_optimization_guide_keyed_service_,
GetOnDeviceModelEligibilityAsync(_, _, _))
.Times(4)
.WillRepeatedly([](auto feature, auto capabilities, auto callback) {
std::move(callback).Run(
optimization_guide::OnDeviceModelEligibilityReason::
kFeatureNotEnabled);
});
base::MockCallback<
base::OnceCallback<void(blink::mojom::ModelAvailabilityCheckResult)>>
callback;
EXPECT_CALL(callback, Run(blink::mojom::ModelAvailabilityCheckResult::
kUnavailableFeatureNotEnabled))
.Times(4);
ai_manager_->CanCreateLanguageModel(/*options=*/{}, callback.Get());
ai_manager_->CanCreateWriter(/*options=*/{}, callback.Get());
ai_manager_->CanCreateSummarizer(/*options=*/{}, callback.Get());
ai_manager_->CanCreateRewriter(/*options=*/{}, callback.Get());
}
TEST_F(AIManagerTest, CanCreateSessionWithTextInputCapabilities) {
base::MockCallback<blink::mojom::AIManager::CanCreateLanguageModelCallback>
callback;
optimization_guide::ModelBasedCapabilityKey key =
optimization_guide::ModelBasedCapabilityKey::kPromptApi;
EXPECT_CALL(callback,
Run(blink::mojom::ModelAvailabilityCheckResult::kAvailable))
.Times(1);
EXPECT_CALL(callback, Run(blink::mojom::ModelAvailabilityCheckResult::
kUnavailableModelAdaptationNotAvailable))
.Times(2);
on_device_model::Capabilities capabilities;
ai_manager_->CanCreateSession(key, capabilities, callback.Get());
capabilities.Put(on_device_model::CapabilityFlags::kImageInput);
ai_manager_->CanCreateSession(key, capabilities, callback.Get());
capabilities.Clear();
capabilities.Put(on_device_model::CapabilityFlags::kAudioInput);
ai_manager_->CanCreateSession(key, capabilities, callback.Get());
}
TEST_F(AIManagerTest, CanCreateSessionWithImageAndAudioInputCapabilities) {
base::test::ScopedFeatureList scoped_feature_list(
blink::features::kAIPromptAPIMultimodalInput);
EXPECT_CALL(*mock_optimization_guide_keyed_service_,
GetOnDeviceCapabilities())
.Times(2)
.WillRepeatedly(testing::Return(on_device_model::Capabilities(
{on_device_model::CapabilityFlags::kImageInput,
on_device_model::CapabilityFlags::kAudioInput})));
base::MockCallback<blink::mojom::AIManager::CanCreateLanguageModelCallback>
callback;
optimization_guide::ModelBasedCapabilityKey key =
optimization_guide::ModelBasedCapabilityKey::kPromptApi;
EXPECT_CALL(callback,
Run(blink::mojom::ModelAvailabilityCheckResult::kAvailable))
.Times(3);
on_device_model::Capabilities capabilities;
ai_manager_->CanCreateSession(key, capabilities, callback.Get());
capabilities.Put(on_device_model::CapabilityFlags::kImageInput);
ai_manager_->CanCreateSession(key, capabilities, callback.Get());
capabilities.Put(on_device_model::CapabilityFlags::kAudioInput);
ai_manager_->CanCreateSession(key, capabilities, callback.Get());
}
TEST_F(AIManagerTest, CanCreateEnterprisePolicyDisabled) {
SetBuildInAIAPIsEnterprisePolicy(false);
base::MockCallback<
base::OnceCallback<void(blink::mojom::ModelAvailabilityCheckResult)>>
callback;
EXPECT_CALL(callback, Run(blink::mojom::ModelAvailabilityCheckResult::
kUnavailableEnterprisePolicyDisabled))
.Times(4);
ai_manager_->CanCreateLanguageModel(/*options=*/{}, callback.Get());
ai_manager_->CanCreateWriter(/*options=*/{}, callback.Get());
ai_manager_->CanCreateSummarizer(/*options=*/{}, callback.Get());
ai_manager_->CanCreateRewriter(/*options=*/{}, callback.Get());
SetBuildInAIAPIsEnterprisePolicy(true);
}
// Test CheckAndFixLanguages templates for LanguageModel.
TEST_F(AIManagerTest, CheckAndFixLanguagesLanguageModel) {
base::flat_set<std::string_view> supported = {"en", "es", "ja"};
auto make_expected = [](const base::flat_set<std::string>& languages) {
auto expected = blink::mojom::AILanguageModelExpected::New();
expected->languages.emplace();
for (const auto& language : languages) {
expected->languages->push_back(
blink::mojom::AILanguageCode::New(language));
}
return expected;
};
auto make_options = [&](const base::flat_set<std::string>& inputs,
const base::flat_set<std::string>& outputs) {
auto options = blink::mojom::AILanguageModelCreateOptions::New();
options->expected_inputs.emplace();
options->expected_inputs->push_back(make_expected(inputs));
options->expected_outputs.emplace();
options->expected_outputs->push_back(make_expected(outputs));
return options;
};
auto options = blink::mojom::AILanguageModelCreateOptions::New();
EXPECT_TRUE(ai_manager_->CheckAndFixLanguages(options, "API", supported));
options = make_options({"en", "es-MX"}, {});
EXPECT_TRUE(ai_manager_->CheckAndFixLanguages(options, "API", supported));
options = make_options({}, {"en-UK", "es-SP", "ja-JP"});
EXPECT_TRUE(ai_manager_->CheckAndFixLanguages(options, "API", supported));
options = make_options({"en", "fr"}, {});
EXPECT_FALSE(ai_manager_->CheckAndFixLanguages(options, "API", supported));
options = make_options({"en"}, {"hi"});
EXPECT_FALSE(ai_manager_->CheckAndFixLanguages(options, "API", supported));
}
// Test CheckAndFixLanguages templates for Summarizer, Writer, and Rewriter.
TEST_F(AIManagerTest, CheckAndFixLanguagesWritingAssistance) {
base::flat_set<std::string_view> supported = {"en", "es", "ja"};
auto make_options = [](const std::vector<std::string>& input,
const std::vector<std::string>& context,
const std::string& output) {
auto options = blink::mojom::AISummarizerCreateOptions::New();
options->expected_input_languages = MakeLanguageCodeVector(input);
options->expected_context_languages = MakeLanguageCodeVector(context);
options->output_language = blink::mojom::AILanguageCode::New(output);
return options;
};
auto options = blink::mojom::AISummarizerCreateOptions::New();
EXPECT_TRUE(ai_manager_->CheckAndFixLanguages(options, "API", supported));
options = make_options({}, {}, "");
EXPECT_TRUE(ai_manager_->CheckAndFixLanguages(options, "API", supported));
EXPECT_TRUE(options->output_language->code.empty());
options = make_options({"en", "es-MX"}, {"ja"}, "en-US");
EXPECT_TRUE(ai_manager_->CheckAndFixLanguages(options, "API", supported));
options = make_options({"en-UK", "en-US"}, {"en"}, "");
EXPECT_TRUE(ai_manager_->CheckAndFixLanguages(options, "API", supported));
EXPECT_EQ(options->output_language->code, "en-UK");
options = make_options({"en", "fr"}, {}, "hi");
EXPECT_FALSE(ai_manager_->CheckAndFixLanguages(options, "API", supported));
}
// Test CheckAndFixLanguages templates for Proofreader.
TEST_F(AIManagerTest, CheckAndFixLanguagesProofreader) {
base::flat_set<std::string_view> supported = {"en", "es", "ja"};
auto make_options = [](const std::vector<std::string>& input,
const std::string& correction_explanation) {
auto options = blink::mojom::AIProofreaderCreateOptions::New();
options->expected_input_languages = MakeLanguageCodeVector(input);
options->correction_explanation_language =
blink::mojom::AILanguageCode::New(correction_explanation);
return options;
};
auto options = blink::mojom::AIProofreaderCreateOptions::New();
EXPECT_TRUE(ai_manager_->CheckAndFixLanguages(options, "API", supported));
options = make_options({}, "");
EXPECT_TRUE(ai_manager_->CheckAndFixLanguages(options, "API", supported));
EXPECT_TRUE(options->correction_explanation_language->code.empty());
options = make_options({"en", "es-MX", "ja"}, "en-US");
EXPECT_TRUE(ai_manager_->CheckAndFixLanguages(options, "API", supported));
options = make_options({"en-UK", "en-US", "en"}, "");
EXPECT_TRUE(ai_manager_->CheckAndFixLanguages(options, "API", supported));
EXPECT_EQ(options->correction_explanation_language->code, "en-UK");
options = make_options({"en", "fr"}, "hi");
EXPECT_FALSE(ai_manager_->CheckAndFixLanguages(options, "API", supported));
}
} // namespace