| // Copyright 2024 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "chrome/browser/ai/ai_language_model.h" |
| |
| #include <cstdint> |
| #include <initializer_list> |
| #include <optional> |
| #include <string> |
| #include <vector> |
| |
| #include "base/functional/bind.h" |
| #include "base/functional/callback_forward.h" |
| #include "base/functional/callback_helpers.h" |
| #include "base/notreached.h" |
| #include "base/strings/strcat.h" |
| #include "base/strings/stringprintf.h" |
| #include "base/task/current_thread.h" |
| #include "base/test/mock_callback.h" |
| #include "base/test/scoped_feature_list.h" |
| #include "base/test/test_future.h" |
| #include "chrome/browser/ai/ai_test_utils.h" |
| #include "chrome/browser/ai/ai_utils.h" |
| #include "chrome/browser/ai/features.h" |
| #include "components/optimization_guide/core/mock_optimization_guide_model_executor.h" |
| #include "components/optimization_guide/core/model_execution/multimodal_message.h" |
| #include "components/optimization_guide/core/optimization_guide_features.h" |
| #include "components/optimization_guide/core/optimization_guide_model_executor.h" |
| #include "components/optimization_guide/core/optimization_guide_proto_util.h" |
| #include "components/optimization_guide/proto/common_types.pb.h" |
| #include "components/optimization_guide/proto/descriptors.pb.h" |
| #include "components/optimization_guide/proto/features/prompt_api.pb.h" |
| #include "components/optimization_guide/proto/on_device_model_execution_config.pb.h" |
| #include "components/optimization_guide/proto/string_value.pb.h" |
| #include "content/public/browser/render_widget_host_view.h" |
| #include "services/on_device_model/public/cpp/capabilities.h" |
| #include "services/on_device_model/public/mojom/on_device_model.mojom.h" |
| #include "testing/gmock/include/gmock/gmock.h" |
| #include "testing/gtest/include/gtest/gtest.h" |
| #include "third_party/blink/public/common/features_generated.h" |
| #include "third_party/blink/public/mojom/ai/ai_common.mojom.h" |
| #include "third_party/blink/public/mojom/ai/ai_language_model.mojom.h" |
| #include "third_party/blink/public/mojom/ai/ai_manager.mojom-shared.h" |
| #include "third_party/blink/public/mojom/ai/model_download_progress_observer.mojom-forward.h" |
| #include "third_party/blink/public/mojom/ai/model_streaming_responder.mojom-shared.h" |
| |
| namespace { |
| |
| using ::optimization_guide::MultimodalMessage; |
| using ::optimization_guide::MultimodalMessageReadView; |
| using ::optimization_guide::proto::PromptApiPrompt; |
| using ::optimization_guide::proto::PromptApiRequest; |
| using ::optimization_guide::proto::PromptApiRole; |
| using ::optimization_guide::proto::ProtoField; |
| using ::testing::_; |
| using ::testing::Return; |
| using ::testing::ReturnRef; |
| using ::testing::Test; |
| using Role = ::blink::mojom::AILanguageModelPromptRole; |
| using SetInputCallback = ::optimization_guide::OptimizationGuideModelExecutor:: |
| Session::SetInputCallback; |
| |
| constexpr uint32_t kTestMaxContextToken = 10u; |
| constexpr uint32_t kTestInitialPromptsToken = 5u; |
| constexpr uint32_t kTestDefaultTopK = 1u; |
| constexpr float kTestDefaultTemperature = 0.3; |
| constexpr uint32_t kTestMaxTopK = 5u; |
| constexpr float kTestMaxTemperature = 1.5; |
| constexpr uint64_t kTestModelDownloadSize = 572u; |
| static_assert(kTestDefaultTopK <= kTestMaxTopK); |
| static_assert(kTestDefaultTemperature <= kTestMaxTemperature); |
| |
| const char kTestPrompt[] = "Test prompt"; |
| const char kExpectedFormattedTestPrompt[] = "U: Test prompt\nM: "; |
| const char kTestSystemPrompts[] = "Test system prompt"; |
| const char kExpectedFormattedSystemPrompts[] = "S: Test system prompt\n"; |
| const char kTestResponse[] = "Test response"; |
| |
| const char kTestInitialPromptsUser1[] = "How are you?"; |
| const char kTestInitialPromptsSystem1[] = "I'm fine, thank you, and you?"; |
| const char kTestInitialPromptsUser2[] = "I'm fine too."; |
| const char kExpectedFormattedInitialPrompts[] = |
| ("U: How are you?\n" |
| "M: I'm fine, thank you, and you?\n" |
| "U: I'm fine too.\n"); |
| const char kExpectedFormattedSystemPromptAndInitialPrompts[] = |
| ("S: Test system prompt\n" |
| "U: How are you?\n" |
| "M: I'm fine, thank you, and you?\n" |
| "U: I'm fine too.\n"); |
| |
| SkBitmap CreateTestBitmap(int width, int height) { |
| SkBitmap bitmap; |
| bitmap.allocN32Pixels(width, height); |
| bitmap.eraseColor(SK_ColorRED); |
| return bitmap; |
| } |
| |
| on_device_model::mojom::AudioDataPtr CreateTestAudio() { |
| return on_device_model::mojom::AudioData::New(); |
| } |
| |
| // Build a mojo prompt struct with the specified `role` and `text` |
| blink::mojom::AILanguageModelPromptPtr MakePrompt(Role role, |
| const std::string& text) { |
| return blink::mojom::AILanguageModelPrompt::New( |
| role, blink::mojom::AILanguageModelPromptContent::NewText(text)); |
| } |
| |
| // Build a mojo prompt struct array holding a single piece of text. |
| std::vector<blink::mojom::AILanguageModelPromptPtr> MakeInput( |
| const std::string& text) { |
| std::vector<blink::mojom::AILanguageModelPromptPtr> prompts; |
| prompts.push_back(MakePrompt(Role::kUser, text)); |
| return prompts; |
| } |
| |
| // Build a mojo prompt struct array with a simple set of initial prompts. |
| std::vector<blink::mojom::AILanguageModelPromptPtr> GetTestInitialPrompts() { |
| std::vector<blink::mojom::AILanguageModelPromptPtr> prompts; |
| prompts.push_back(MakePrompt(Role::kUser, kTestInitialPromptsUser1)); |
| prompts.push_back(MakePrompt(Role::kAssistant, kTestInitialPromptsSystem1)); |
| prompts.push_back(MakePrompt(Role::kUser, kTestInitialPromptsUser2)); |
| return prompts; |
| } |
| |
| // Construct a ContextItem with system prompt text. |
| AILanguageModel::Context::ContextItem SimpleContextItem(std::string text, |
| uint32_t size) { |
| auto item = AILanguageModel::Context::ContextItem(); |
| item.tokens = size; |
| item.prompts.emplace_back( |
| MakePrompt(blink::mojom::AILanguageModelPromptRole::kSystem, text)); |
| return item; |
| } |
| |
| // Convert a PromptApiRole to a string for expectation matching. |
| const char* FormatPromptRole(PromptApiRole role) { |
| switch (role) { |
| case PromptApiRole::PROMPT_API_ROLE_SYSTEM: |
| return "S: "; |
| case PromptApiRole::PROMPT_API_ROLE_USER: |
| return "U: "; |
| case PromptApiRole::PROMPT_API_ROLE_ASSISTANT: |
| return "M: "; |
| default: |
| NOTREACHED(); |
| } |
| } |
| |
| // Construct a ProtoField message that selects a field from it's tag path. |
| ProtoField FieldWithTags(std::initializer_list<int32_t> tags) { |
| ProtoField result; |
| for (int32_t tag : tags) { |
| result.add_proto_descriptors()->set_tag_number(tag); |
| } |
| return result; |
| } |
| |
| // Convert a MultimodalMessageReadView of PromptApiPrompt to string for |
| // expectation matching. |
| void FormatPrompt(std::ostringstream& oss, MultimodalMessageReadView view) { |
| PromptApiRole role = static_cast<PromptApiRole>( |
| view.GetValue(FieldWithTags({PromptApiPrompt::kRoleFieldNumber})) |
| ->int32_value()); |
| oss << FormatPromptRole(role); |
| oss << view.GetValue(FieldWithTags({PromptApiPrompt::kTextFieldNumber})) |
| ->string_value(); |
| if (view.GetImage(FieldWithTags({PromptApiPrompt::kMediaFieldNumber}))) { |
| oss << "<image>"; |
| } |
| if (view.GetAudio(FieldWithTags({PromptApiPrompt::kMediaFieldNumber}))) { |
| oss << "<audio>"; |
| } |
| oss << "\n"; |
| } |
| |
| // Convert a RepeatedMultimodalMessageReadView of PromptApiPrompts to string for |
| // expectation matching. |
| void FormatPrompts(std::ostringstream& oss, |
| optimization_guide::RepeatedMultimodalMessageReadView view) { |
| int size = view.Size(); |
| for (int i = 0; i < size; i++) { |
| FormatPrompt(oss, view.Get(i)); |
| } |
| } |
| |
| // Convert a MultimodalMessageReadView of PromptApiRequest to string for |
| // expectation matching. |
| void FormatRequest(std::ostringstream& oss, MultimodalMessageReadView view) { |
| FormatPrompts(oss, *view.GetRepeated(FieldWithTags( |
| {PromptApiRequest::kInitialPromptsFieldNumber}))); |
| FormatPrompts(oss, *view.GetRepeated(FieldWithTags( |
| {PromptApiRequest::kPromptHistoryFieldNumber}))); |
| FormatPrompts(oss, *view.GetRepeated(FieldWithTags( |
| {PromptApiRequest::kCurrentPromptsFieldNumber}))); |
| if (view.GetRepeated( |
| FieldWithTags({PromptApiRequest::kCurrentPromptsFieldNumber})) |
| ->Size() > 0) { |
| oss << FormatPromptRole(PromptApiRole::PROMPT_API_ROLE_ASSISTANT); |
| } |
| } |
| |
| // Convert a MultimodalMessage to string for expectation matching. |
| std::string ToString(const optimization_guide::MultimodalMessage& request) { |
| if (request.GetTypeName() == "optimization_guide.proto.PromptApiRequest") { |
| std::ostringstream oss; |
| FormatRequest(oss, request.read()); |
| return oss.str(); |
| } |
| return "unexpected type"; |
| } |
| |
| // Convert a Context to string for expectation matching. |
| std::string GetContextString(AILanguageModel::Context& ctx) { |
| return ToString(ctx.MakeRequest(on_device_model::Capabilities())); |
| } |
| |
| const optimization_guide::proto::Any& GetPromptApiMetadata() { |
| static base::NoDestructor<optimization_guide::proto::Any> data([]() { |
| optimization_guide::proto::PromptApiMetadata metadata; |
| metadata.set_version(AILanguageModel::kMinVersionUsingProto); |
| return optimization_guide::AnyWrapProto(metadata); |
| }()); |
| return *data; |
| } |
| |
| optimization_guide::OptimizationGuideModelStreamingExecutionResult |
| CreateExecutionResult(const std::string& output, |
| bool is_complete, |
| uint32_t input_token_count, |
| uint32_t output_token_count) { |
| optimization_guide::proto::StringValue response; |
| response.set_value(output); |
| |
| return optimization_guide::OptimizationGuideModelStreamingExecutionResult( |
| optimization_guide::StreamingResponse{ |
| .response = optimization_guide::AnyWrapProto(response), |
| .is_complete = is_complete, |
| .input_token_count = input_token_count, |
| .output_token_count = output_token_count}, |
| /*provided_by_on_device=*/true); |
| } |
| |
| class AILanguageModelTest : public AITestUtils::AITestBase { |
| public: |
| struct Options { |
| blink::mojom::AILanguageModelSamplingParamsPtr sampling_params = nullptr; |
| std::optional<std::string> system_prompt = std::nullopt; |
| std::vector<blink::mojom::AILanguageModelPromptPtr> initial_prompts; |
| std::string prompt_input = kTestPrompt; |
| std::string expected_context = ""; |
| std::string expected_cloned_context = |
| base::StrCat({kExpectedFormattedTestPrompt, kTestResponse, "\n"}); |
| std::string expected_prompt = kExpectedFormattedTestPrompt; |
| bool should_overflow_context = false; |
| bool should_use_supported_language = true; |
| }; |
| |
| protected: |
| void SetupMockOptimizationGuideKeyedService() override { |
| AITestUtils::AITestBase::SetupMockOptimizationGuideKeyedService(); |
| ON_CALL(*mock_optimization_guide_keyed_service_, GetSamplingParamsConfig(_)) |
| .WillByDefault([](optimization_guide::ModelBasedCapabilityKey feature) { |
| return optimization_guide::SamplingParamsConfig{ |
| .default_top_k = kTestDefaultTopK, |
| .default_temperature = kTestDefaultTemperature}; |
| }); |
| |
| ON_CALL(*mock_optimization_guide_keyed_service_, GetFeatureMetadata(_)) |
| .WillByDefault([](optimization_guide::ModelBasedCapabilityKey feature) { |
| optimization_guide::proto::SamplingParams sampling_params; |
| sampling_params.set_top_k(kTestMaxTopK); |
| sampling_params.set_temperature(kTestMaxTemperature); |
| optimization_guide::proto::PromptApiMetadata metadata; |
| *metadata.mutable_max_sampling_params() = sampling_params; |
| optimization_guide::proto::Any any; |
| any.set_value(metadata.SerializeAsString()); |
| any.set_type_url( |
| base::StrCat({"type.googleapis.com/", metadata.GetTypeName()})); |
| return any; |
| }); |
| } |
| |
| // The helper function that creates a `AILanguageModel` and executes the |
| // prompt. |
| void RunPromptTest(Options options) { |
| blink::mojom::AILanguageModelSamplingParamsPtr sampling_params_copy; |
| if (options.sampling_params) { |
| sampling_params_copy = options.sampling_params->Clone(); |
| } |
| |
| // Set up mock service. |
| SetupMockOptimizationGuideKeyedService(); |
| |
| if (options.should_use_supported_language) { |
| // `StartSession()` will run twice when creating and cloning the session. |
| EXPECT_CALL(*mock_optimization_guide_keyed_service_, StartSession(_, _)) |
| .Times(2) |
| .WillOnce([&](optimization_guide::ModelBasedCapabilityKey feature, |
| const std::optional< |
| optimization_guide::SessionConfigParams>& |
| config_params) { |
| auto session = std::make_unique< |
| testing::NiceMock<optimization_guide::MockSession>>(); |
| if (sampling_params_copy) { |
| EXPECT_EQ(config_params->sampling_params->top_k, |
| std::min(kTestMaxTopK, sampling_params_copy->top_k)); |
| EXPECT_EQ(config_params->sampling_params->temperature, |
| std::min(kTestMaxTemperature, |
| sampling_params_copy->temperature)); |
| } |
| |
| SetUpMockSession(*session); |
| |
| ON_CALL(*session, GetContextSizeInTokens(_, _)) |
| .WillByDefault([&](MultimodalMessageReadView request_metadata, |
| optimization_guide:: |
| OptimizationGuideModelSizeInTokenCallback |
| callback) { |
| std::move(callback).Run( |
| options.should_overflow_context |
| ? AITestUtils::GetFakeTokenLimits() |
| .max_context_tokens + |
| 1 |
| : 1); |
| }); |
| ON_CALL(*session, SetInput(_, _)) |
| .WillByDefault([&, initial = true]( |
| MultimodalMessage request_metadata, |
| SetInputCallback callback) mutable { |
| if (initial && !options.expected_context.empty()) { |
| initial = false; |
| EXPECT_THAT(ToString(request_metadata), |
| options.expected_context); |
| } else { |
| EXPECT_THAT( |
| ToString(request_metadata), |
| options.expected_context + options.expected_prompt); |
| } |
| }); |
| |
| EXPECT_CALL(*session, ExecuteModelWithResponseJsonSchema(_, _, _)) |
| .WillOnce( |
| [&](const google::protobuf::MessageLite& request_metadata, |
| const std::optional<std::string>& response_json_schema, |
| optimization_guide:: |
| OptimizationGuideModelExecutionResultStreamingCallback |
| callback) { |
| EXPECT_THAT(request_metadata.ByteSizeLong(), 0); |
| StreamResponse(callback); |
| }); |
| return session; |
| }) |
| .WillOnce([&](optimization_guide::ModelBasedCapabilityKey feature, |
| const std::optional< |
| optimization_guide::SessionConfigParams>& |
| config_params) { |
| auto session = std::make_unique< |
| testing::NiceMock<optimization_guide::MockSession>>(); |
| |
| SetUpMockSession(*session); |
| |
| ON_CALL(*session, SetInput(_, _)) |
| .WillByDefault([&](MultimodalMessage request_metadata, |
| SetInputCallback callback) { |
| EXPECT_THAT(ToString(request_metadata), |
| options.expected_cloned_context + |
| options.expected_prompt); |
| }); |
| EXPECT_CALL(*session, ExecuteModelWithResponseJsonSchema(_, _, _)) |
| .WillOnce( |
| [&](const google::protobuf::MessageLite& request_metadata, |
| const std::optional<std::string>& response_json_schema, |
| optimization_guide:: |
| OptimizationGuideModelExecutionResultStreamingCallback |
| callback) { |
| EXPECT_THAT(request_metadata.ByteSizeLong(), 0); |
| StreamResponse(callback); |
| }); |
| return session; |
| }); |
| } |
| |
| // Test session creation. |
| mojo::Remote<blink::mojom::AILanguageModel> mock_session; |
| AITestUtils::MockCreateLanguageModelClient |
| mock_create_language_model_client; |
| base::RunLoop creation_run_loop; |
| bool is_initial_prompts_or_system_prompt_set = |
| options.initial_prompts.size() > 0 || |
| (options.system_prompt.has_value() && |
| options.system_prompt->size() > 0); |
| if (options.should_use_supported_language) { |
| EXPECT_CALL(mock_create_language_model_client, OnResult(_, _)) |
| .WillOnce([&](mojo::PendingRemote<blink::mojom::AILanguageModel> |
| language_model, |
| blink::mojom::AILanguageModelInstanceInfoPtr info) { |
| EXPECT_TRUE(language_model); |
| EXPECT_EQ(info->input_quota, |
| AITestUtils::GetFakeTokenLimits().max_context_tokens); |
| if (is_initial_prompts_or_system_prompt_set) { |
| EXPECT_GT(info->input_usage, 0ul); |
| } else { |
| EXPECT_EQ(info->input_usage, 0ul); |
| } |
| mock_session = mojo::Remote<blink::mojom::AILanguageModel>( |
| std::move(language_model)); |
| creation_run_loop.Quit(); |
| }); |
| } else { |
| EXPECT_CALL(mock_create_language_model_client, OnError(_)) |
| .WillOnce([&](blink::mojom::AIManagerCreateClientError error) { |
| EXPECT_EQ( |
| error, |
| blink::mojom::AIManagerCreateClientError::kUnsupportedLanguage); |
| creation_run_loop.Quit(); |
| }); |
| } |
| |
| mojo::Remote<blink::mojom::AIManager> mock_remote = GetAIManagerRemote(); |
| |
| if (options.should_use_supported_language) { |
| EXPECT_EQ(GetAIManagerDownloadProgressObserversSize(), 0u); |
| AITestUtils::MockModelDownloadProgressMonitor mock_monitor; |
| base::RunLoop download_progress_run_loop; |
| EXPECT_CALL(mock_monitor, OnDownloadProgressUpdate(_, _)) |
| .WillOnce(testing::Invoke([&](uint64_t downloaded_bytes, |
| uint64_t total_bytes) { |
| EXPECT_EQ( |
| downloaded_bytes, |
| static_cast<uint64_t>(AIUtils::kNormalizedDownloadProgressMax)); |
| EXPECT_EQ( |
| total_bytes, |
| static_cast<uint64_t>(AIUtils::kNormalizedDownloadProgressMax)); |
| download_progress_run_loop.Quit(); |
| })); |
| |
| mock_remote->AddModelDownloadProgressObserver( |
| mock_monitor.BindNewPipeAndPassRemote()); |
| ASSERT_TRUE(base::test::RunUntil([this] { |
| return GetAIManagerDownloadProgressObserversSize() == 1u; |
| })); |
| |
| MockDownloadProgressUpdate(kTestModelDownloadSize, |
| kTestModelDownloadSize); |
| download_progress_run_loop.Run(); |
| } |
| |
| std::vector<blink::mojom::AILanguageModelExpectedInputPtr> expected_inputs; |
| if (!options.should_use_supported_language) { |
| expected_inputs.push_back(blink::mojom::AILanguageModelExpectedInput::New( |
| blink::mojom::AILanguageModelPromptType::kText, |
| AITestUtils::ToMojoLanguageCodes({"ja"}))); |
| } |
| |
| mock_remote->CreateLanguageModel( |
| mock_create_language_model_client.BindNewPipeAndPassRemote(), |
| blink::mojom::AILanguageModelCreateOptions::New( |
| std::move(options.sampling_params), options.system_prompt, |
| std::move(options.initial_prompts), std::move(expected_inputs))); |
| creation_run_loop.Run(); |
| |
| if (!options.should_use_supported_language) { |
| return; |
| } |
| |
| AITestUtils::MockModelStreamingResponder mock_responder; |
| |
| TestPromptCall(mock_session, options.prompt_input, |
| options.should_overflow_context); |
| |
| // Test session cloning. |
| mojo::Remote<blink::mojom::AILanguageModel> mock_cloned_session; |
| AITestUtils::MockCreateLanguageModelClient mock_clone_language_model_client; |
| base::RunLoop clone_run_loop; |
| EXPECT_CALL(mock_clone_language_model_client, OnResult(_, _)) |
| .WillOnce(testing::Invoke( |
| [&](mojo::PendingRemote<blink::mojom::AILanguageModel> |
| language_model, |
| blink::mojom::AILanguageModelInstanceInfoPtr info) { |
| EXPECT_TRUE(language_model); |
| mock_cloned_session = mojo::Remote<blink::mojom::AILanguageModel>( |
| std::move(language_model)); |
| clone_run_loop.Quit(); |
| })); |
| |
| mock_session->Fork( |
| mock_clone_language_model_client.BindNewPipeAndPassRemote()); |
| clone_run_loop.Run(); |
| |
| TestPromptCall(mock_cloned_session, options.prompt_input, |
| /*should_overflow_context=*/false); |
| } |
| |
| void TestSessionDestroy( |
| base::OnceCallback<void( |
| mojo::Remote<blink::mojom::AILanguageModel> mock_session, |
| AITestUtils::MockModelStreamingResponder& mock_responder)> callback) { |
| SetupMockOptimizationGuideKeyedService(); |
| base::OnceClosure size_in_token_callback; |
| EXPECT_CALL(*mock_optimization_guide_keyed_service_, StartSession(_, _)) |
| .WillOnce( |
| [&](optimization_guide::ModelBasedCapabilityKey feature, |
| const std::optional<optimization_guide::SessionConfigParams>& |
| config_params) { |
| auto session = std::make_unique< |
| testing::NiceMock<optimization_guide::MockSession>>(); |
| |
| SetUpMockSession(*session); |
| ON_CALL(*session, GetExecutionInputSizeInTokens(_, _)) |
| .WillByDefault( |
| [&](MultimodalMessageReadView request_metadata, |
| optimization_guide:: |
| OptimizationGuideModelSizeInTokenCallback |
| callback) { |
| size_in_token_callback = |
| base::BindOnce(std::move(callback), 1); |
| }); |
| |
| // The model should not be executed. |
| EXPECT_CALL(*session, ExecuteModelWithResponseJsonSchema(_, _, _)) |
| .Times(0); |
| return session; |
| }); |
| |
| mojo::Remote<blink::mojom::AILanguageModel> mock_session = |
| CreateMockSession(); |
| |
| AITestUtils::MockModelStreamingResponder mock_responder; |
| |
| base::RunLoop responder_run_loop; |
| |
| EXPECT_CALL(mock_responder, OnError(_)) |
| .WillOnce(testing::Invoke( |
| [&](blink::mojom::ModelStreamingResponseStatus status) { |
| EXPECT_EQ(status, blink::mojom::ModelStreamingResponseStatus:: |
| kErrorSessionDestroyed); |
| responder_run_loop.Quit(); |
| })); |
| |
| std::move(callback).Run(std::move(mock_session), mock_responder); |
| // Defers the `size_in_token_callback` until the testing callback which |
| // destroys the session is run. |
| if (size_in_token_callback) { |
| std::move(size_in_token_callback).Run(); |
| } |
| responder_run_loop.Run(); |
| } |
| |
| void TestSessionAddContext(bool should_overflow_context) { |
| SetupMockOptimizationGuideKeyedService(); |
| // Use `max_context_token / 2 + 1` to ensure the |
| // context overflow on the second prompt. |
| uint32_t mock_size_in_tokens = |
| should_overflow_context |
| ? 1 + AITestUtils::GetFakeTokenLimits().max_context_tokens / 2 |
| : 1; |
| |
| EXPECT_CALL(*mock_optimization_guide_keyed_service_, StartSession(_, _)) |
| .WillOnce([&](optimization_guide::ModelBasedCapabilityKey feature, |
| const std::optional< |
| optimization_guide::SessionConfigParams>& |
| config_params) { |
| auto session = std::make_unique< |
| testing::NiceMock<optimization_guide::MockSession>>(); |
| |
| SetUpMockSession(*session); |
| |
| ON_CALL(*session, GetContextSizeInTokens(_, _)) |
| .WillByDefault( |
| [&](MultimodalMessageReadView request_metadata, |
| optimization_guide:: |
| OptimizationGuideModelSizeInTokenCallback callback) { |
| std::move(callback).Run(mock_size_in_tokens); |
| }); |
| |
| ON_CALL(*session, GetExecutionInputSizeInTokens(_, _)) |
| .WillByDefault( |
| [&](MultimodalMessageReadView request_metadata, |
| optimization_guide:: |
| OptimizationGuideModelSizeInTokenCallback callback) { |
| std::move(callback).Run(mock_size_in_tokens); |
| }); |
| |
| EXPECT_CALL(*session, SetInput(_, _)) |
| .Times(2) |
| .WillOnce( |
| [&](MultimodalMessage request, SetInputCallback callback) { |
| EXPECT_THAT(ToString(request), "U: A\nM: "); |
| }) |
| .WillOnce([&](MultimodalMessage request, |
| SetInputCallback callback) { |
| // Prompt history should be omitted if it would overflow. |
| EXPECT_THAT(ToString(request), should_overflow_context |
| ? "U: B\nM: " |
| : "U: A\nM: OK\nU: B\nM: "); |
| }); |
| |
| EXPECT_CALL(*session, ExecuteModelWithResponseJsonSchema(_, _, _)) |
| .Times(2) |
| .WillRepeatedly( |
| [&](const google::protobuf::MessageLite& request_metadata, |
| const std::optional<std::string>& response_json_schema, |
| optimization_guide:: |
| OptimizationGuideModelExecutionResultStreamingCallback |
| callback) { |
| EXPECT_THAT(request_metadata.ByteSizeLong(), 0); |
| callback.Run(CreateExecutionResult( |
| "OK", /*is_complete=*/true, /*input_token_count=*/1u, |
| /*output_token_count=*/mock_size_in_tokens)); |
| }); |
| return session; |
| }); |
| |
| mojo::Remote<blink::mojom::AILanguageModel> mock_session = |
| CreateMockSession(); |
| |
| AITestUtils::MockModelStreamingResponder mock_responder_1; |
| AITestUtils::MockModelStreamingResponder mock_responder_2; |
| |
| base::RunLoop responder_run_loop_1; |
| base::RunLoop responder_run_loop_2; |
| |
| EXPECT_CALL(mock_responder_1, OnStreaming(_)) |
| .WillOnce(testing::Invoke( |
| [&](const std::string& text) { EXPECT_THAT(text, "OK"); })); |
| EXPECT_CALL(mock_responder_2, OnStreaming(_)) |
| .WillOnce(testing::Invoke( |
| [&](const std::string& text) { EXPECT_THAT(text, "OK"); })); |
| |
| EXPECT_CALL(mock_responder_2, OnQuotaOverflow()) |
| .Times(should_overflow_context ? 1 : 0); |
| |
| EXPECT_CALL(mock_responder_1, OnCompletion(_)) |
| .WillOnce(testing::Invoke( |
| [&](blink::mojom::ModelExecutionContextInfoPtr context_info) { |
| responder_run_loop_1.Quit(); |
| })); |
| EXPECT_CALL(mock_responder_2, OnCompletion(_)) |
| .WillOnce(testing::Invoke( |
| [&](blink::mojom::ModelExecutionContextInfoPtr context_info) { |
| responder_run_loop_2.Quit(); |
| })); |
| |
| mock_session->Prompt(MakeInput("A"), /*response_json_schema=*/std::nullopt, |
| mock_responder_1.BindNewPipeAndPassRemote()); |
| responder_run_loop_1.Run(); |
| mock_session->Prompt(MakeInput("B"), /*response_json_schema=*/std::nullopt, |
| mock_responder_2.BindNewPipeAndPassRemote()); |
| responder_run_loop_2.Run(); |
| } |
| |
| void SetUpMockSession( |
| testing::NiceMock<optimization_guide::MockSession>& session) { |
| ON_CALL(session, GetTokenLimits()) |
| .WillByDefault(AITestUtils::GetFakeTokenLimits); |
| |
| ON_CALL(session, GetOnDeviceFeatureMetadata()) |
| .WillByDefault(ReturnRef(GetPromptApiMetadata())); |
| ON_CALL(session, GetSamplingParams()).WillByDefault([]() { |
| // We don't need to use these value, so just mock it with defaults. |
| return optimization_guide::SamplingParams{ |
| /*top_k=*/kTestDefaultTopK, |
| /*temperature=*/kTestDefaultTemperature}; |
| }); |
| ON_CALL(session, GetSizeInTokens(_, _)) |
| .WillByDefault( |
| [](const std::string& text, |
| optimization_guide::OptimizationGuideModelSizeInTokenCallback |
| callback) { std::move(callback).Run(1); }); |
| ON_CALL(session, GetExecutionInputSizeInTokens(_, _)) |
| .WillByDefault( |
| [](MultimodalMessageReadView request_metadata, |
| optimization_guide::OptimizationGuideModelSizeInTokenCallback |
| callback) { std::move(callback).Run(1); }); |
| ON_CALL(session, GetContextSizeInTokens(_, _)) |
| .WillByDefault( |
| [](MultimodalMessageReadView request_metadata, |
| optimization_guide::OptimizationGuideModelSizeInTokenCallback |
| callback) { std::move(callback).Run(1); }); |
| } |
| |
| void StreamResponse( |
| optimization_guide::OptimizationGuideModelExecutionResultStreamingCallback |
| callback) { |
| std::string responses[3]; |
| std::string response = std::string(kTestResponse); |
| responses[0] = response.substr(0, 1); |
| responses[1] = response.substr(1); |
| responses[2] = ""; |
| callback.Run(CreateExecutionResult(responses[0], |
| /*is_complete=*/false, |
| /*input_token_count=*/1u, |
| /*output_token_count=*/1u)); |
| callback.Run(CreateExecutionResult(responses[1], |
| /*is_complete=*/false, |
| /*input_token_count=*/1u, |
| /*output_token_count=*/1u)); |
| callback.Run(CreateExecutionResult(responses[2], |
| /*is_complete=*/true, |
| /*input_token_count=*/1u, |
| /*output_token_count=*/1u)); |
| } |
| |
| void TestPromptCall(mojo::Remote<blink::mojom::AILanguageModel>& mock_session, |
| const std::string& prompt, |
| bool should_overflow_context) { |
| AITestUtils::MockModelStreamingResponder mock_responder; |
| |
| base::RunLoop responder_run_loop; |
| std::string response = std::string(kTestResponse); |
| EXPECT_CALL(mock_responder, OnStreaming(_)) |
| .Times(3) |
| .WillOnce(testing::Invoke([&](const std::string& text) { |
| EXPECT_THAT(text, response.substr(0, 1)); |
| })) |
| .WillOnce(testing::Invoke([&](const std::string& text) { |
| EXPECT_THAT(text, response.substr(1)); |
| })) |
| .WillOnce(testing::Invoke( |
| [&](const std::string& text) { EXPECT_THAT(text, ""); })); |
| |
| EXPECT_CALL(mock_responder, OnCompletion(_)) |
| .WillOnce(testing::Invoke( |
| [&](blink::mojom::ModelExecutionContextInfoPtr context_info) { |
| responder_run_loop.Quit(); |
| })); |
| |
| mock_session->Prompt(MakeInput(prompt), |
| /*response_json_schema=*/std::nullopt, |
| mock_responder.BindNewPipeAndPassRemote()); |
| responder_run_loop.Run(); |
| } |
| |
| mojo::Remote<blink::mojom::AILanguageModel> CreateMockSession() { |
| mojo::Remote<blink::mojom::AILanguageModel> mock_session; |
| AITestUtils::MockCreateLanguageModelClient |
| mock_create_language_model_client; |
| base::RunLoop creation_run_loop; |
| EXPECT_CALL(mock_create_language_model_client, OnResult(_, _)) |
| .WillOnce([&](mojo::PendingRemote<blink::mojom::AILanguageModel> |
| language_model, |
| blink::mojom::AILanguageModelInstanceInfoPtr info) { |
| EXPECT_TRUE(language_model); |
| mock_session = mojo::Remote<blink::mojom::AILanguageModel>( |
| std::move(language_model)); |
| creation_run_loop.Quit(); |
| }); |
| |
| mojo::Remote<blink::mojom::AIManager> mock_remote = GetAIManagerRemote(); |
| |
| mock_remote->CreateLanguageModel( |
| mock_create_language_model_client.BindNewPipeAndPassRemote(), |
| blink::mojom::AILanguageModelCreateOptions::New()); |
| creation_run_loop.Run(); |
| |
| return mock_session; |
| } |
| |
| private: |
| base::test::ScopedFeatureList scoped_feature_list_; |
| }; |
| |
| TEST_F(AILanguageModelTest, PromptDefaultSession) { |
| RunPromptTest(AILanguageModelTest::Options{ |
| .prompt_input = kTestPrompt, |
| .expected_prompt = kExpectedFormattedTestPrompt, |
| }); |
| } |
| |
| TEST_F(AILanguageModelTest, PromptSessionWithSamplingParams) { |
| RunPromptTest(AILanguageModelTest::Options{ |
| .sampling_params = blink::mojom::AILanguageModelSamplingParams::New( |
| /*top_k=*/kTestMaxTopK - 1, |
| /*temperature=*/kTestMaxTemperature * 0.9), |
| .prompt_input = kTestPrompt, |
| .expected_prompt = kExpectedFormattedTestPrompt, |
| }); |
| } |
| |
| TEST_F(AILanguageModelTest, PromptSessionWithSamplingParams_ExceedMaxTopK) { |
| RunPromptTest(AILanguageModelTest::Options{ |
| .sampling_params = blink::mojom::AILanguageModelSamplingParams::New( |
| /*top_k=*/kTestMaxTopK + 1, |
| /*temperature=*/kTestMaxTemperature * 0.9), |
| .prompt_input = kTestPrompt, |
| .expected_prompt = kExpectedFormattedTestPrompt, |
| }); |
| } |
| |
| TEST_F(AILanguageModelTest, |
| PromptSessionWithSamplingParams_ExceedMaxTemperature) { |
| RunPromptTest(AILanguageModelTest::Options{ |
| .sampling_params = blink::mojom::AILanguageModelSamplingParams::New( |
| /*top_k=*/kTestMaxTopK - 1, |
| /*temperature=*/kTestMaxTemperature + 0.1), |
| .prompt_input = kTestPrompt, |
| .expected_prompt = kExpectedFormattedTestPrompt, |
| }); |
| } |
| |
| TEST_F(AILanguageModelTest, PromptSessionWithSystemPrompt) { |
| RunPromptTest(AILanguageModelTest::Options{ |
| .system_prompt = kTestSystemPrompts, |
| .prompt_input = kTestPrompt, |
| .expected_context = kExpectedFormattedSystemPrompts, |
| .expected_cloned_context = |
| base::StrCat({kExpectedFormattedSystemPrompts, |
| kExpectedFormattedTestPrompt, kTestResponse, "\n"}), |
| .expected_prompt = kExpectedFormattedTestPrompt, |
| }); |
| } |
| |
| TEST_F(AILanguageModelTest, PromptSessionWithInitialPrompts) { |
| RunPromptTest(AILanguageModelTest::Options{ |
| .initial_prompts = GetTestInitialPrompts(), |
| .prompt_input = kTestPrompt, |
| .expected_context = kExpectedFormattedInitialPrompts, |
| .expected_cloned_context = |
| base::StrCat({kExpectedFormattedInitialPrompts, |
| kExpectedFormattedTestPrompt, kTestResponse, "\n"}), |
| .expected_prompt = kExpectedFormattedTestPrompt, |
| }); |
| } |
| |
| TEST_F(AILanguageModelTest, PromptSessionWithSystemPromptAndInitialPrompts) { |
| RunPromptTest(AILanguageModelTest::Options{ |
| .system_prompt = kTestSystemPrompts, |
| .initial_prompts = GetTestInitialPrompts(), |
| .prompt_input = kTestPrompt, |
| .expected_context = kExpectedFormattedSystemPromptAndInitialPrompts, |
| .expected_cloned_context = base::StrCat( |
| {kExpectedFormattedSystemPrompts, kExpectedFormattedInitialPrompts, |
| kExpectedFormattedTestPrompt, kTestResponse, "\n"}), |
| .expected_prompt = kExpectedFormattedTestPrompt, |
| }); |
| } |
| |
| TEST_F(AILanguageModelTest, PromptSessionWithPromptApiRequests) { |
| RunPromptTest(AILanguageModelTest::Options{ |
| .system_prompt = "Test system prompt", |
| .initial_prompts = GetTestInitialPrompts(), |
| .prompt_input = "Test prompt", |
| .expected_context = ("S: Test system prompt\n" |
| "U: How are you?\n" |
| "M: I'm fine, thank you, and you?\n" |
| "U: I'm fine too.\n"), |
| .expected_cloned_context = ("S: Test system prompt\n" |
| "U: How are you?\n" |
| "M: I'm fine, thank you, and you?\n" |
| "U: I'm fine too.\n" |
| "U: Test prompt\n" |
| "M: Test response\n"), |
| .expected_prompt = "U: Test prompt\nM: ", |
| }); |
| } |
| |
| TEST_F(AILanguageModelTest, PromptSessionWithQuotaOverflow) { |
| RunPromptTest({.prompt_input = kTestPrompt, |
| .expected_prompt = kExpectedFormattedTestPrompt, |
| .should_overflow_context = true}); |
| } |
| |
| TEST_F(AILanguageModelTest, PromptSessionWithUnsupportedLanguage) { |
| RunPromptTest({.should_use_supported_language = true}); |
| } |
| |
| // Tests that sending `Prompt()` after destroying the session won't make a real |
| // call to the model. |
| TEST_F(AILanguageModelTest, PromptAfterDestroy) { |
| TestSessionDestroy(base::BindOnce( |
| [](mojo::Remote<blink::mojom::AILanguageModel> mock_session, |
| AITestUtils::MockModelStreamingResponder& mock_responder) { |
| mock_session->Destroy(); |
| mock_session->Prompt(MakeInput(kTestPrompt), |
| /*response_json_schema=*/std::nullopt, |
| mock_responder.BindNewPipeAndPassRemote()); |
| })); |
| } |
| |
| // Tests that sending `Prompt()` right before destroying the session won't make |
| // a real call to the model. |
| TEST_F(AILanguageModelTest, PromptBeforeDestroy) { |
| TestSessionDestroy(base::BindOnce( |
| [](mojo::Remote<blink::mojom::AILanguageModel> mock_session, |
| AITestUtils::MockModelStreamingResponder& mock_responder) { |
| mock_session->Prompt(MakeInput(kTestPrompt), |
| /*response_json_schema=*/std::nullopt, |
| mock_responder.BindNewPipeAndPassRemote()); |
| mock_session->Destroy(); |
| })); |
| } |
| |
| // Tests that the session will call `AddContext()` from the second prompt when |
| // there is no context overflow. |
| TEST_F(AILanguageModelTest, PromptWithHistoryWithoutQuotaOverflow) { |
| TestSessionAddContext(/*should_overflow_context=*/false); |
| } |
| |
| // Tests that the session will not call `AddContext()` from the second prompt |
| // when there is context overflow. |
| TEST_F(AILanguageModelTest, PromptWithHistoryWithQuotaOverflow) { |
| TestSessionAddContext(/*should_overflow_context=*/true); |
| } |
| |
| TEST_F(AILanguageModelTest, CanCreate_IsLanguagesSupported) { |
| SetupMockOptimizationGuideKeyedService(); |
| EXPECT_CALL(*mock_optimization_guide_keyed_service_, |
| GetOnDeviceModelEligibility(_)) |
| .WillRepeatedly(testing::Return( |
| optimization_guide::OnDeviceModelEligibilityReason::kSuccess)); |
| |
| base::MockCallback<AIManager::CanCreateLanguageModelCallback> callback; |
| auto options = blink::mojom::AILanguageModelCreateOptions::New(); |
| options->expected_inputs = |
| std::vector<blink::mojom::AILanguageModelExpectedInputPtr>(); |
| options->expected_inputs->push_back( |
| blink::mojom::AILanguageModelExpectedInput::New( |
| blink::mojom::AILanguageModelPromptType::kText, |
| AITestUtils::ToMojoLanguageCodes({"en"}))); |
| EXPECT_CALL(callback, |
| Run(blink::mojom::ModelAvailabilityCheckResult::kAvailable)); |
| GetAIManagerInterface()->CanCreateLanguageModel(std::move(options), |
| callback.Get()); |
| } |
| |
| TEST_F(AILanguageModelTest, CanCreate_UnIsLanguagesSupported) { |
| SetupMockOptimizationGuideKeyedService(); |
| EXPECT_CALL(*mock_optimization_guide_keyed_service_, |
| GetOnDeviceModelEligibility(_)) |
| .WillRepeatedly(testing::Return( |
| optimization_guide::OnDeviceModelEligibilityReason::kSuccess)); |
| |
| base::MockCallback<AIManager::CanCreateLanguageModelCallback> callback; |
| EXPECT_CALL(callback, Run(blink::mojom::ModelAvailabilityCheckResult:: |
| kUnavailableUnsupportedLanguage)); |
| auto options = blink::mojom::AILanguageModelCreateOptions::New(); |
| options->expected_inputs = |
| std::vector<blink::mojom::AILanguageModelExpectedInputPtr>(); |
| options->expected_inputs->push_back( |
| blink::mojom::AILanguageModelExpectedInput::New( |
| blink::mojom::AILanguageModelPromptType::kText, |
| AITestUtils::ToMojoLanguageCodes({"ja"}))); |
| GetAIManagerInterface()->CanCreateLanguageModel(std::move(options), |
| callback.Get()); |
| } |
| |
| // Test Prompt() with image and audio input. |
| TEST_F(AILanguageModelTest, MultimodalInput) { |
| SetupMockOptimizationGuideKeyedService(); |
| EXPECT_CALL(*mock_optimization_guide_keyed_service_, StartSession(_, _)) |
| .WillOnce([&](optimization_guide::ModelBasedCapabilityKey feature, |
| const std::optional< |
| optimization_guide::SessionConfigParams>& |
| config_params) { |
| auto session = std::make_unique< |
| testing::NiceMock<optimization_guide::MockSession>>(); |
| SetUpMockSession(*session); |
| EXPECT_CALL(*session, GetCapabilities()) |
| .WillRepeatedly(Return(on_device_model::Capabilities{ |
| on_device_model::CapabilityFlags::kImageInput, |
| on_device_model::CapabilityFlags::kAudioInput})); |
| EXPECT_CALL(*session, SetInput(_, _)) |
| .WillOnce([&](MultimodalMessage request_metadata, |
| SetInputCallback callback) { |
| EXPECT_THAT(ToString(request_metadata), |
| "U: Test prompt\n" |
| "U: <image>\n" |
| "U: <audio>\n" |
| "M: "); |
| }); |
| EXPECT_CALL(*session, ExecuteModelWithResponseJsonSchema(_, _, _)) |
| .WillOnce( |
| [&](const google::protobuf::MessageLite& request_metadata, |
| const std::optional<std::string>& response_json_schema, |
| optimization_guide:: |
| OptimizationGuideModelExecutionResultStreamingCallback |
| callback) { |
| EXPECT_THAT(request_metadata.ByteSizeLong(), 0); |
| callback.Run( |
| CreateExecutionResult("OK", /*is_complete=*/true, |
| /*input_token_count=*/1u, |
| /*output_token_count=*/1u)); |
| }); |
| return session; |
| }); |
| mojo::Remote<blink::mojom::AILanguageModel> mock_session = |
| CreateMockSession(); |
| |
| AITestUtils::MockModelStreamingResponder mock_responder; |
| base::RunLoop run_loop; |
| EXPECT_CALL(mock_responder, OnStreaming("OK")).Times(1); |
| EXPECT_CALL(mock_responder, OnCompletion(_)) |
| .WillOnce(testing::InvokeWithoutArgs(&run_loop, &base::RunLoop::Quit)); |
| |
| std::vector<blink::mojom::AILanguageModelPromptPtr> input = |
| MakeInput(kTestPrompt); |
| input.push_back(blink::mojom::AILanguageModelPrompt::New( |
| Role::kUser, blink::mojom::AILanguageModelPromptContent::NewBitmap( |
| CreateTestBitmap(10, 10)))); |
| input.push_back(blink::mojom::AILanguageModelPrompt::New( |
| Role::kUser, |
| blink::mojom::AILanguageModelPromptContent::NewAudio(CreateTestAudio()))); |
| mock_session->Prompt(std::move(input), /*response_json_schema=*/std::nullopt, |
| mock_responder.BindNewPipeAndPassRemote()); |
| run_loop.Run(); |
| } |
| |
| // Tests `AILanguageModel::Context` creation without initial prompts. |
| TEST(AILanguageModelContextCreationTest, CreateContext_WithoutInitialPrompts) { |
| AILanguageModel::Context context(kTestMaxContextToken, {}); |
| EXPECT_FALSE(context.HasContextItem()); |
| } |
| |
| // Tests `AILanguageModel::Context` creation with valid initial prompts. |
| TEST(AILanguageModelContextCreationTest, |
| CreateContext_WithInitialPrompts_Normal) { |
| AILanguageModel::Context context( |
| kTestMaxContextToken, |
| SimpleContextItem("initial prompts\n", kTestInitialPromptsToken)); |
| EXPECT_TRUE(context.HasContextItem()); |
| } |
| |
| // Tests `AILanguageModel::Context` creation with initial prompts that exceeds |
| // the max token limit. |
| TEST(AILanguageModelContextCreationTest, |
| CreateContext_WithInitialPrompts_Overflow) { |
| EXPECT_DEATH_IF_SUPPORTED( |
| AILanguageModel::Context context( |
| kTestMaxContextToken, SimpleContextItem("long initial prompts\n", |
| kTestMaxContextToken + 1u)), |
| ""); |
| } |
| |
| // Tests the `AILanguageModel::Context` that's initialized with/without any |
| // initial prompt. |
| class AILanguageModelContextTest : public testing::Test, |
| public testing::WithParamInterface< |
| /*is_init_with_initial_prompts=*/bool> { |
| public: |
| bool IsInitializedWithInitialPrompts() { return GetParam(); } |
| |
| uint32_t GetMaxContextToken() { |
| return IsInitializedWithInitialPrompts() |
| ? kTestMaxContextToken - kTestInitialPromptsToken |
| : kTestMaxContextToken; |
| } |
| |
| std::string GetInitialPromptsPrefix() { |
| return IsInitializedWithInitialPrompts() ? "S: initial prompts\n" : ""; |
| } |
| |
| AILanguageModel::Context context_{ |
| kTestMaxContextToken, |
| IsInitializedWithInitialPrompts() |
| ? SimpleContextItem("initial prompts", kTestInitialPromptsToken) |
| : AILanguageModel::Context::ContextItem()}; |
| }; |
| |
| INSTANTIATE_TEST_SUITE_P(All, |
| AILanguageModelContextTest, |
| testing::Bool(), |
| [](const testing::TestParamInfo<bool>& info) { |
| return info.param ? "WithInitialPrompts" |
| : "WithoutInitialPrompts"; |
| }); |
| |
| // Tests `GetContextString()` and `HasContextItem()` when the context is empty. |
| TEST_P(AILanguageModelContextTest, TestContextOperation_Empty) { |
| EXPECT_EQ(GetContextString(context_), GetInitialPromptsPrefix()); |
| |
| if (IsInitializedWithInitialPrompts()) { |
| EXPECT_TRUE(context_.HasContextItem()); |
| } else { |
| EXPECT_FALSE(context_.HasContextItem()); |
| } |
| } |
| |
| // Tests `GetContextString()` and `HasContextItem()` when some items are added |
| // to the context. |
| TEST_P(AILanguageModelContextTest, TestContextOperation_NonEmpty) { |
| EXPECT_EQ(context_.AddContextItem(SimpleContextItem("test", 1u)), |
| AILanguageModel::Context::SpaceReservationResult::kSufficientSpace); |
| EXPECT_EQ(GetContextString(context_), |
| GetInitialPromptsPrefix() + "S: test\n"); |
| EXPECT_TRUE(context_.HasContextItem()); |
| |
| context_.AddContextItem(SimpleContextItem(" test again", 2u)); |
| EXPECT_EQ(GetContextString(context_), |
| GetInitialPromptsPrefix() + "S: test\nS: test again\n"); |
| EXPECT_TRUE(context_.HasContextItem()); |
| } |
| |
| // Tests `GetContextString()` and `HasContextItem()` when the items overflow. |
| TEST_P(AILanguageModelContextTest, TestContextOperation_Overflow) { |
| EXPECT_EQ(context_.AddContextItem(SimpleContextItem("test", 1u)), |
| AILanguageModel::Context::SpaceReservationResult::kSufficientSpace); |
| EXPECT_EQ(GetContextString(context_), |
| GetInitialPromptsPrefix() + "S: test\n"); |
| EXPECT_TRUE(context_.HasContextItem()); |
| |
| // Since the total number of tokens will exceed `kTestMaxContextToken`, the |
| // old item will be evicted. |
| EXPECT_EQ( |
| context_.AddContextItem( |
| SimpleContextItem("test long token", GetMaxContextToken())), |
| AILanguageModel::Context::SpaceReservationResult::kSpaceMadeAvailable); |
| EXPECT_EQ(GetContextString(context_), |
| GetInitialPromptsPrefix() + "S: test long token\n"); |
| EXPECT_TRUE(context_.HasContextItem()); |
| } |
| |
| // Tests `GetContextString()` and `HasContextItem()` when the items overflow on |
| // the first insertion. |
| TEST_P(AILanguageModelContextTest, TestContextOperation_OverflowOnFirstItem) { |
| EXPECT_EQ( |
| context_.AddContextItem( |
| SimpleContextItem("test very long token", GetMaxContextToken() + 1u)), |
| AILanguageModel::Context::SpaceReservationResult::kInsufficientSpace); |
| EXPECT_EQ(GetContextString(context_), GetInitialPromptsPrefix()); |
| if (IsInitializedWithInitialPrompts()) { |
| EXPECT_TRUE(context_.HasContextItem()); |
| } else { |
| EXPECT_FALSE(context_.HasContextItem()); |
| } |
| } |
| |
| TEST_F(AILanguageModelTest, Priority) { |
| SetupMockOptimizationGuideKeyedService(); |
| base::test::TestFuture<testing::NiceMock<optimization_guide::MockSession>*> |
| session_future; |
| EXPECT_CALL(*mock_optimization_guide_keyed_service_, StartSession(_, _)) |
| .WillOnce( |
| [&](optimization_guide::ModelBasedCapabilityKey feature, |
| const std::optional<optimization_guide::SessionConfigParams>& |
| config_params) { |
| auto session = std::make_unique< |
| testing::NiceMock<optimization_guide::MockSession>>(); |
| SetUpMockSession(*session); |
| EXPECT_CALL( |
| *session, |
| SetPriority(on_device_model::mojom::Priority::kForeground)); |
| session_future.SetValue(session.get()); |
| return session; |
| }); |
| auto session_remote = CreateMockSession(); |
| auto* session = session_future.Get(); |
| |
| EXPECT_CALL(*session, |
| SetPriority(on_device_model::mojom::Priority::kBackground)); |
| main_rfh()->GetRenderWidgetHost()->GetView()->Hide(); |
| |
| EXPECT_CALL(*session, |
| SetPriority(on_device_model::mojom::Priority::kForeground)); |
| main_rfh()->GetRenderWidgetHost()->GetView()->Show(); |
| } |
| |
| } // namespace |