| // Copyright 2024 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "chrome/browser/ai/ai_language_model.h" |
| |
| #include <optional> |
| |
| #include "base/functional/callback_helpers.h" |
| #include "base/no_destructor.h" |
| #include "base/notreached.h" |
| #include "base/strings/stringprintf.h" |
| #include "base/task/current_thread.h" |
| #include "base/test/scoped_feature_list.h" |
| #include "chrome/browser/ai/ai_test_utils.h" |
| #include "chrome/browser/ai/features.h" |
| #include "components/optimization_guide/core/mock_optimization_guide_model_executor.h" |
| #include "components/optimization_guide/core/optimization_guide_model_executor.h" |
| #include "components/optimization_guide/core/optimization_guide_proto_util.h" |
| #include "components/optimization_guide/proto/common_types.pb.h" |
| #include "components/optimization_guide/proto/features/prompt_api.pb.h" |
| #include "components/optimization_guide/proto/string_value.pb.h" |
| #include "testing/gmock/include/gmock/gmock.h" |
| #include "testing/gtest/include/gtest/gtest.h" |
| #include "third_party/blink/public/mojom/ai/ai_language_model.mojom.h" |
| #include "third_party/blink/public/mojom/ai/ai_manager.mojom-shared.h" |
| #include "third_party/blink/public/mojom/ai/model_download_progress_observer.mojom-forward.h" |
| |
| using testing::_; |
| using testing::ReturnRef; |
| using testing::Test; |
| using Role = blink::mojom::AILanguageModelInitialPromptRole; |
| |
| namespace { |
| |
| using optimization_guide::proto::PromptApiRequest; |
| using optimization_guide::proto::PromptApiRole; |
| |
| const uint32_t kTestMaxContextToken = 10u; |
| const uint32_t kTestInitialPromptsToken = 5u; |
| const uint32_t kDefaultTopK = 1u; |
| const uint32_t kOverrideMaxTopK = 5u; |
| const float kDefaultTemperature = 0.0; |
| const uint64_t kTestModelDownloadSize = 572u; |
| |
| const char kTestPrompt[] = "Test prompt"; |
| const char kExpectedFormattedTestPrompt[] = "User: Test prompt\nModel: "; |
| const char kTestSystemPrompts[] = "Test system prompt"; |
| const char kExpectedFormattedSystemPrompts[] = "Test system prompt\n"; |
| const char kTestResponse[] = "Test response"; |
| |
| const char kTestInitialPromptsUser1[] = "How are you?"; |
| const char kTestInitialPromptsSystem1[] = "I'm fine, thank you, and you?"; |
| const char kTestInitialPromptsUser2[] = "I'm fine too."; |
| const char kExpectedFormattedInitialPrompts[] = |
| "User: How are you?\nModel: I'm fine, thank you, and you?\nUser: I'm fine " |
| "too.\n"; |
| const char kExpectedFormattedSystemPromptAndInitialPrompts[] = |
| "Test system prompt\nUser: How are you?\nModel: I'm fine, thank you, and " |
| "you?\nUser: I'm fine too.\n"; |
| |
| std::vector<blink::mojom::AILanguageModelInitialPromptPtr> |
| GetTestInitialPrompts() { |
| auto create_initial_prompt = [](Role role, const char* content) { |
| return blink::mojom::AILanguageModelInitialPrompt::New(role, content); |
| }; |
| std::vector<blink::mojom::AILanguageModelInitialPromptPtr> initial_prompts{}; |
| initial_prompts.push_back( |
| create_initial_prompt(Role::kUser, kTestInitialPromptsUser1)); |
| initial_prompts.push_back( |
| create_initial_prompt(Role::kAssistant, kTestInitialPromptsSystem1)); |
| initial_prompts.push_back( |
| create_initial_prompt(Role::kUser, kTestInitialPromptsUser2)); |
| return initial_prompts; |
| } |
| |
| std::string GetContextString(AILanguageModel::Context& ctx) { |
| auto msg = ctx.MakeRequest(); |
| auto* v = static_cast<optimization_guide::proto::StringValue*>(msg.get()); |
| return v->value(); |
| } |
| |
| AILanguageModel::Context::ContextItem SimpleContextItem(std::string text, |
| uint32_t size) { |
| auto item = AILanguageModel::Context::ContextItem(); |
| item.tokens = size; |
| auto* prompt = item.prompts.Add(); |
| prompt->set_role(PromptApiRole::PROMPT_API_ROLE_SYSTEM); |
| prompt->set_content(text); |
| return item; |
| } |
| |
| const char* FormatPromptRole(PromptApiRole role) { |
| switch (role) { |
| case PromptApiRole::PROMPT_API_ROLE_SYSTEM: |
| return "S: "; |
| case PromptApiRole::PROMPT_API_ROLE_USER: |
| return "U: "; |
| case PromptApiRole::PROMPT_API_ROLE_ASSISTANT: |
| return "M: "; |
| default: |
| NOTREACHED(); |
| } |
| } |
| |
| std::string ToString(const PromptApiRequest& request) { |
| std::ostringstream oss; |
| for (const auto& prompt : request.initial_prompts()) { |
| oss << FormatPromptRole(prompt.role()) << prompt.content() << "\n"; |
| } |
| for (const auto& prompt : request.prompt_history()) { |
| oss << FormatPromptRole(prompt.role()) << prompt.content() << "\n"; |
| } |
| for (const auto& prompt : request.current_prompts()) { |
| oss << FormatPromptRole(prompt.role()) << prompt.content() << "\n"; |
| } |
| if (request.current_prompts_size() > 0) { |
| oss << FormatPromptRole(PromptApiRole::PROMPT_API_ROLE_ASSISTANT); |
| } |
| return oss.str(); |
| } |
| |
| std::string ToString(const google::protobuf::MessageLite& request_metadata) { |
| if (request_metadata.GetTypeName() == |
| "optimization_guide.proto.PromptApiRequest") { |
| return ToString(*static_cast<const PromptApiRequest*>(&request_metadata)); |
| } |
| if (request_metadata.GetTypeName() == |
| "optimization_guide.proto.StringValue") { |
| return static_cast<const optimization_guide::proto::StringValue*>( |
| &request_metadata) |
| ->value(); |
| } |
| return "unexpected type"; |
| } |
| |
| const optimization_guide::proto::Any& GetPromptApiMetadata( |
| bool use_prompt_api_proto, |
| bool is_streaming_chunk_by_chunk) { |
| static base::NoDestructor< |
| std::map<std::pair<bool, bool>, optimization_guide::proto::Any>> |
| metadata_map; |
| auto key = std::make_pair(use_prompt_api_proto, is_streaming_chunk_by_chunk); |
| |
| if (metadata_map->find(key) == metadata_map->end()) { |
| metadata_map->emplace(key, [use_prompt_api_proto, |
| is_streaming_chunk_by_chunk]() { |
| optimization_guide::proto::PromptApiMetadata metadata; |
| metadata.set_version( |
| use_prompt_api_proto ? AILanguageModel::kMinVersionUsingProto : 0); |
| metadata.set_is_streaming_chunk_by_chunk(is_streaming_chunk_by_chunk); |
| return optimization_guide::AnyWrapProto(metadata); |
| }()); |
| } |
| |
| return metadata_map->at(key); |
| } |
| |
| } // namespace |
| |
| class AILanguageModelTest : public AITestUtils::AITestBase, |
| public testing::WithParamInterface< |
| /*is_model_streaming_chunk_by_chunk=*/bool> { |
| public: |
| struct Options { |
| blink::mojom::AILanguageModelSamplingParamsPtr sampling_params = nullptr; |
| std::optional<std::string> system_prompt = std::nullopt; |
| std::vector<blink::mojom::AILanguageModelInitialPromptPtr> initial_prompts; |
| std::string prompt_input = kTestPrompt; |
| std::string expected_context = ""; |
| std::string expected_cloned_context = |
| base::StrCat({kExpectedFormattedTestPrompt, kTestResponse, "\n"}); |
| std::string expected_prompt = kExpectedFormattedTestPrompt; |
| bool use_prompt_api_proto = false; |
| bool should_overflow_context = false; |
| }; |
| |
| void SetUp() override { |
| AITestUtils::AITestBase::SetUp(); |
| scoped_feature_list_.InitWithFeaturesAndParameters( |
| {base::test::FeatureRefAndParams( |
| features::kAILanguageModelOverrideConfiguration, |
| {{"max_top_k", base::NumberToString(kOverrideMaxTopK)}})}, |
| {}); |
| } |
| |
| protected: |
| bool IsModelStreamingChunkByChunk() { return GetParam(); } |
| |
| // The helper function that creates a `AILanguageModel` and executes the |
| // prompt. |
| void RunPromptTest(Options options) { |
| blink::mojom::AILanguageModelSamplingParamsPtr sampling_params_copy; |
| if (options.sampling_params) { |
| sampling_params_copy = options.sampling_params->Clone(); |
| } |
| |
| // Set up mock service. |
| SetupMockOptimizationGuideKeyedService(); |
| // When the sampling param is not specified, `StartSession()` will run three |
| // times: |
| // 1. when getting the default sampling params. |
| // 2. when creating the session. |
| // 3. when cloning the session. |
| // Other wise, it will run twice as the first one is unnecessary. |
| auto& expectation = |
| EXPECT_CALL(*mock_optimization_guide_keyed_service_, StartSession(_, _)) |
| .Times(sampling_params_copy ? 2 : 3); |
| if (!sampling_params_copy) { |
| expectation.WillOnce( |
| [&](optimization_guide::ModelBasedCapabilityKey feature, |
| const std::optional<optimization_guide::SessionConfigParams>& |
| config_params) { |
| auto session = std::make_unique< |
| testing::NiceMock<optimization_guide::MockSession>>(); |
| SetUpMockSession(*session, options.use_prompt_api_proto, |
| IsModelStreamingChunkByChunk()); |
| ON_CALL(*session, GetSamplingParams()) |
| .WillByDefault( |
| [&]() -> const optimization_guide::SamplingParams { |
| return optimization_guide::SamplingParams{ |
| .top_k = kDefaultTopK, |
| .temperature = kDefaultTemperature}; |
| }); |
| |
| return session; |
| }); |
| } |
| expectation |
| .WillOnce([&](optimization_guide::ModelBasedCapabilityKey feature, |
| const std::optional< |
| optimization_guide::SessionConfigParams>& |
| config_params) { |
| auto session = std::make_unique< |
| testing::NiceMock<optimization_guide::MockSession>>(); |
| if (sampling_params_copy) { |
| EXPECT_EQ(config_params->sampling_params->top_k, |
| std::min(kOverrideMaxTopK, sampling_params_copy->top_k)); |
| EXPECT_EQ(config_params->sampling_params->temperature, |
| sampling_params_copy->temperature); |
| } |
| |
| SetUpMockSession(*session, options.use_prompt_api_proto, |
| IsModelStreamingChunkByChunk()); |
| |
| ON_CALL(*session, GetContextSizeInTokens(_, _)) |
| .WillByDefault( |
| [&](const google::protobuf::MessageLite& request_metadata, |
| optimization_guide:: |
| OptimizationGuideModelSizeInTokenCallback callback) { |
| std::move(callback).Run( |
| options.should_overflow_context |
| ? AITestUtils::GetFakeTokenLimits() |
| .max_context_tokens + |
| 1 |
| : 1); |
| }); |
| ON_CALL(*session, AddContext(_)) |
| .WillByDefault( |
| [&](const google::protobuf::MessageLite& request_metadata) { |
| EXPECT_THAT(ToString(request_metadata), |
| options.expected_context); |
| }); |
| EXPECT_CALL(*session, ExecuteModel(_, _)) |
| .WillOnce( |
| [&](const google::protobuf::MessageLite& request_metadata, |
| optimization_guide:: |
| OptimizationGuideModelExecutionResultStreamingCallback |
| callback) { |
| EXPECT_THAT(ToString(request_metadata), |
| options.expected_prompt); |
| callback.Run(CreateExecutionResult(kTestResponse, |
| /*is_complete=*/true)); |
| }); |
| return session; |
| }) |
| .WillOnce([&](optimization_guide::ModelBasedCapabilityKey feature, |
| const std::optional< |
| optimization_guide::SessionConfigParams>& |
| config_params) { |
| auto session = std::make_unique< |
| testing::NiceMock<optimization_guide::MockSession>>(); |
| |
| SetUpMockSession(*session, options.use_prompt_api_proto, |
| IsModelStreamingChunkByChunk()); |
| |
| ON_CALL(*session, AddContext(_)) |
| .WillByDefault( |
| [&](const google::protobuf::MessageLite& request_metadata) { |
| EXPECT_THAT(ToString(request_metadata), |
| options.expected_cloned_context); |
| }); |
| EXPECT_CALL(*session, ExecuteModel(_, _)) |
| .WillOnce( |
| [&](const google::protobuf::MessageLite& request_metadata, |
| optimization_guide:: |
| OptimizationGuideModelExecutionResultStreamingCallback |
| callback) { |
| EXPECT_THAT(ToString(request_metadata), |
| options.expected_prompt); |
| callback.Run(CreateExecutionResult(kTestResponse, |
| /*is_complete=*/true)); |
| }); |
| return session; |
| }); |
| |
| // Test session creation. |
| mojo::Remote<blink::mojom::AILanguageModel> mock_session; |
| AITestUtils::MockCreateLanguageModelClient |
| mock_create_language_model_client; |
| base::RunLoop creation_run_loop; |
| bool is_initial_prompts_or_system_prompt_set = |
| options.initial_prompts.size() > 0 || |
| (options.system_prompt.has_value() && |
| options.system_prompt->size() > 0); |
| EXPECT_CALL(mock_create_language_model_client, OnResult(_, _)) |
| .WillOnce([&](mojo::PendingRemote<blink::mojom::AILanguageModel> |
| language_model, |
| blink::mojom::AILanguageModelInfoPtr info) { |
| EXPECT_TRUE(language_model); |
| EXPECT_EQ(info->max_tokens, |
| AITestUtils::GetFakeTokenLimits().max_context_tokens); |
| if (is_initial_prompts_or_system_prompt_set) { |
| EXPECT_GT(info->current_tokens, 0ul); |
| } else { |
| EXPECT_EQ(info->current_tokens, 0ul); |
| } |
| mock_session = mojo::Remote<blink::mojom::AILanguageModel>( |
| std::move(language_model)); |
| creation_run_loop.Quit(); |
| }); |
| |
| mojo::Remote<blink::mojom::AIManager> mock_remote = GetAIManagerRemote(); |
| |
| EXPECT_EQ(GetAIManagerDownloadProgressObserversSize(), 0u); |
| AITestUtils::MockModelDownloadProgressMonitor mock_monitor; |
| base::RunLoop download_progress_run_loop; |
| EXPECT_CALL(mock_monitor, OnDownloadProgressUpdate(_, _)) |
| .WillOnce(testing::Invoke( |
| [&](uint64_t downloaded_bytes, uint64_t total_bytes) { |
| EXPECT_EQ(downloaded_bytes, kTestModelDownloadSize); |
| EXPECT_EQ(total_bytes, kTestModelDownloadSize); |
| download_progress_run_loop.Quit(); |
| })); |
| |
| mock_remote->AddModelDownloadProgressObserver( |
| mock_monitor.BindNewPipeAndPassRemote()); |
| ASSERT_TRUE(base::test::RunUntil( |
| [this] { return GetAIManagerDownloadProgressObserversSize() == 1u; })); |
| |
| MockDownloadProgressUpdate(kTestModelDownloadSize, kTestModelDownloadSize); |
| download_progress_run_loop.Run(); |
| |
| mock_remote->CreateLanguageModel( |
| mock_create_language_model_client.BindNewPipeAndPassRemote(), |
| blink::mojom::AILanguageModelCreateOptions::New( |
| std::move(options.sampling_params), options.system_prompt, |
| std::move(options.initial_prompts))); |
| creation_run_loop.Run(); |
| |
| AITestUtils::MockModelStreamingResponder mock_responder; |
| |
| TestPromptCall(mock_session, options.prompt_input, |
| options.should_overflow_context); |
| |
| // Test session cloning. |
| mojo::Remote<blink::mojom::AILanguageModel> mock_cloned_session; |
| AITestUtils::MockCreateLanguageModelClient mock_clone_language_model_client; |
| base::RunLoop clone_run_loop; |
| EXPECT_CALL(mock_clone_language_model_client, OnResult(_, _)) |
| .WillOnce(testing::Invoke( |
| [&](mojo::PendingRemote<blink::mojom::AILanguageModel> |
| language_model, |
| blink::mojom::AILanguageModelInfoPtr info) { |
| EXPECT_TRUE(language_model); |
| mock_cloned_session = mojo::Remote<blink::mojom::AILanguageModel>( |
| std::move(language_model)); |
| clone_run_loop.Quit(); |
| })); |
| |
| mock_session->Fork( |
| mock_clone_language_model_client.BindNewPipeAndPassRemote()); |
| clone_run_loop.Run(); |
| |
| TestPromptCall(mock_cloned_session, options.prompt_input, |
| /*should_overflow_context=*/false); |
| } |
| |
| private: |
| optimization_guide::OptimizationGuideModelStreamingExecutionResult |
| CreateExecutionResult(const std::string& output, bool is_complete) { |
| optimization_guide::proto::StringValue response; |
| response.set_value(output); |
| return optimization_guide::OptimizationGuideModelStreamingExecutionResult( |
| optimization_guide::StreamingResponse{ |
| .response = optimization_guide::AnyWrapProto(response), |
| .is_complete = is_complete, |
| }, |
| /*provided_by_on_device=*/true); |
| } |
| |
| void SetUpMockSession( |
| testing::NiceMock<optimization_guide::MockSession>& session, |
| bool use_prompt_api_proto, |
| bool is_streaming_chunk_by_chunk) { |
| ON_CALL(session, GetTokenLimits()) |
| .WillByDefault(AITestUtils::GetFakeTokenLimits); |
| |
| ON_CALL(session, GetOnDeviceFeatureMetadata()) |
| .WillByDefault(ReturnRef(GetPromptApiMetadata( |
| use_prompt_api_proto, is_streaming_chunk_by_chunk))); |
| ON_CALL(session, GetSamplingParams()).WillByDefault([]() { |
| // We don't need to use these value, so just mock it with defaults. |
| return optimization_guide::SamplingParams{ |
| /*top_k=*/kDefaultTopK, |
| /*temperature=*/kDefaultTemperature}; |
| }); |
| ON_CALL(session, GetSizeInTokens(_, _)) |
| .WillByDefault( |
| [](const std::string& text, |
| optimization_guide::OptimizationGuideModelSizeInTokenCallback |
| callback) { std::move(callback).Run(text.size()); }); |
| ON_CALL(session, GetExecutionInputSizeInTokens(_, _)) |
| .WillByDefault( |
| [](const google::protobuf::MessageLite& request_metadata, |
| optimization_guide::OptimizationGuideModelSizeInTokenCallback |
| callback) { |
| std::move(callback).Run(ToString(request_metadata).size()); |
| }); |
| ON_CALL(session, GetContextSizeInTokens(_, _)) |
| .WillByDefault( |
| [](const google::protobuf::MessageLite& request_metadata, |
| optimization_guide::OptimizationGuideModelSizeInTokenCallback |
| callback) { |
| std::move(callback).Run(ToString(request_metadata).size()); |
| }); |
| } |
| |
| void TestPromptCall(mojo::Remote<blink::mojom::AILanguageModel>& mock_session, |
| std::string& prompt, |
| bool should_overflow_context) { |
| AITestUtils::MockModelStreamingResponder mock_responder; |
| |
| base::RunLoop responder_run_loop; |
| EXPECT_CALL(mock_responder, OnStreaming(_)) |
| .WillOnce(testing::Invoke([&](const std::string& text) { |
| EXPECT_THAT(text, kTestResponse); |
| })); |
| |
| EXPECT_CALL(mock_responder, OnCompletion(_)) |
| .WillOnce(testing::Invoke( |
| [&](blink::mojom::ModelExecutionContextInfoPtr context_info) { |
| responder_run_loop.Quit(); |
| })); |
| |
| mock_session->Prompt(prompt, mock_responder.BindNewPipeAndPassRemote()); |
| responder_run_loop.Run(); |
| } |
| |
| base::test::ScopedFeatureList scoped_feature_list_; |
| }; |
| |
| INSTANTIATE_TEST_SUITE_P(All, |
| AILanguageModelTest, |
| testing::Bool(), |
| [](const testing::TestParamInfo<bool>& info) { |
| return info.param |
| ? "IsModelStreamingChunkByChunk" |
| : "IsModelStreamingWithCurrentResponse"; |
| }); |
| |
| TEST_P(AILanguageModelTest, PromptDefaultSession) { |
| RunPromptTest(AILanguageModelTest::Options{ |
| .prompt_input = kTestPrompt, |
| .expected_prompt = kExpectedFormattedTestPrompt, |
| }); |
| } |
| |
| TEST_P(AILanguageModelTest, PromptSessionWithSamplingParams) { |
| RunPromptTest(AILanguageModelTest::Options{ |
| .sampling_params = blink::mojom::AILanguageModelSamplingParams::New( |
| /*top_k=*/kOverrideMaxTopK - 1, /*temperature=*/0.6), |
| .prompt_input = kTestPrompt, |
| .expected_prompt = kExpectedFormattedTestPrompt, |
| }); |
| } |
| |
| TEST_P(AILanguageModelTest, PromptSessionWithSamplingParams_ExceedMaxTopK) { |
| RunPromptTest(AILanguageModelTest::Options{ |
| .sampling_params = blink::mojom::AILanguageModelSamplingParams::New( |
| /*top_k=*/kOverrideMaxTopK + 1, /*temperature=*/0.6), |
| .prompt_input = kTestPrompt, |
| .expected_prompt = kExpectedFormattedTestPrompt, |
| }); |
| } |
| |
| TEST_P(AILanguageModelTest, PromptSessionWithSystemPrompt) { |
| RunPromptTest(AILanguageModelTest::Options{ |
| .system_prompt = kTestSystemPrompts, |
| .prompt_input = kTestPrompt, |
| .expected_context = kExpectedFormattedSystemPrompts, |
| .expected_cloned_context = |
| base::StrCat({kExpectedFormattedSystemPrompts, |
| kExpectedFormattedTestPrompt, kTestResponse, "\n"}), |
| .expected_prompt = kExpectedFormattedTestPrompt, |
| }); |
| } |
| |
| TEST_P(AILanguageModelTest, PromptSessionWithInitialPrompts) { |
| RunPromptTest(AILanguageModelTest::Options{ |
| .initial_prompts = GetTestInitialPrompts(), |
| .prompt_input = kTestPrompt, |
| .expected_context = kExpectedFormattedInitialPrompts, |
| .expected_cloned_context = |
| base::StrCat({kExpectedFormattedInitialPrompts, |
| kExpectedFormattedTestPrompt, kTestResponse, "\n"}), |
| .expected_prompt = kExpectedFormattedTestPrompt, |
| }); |
| } |
| |
| TEST_P(AILanguageModelTest, PromptSessionWithSystemPromptAndInitialPrompts) { |
| RunPromptTest(AILanguageModelTest::Options{ |
| .system_prompt = kTestSystemPrompts, |
| .initial_prompts = GetTestInitialPrompts(), |
| .prompt_input = kTestPrompt, |
| .expected_context = kExpectedFormattedSystemPromptAndInitialPrompts, |
| .expected_cloned_context = base::StrCat( |
| {kExpectedFormattedSystemPrompts, kExpectedFormattedInitialPrompts, |
| kExpectedFormattedTestPrompt, kTestResponse, "\n"}), |
| .expected_prompt = kExpectedFormattedTestPrompt, |
| }); |
| } |
| |
| TEST_P(AILanguageModelTest, PromptSessionWithPromptApiRequests) { |
| RunPromptTest(AILanguageModelTest::Options{ |
| .system_prompt = "Test system prompt", |
| .initial_prompts = GetTestInitialPrompts(), |
| .prompt_input = "Test prompt", |
| .expected_context = ("S: Test system prompt\n" |
| "U: How are you?\n" |
| "M: I'm fine, thank you, and you?\n" |
| "U: I'm fine too.\n"), |
| .expected_cloned_context = ("S: Test system prompt\n" |
| "U: How are you?\n" |
| "M: I'm fine, thank you, and you?\n" |
| "U: I'm fine too.\n" |
| "U: Test prompt\n" |
| "M: Test response\n"), |
| .expected_prompt = "U: Test prompt\nM: ", |
| .use_prompt_api_proto = true, |
| }); |
| } |
| |
| TEST_P(AILanguageModelTest, PromptSessionWithContextOverflow) { |
| RunPromptTest({.prompt_input = kTestPrompt, |
| .expected_prompt = kExpectedFormattedTestPrompt, |
| .should_overflow_context = true}); |
| } |
| |
| // Tests `AILanguageModel::Context` creation without initial prompts. |
| TEST(AILanguageModelContextCreationTest, CreateContext_WithoutInitialPrompts) { |
| AILanguageModel::Context context(kTestMaxContextToken, {}, |
| /*use_prompt_api_request*/ false); |
| EXPECT_FALSE(context.HasContextItem()); |
| } |
| |
| // Tests `AILanguageModel::Context` creation with valid initial prompts. |
| TEST(AILanguageModelContextCreationTest, |
| CreateContext_WithInitialPrompts_Normal) { |
| AILanguageModel::Context context( |
| kTestMaxContextToken, |
| SimpleContextItem("initial prompts\n", kTestInitialPromptsToken), |
| /*use_prompt_api_request*/ false); |
| EXPECT_TRUE(context.HasContextItem()); |
| } |
| |
| // Tests `AILanguageModel::Context` creation with initial prompts that exceeds |
| // the max token limit. |
| TEST(AILanguageModelContextCreationTest, |
| CreateContext_WithInitialPrompts_Overflow) { |
| EXPECT_DEATH_IF_SUPPORTED(AILanguageModel::Context context( |
| kTestMaxContextToken, |
| SimpleContextItem("long initial prompts\n", |
| kTestMaxContextToken + 1u), |
| /*use_prompt_api_request*/ false), |
| ""); |
| } |
| |
| // Tests the `AILanguageModel::Context` that's initialized with/without any |
| // initial prompt. |
| class AILanguageModelContextTest : public testing::Test, |
| public testing::WithParamInterface< |
| /*is_init_with_initial_prompts=*/bool> { |
| public: |
| bool IsInitializedWithInitialPrompts() { return GetParam(); } |
| |
| uint32_t GetMaxContextToken() { |
| return IsInitializedWithInitialPrompts() |
| ? kTestMaxContextToken - kTestInitialPromptsToken |
| : kTestMaxContextToken; |
| } |
| |
| std::string GetInitialPromptsPrefix() { |
| return IsInitializedWithInitialPrompts() ? "initial prompts\n" : ""; |
| } |
| |
| AILanguageModel::Context context_{ |
| kTestMaxContextToken, |
| IsInitializedWithInitialPrompts() |
| ? SimpleContextItem("initial prompts", kTestInitialPromptsToken) |
| : AILanguageModel::Context::ContextItem(), |
| /*use_prompt_api_request*/ false}; |
| }; |
| |
| INSTANTIATE_TEST_SUITE_P(All, |
| AILanguageModelContextTest, |
| testing::Bool(), |
| [](const testing::TestParamInfo<bool>& info) { |
| return info.param ? "WithInitialPrompts" |
| : "WithoutInitialPrompts"; |
| }); |
| |
| // Tests `GetContextString()` and `HasContextItem()` when the context is empty. |
| TEST_P(AILanguageModelContextTest, TestContextOperation_Empty) { |
| EXPECT_EQ(GetContextString(context_), GetInitialPromptsPrefix()); |
| |
| if (IsInitializedWithInitialPrompts()) { |
| EXPECT_TRUE(context_.HasContextItem()); |
| } else { |
| EXPECT_FALSE(context_.HasContextItem()); |
| } |
| } |
| |
| // Tests `GetContextString()` and `HasContextItem()` when some items are added |
| // to the context. |
| TEST_P(AILanguageModelContextTest, TestContextOperation_NonEmpty) { |
| context_.AddContextItem(SimpleContextItem("test", 1u)); |
| EXPECT_EQ(GetContextString(context_), GetInitialPromptsPrefix() + "test\n"); |
| EXPECT_TRUE(context_.HasContextItem()); |
| |
| context_.AddContextItem(SimpleContextItem(" test again", 2u)); |
| EXPECT_EQ(GetContextString(context_), |
| GetInitialPromptsPrefix() + "test\n test again\n"); |
| EXPECT_TRUE(context_.HasContextItem()); |
| } |
| |
| // Tests `GetContextString()` and `HasContextItem()` when the items overflow. |
| TEST_P(AILanguageModelContextTest, TestContextOperation_Overflow) { |
| context_.AddContextItem(SimpleContextItem("test", 1u)); |
| EXPECT_EQ(GetContextString(context_), GetInitialPromptsPrefix() + "test\n"); |
| EXPECT_TRUE(context_.HasContextItem()); |
| |
| // Since the total number of tokens will exceed `kTestMaxContextToken`, the |
| // old item will be evicted. |
| context_.AddContextItem( |
| SimpleContextItem("test long token", GetMaxContextToken())); |
| EXPECT_EQ(GetContextString(context_), |
| GetInitialPromptsPrefix() + "test long token\n"); |
| EXPECT_TRUE(context_.HasContextItem()); |
| } |
| |
| // Tests `GetContextString()` and `HasContextItem()` when the items overflow on |
| // the first insertion. |
| TEST_P(AILanguageModelContextTest, TestContextOperation_OverflowOnFirstItem) { |
| context_.AddContextItem( |
| SimpleContextItem("test very long token", GetMaxContextToken() + 1u)); |
| EXPECT_EQ(GetContextString(context_), GetInitialPromptsPrefix()); |
| if (IsInitializedWithInitialPrompts()) { |
| EXPECT_TRUE(context_.HasContextItem()); |
| } else { |
| EXPECT_FALSE(context_.HasContextItem()); |
| } |
| } |