blob: 89482e6f457a65bcbd7bb2770c74a9ac3b5eb47d [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/ai/ai_language_model.h"
#include <optional>
#include <vector>
#include "base/functional/bind.h"
#include "base/functional/callback_forward.h"
#include "base/functional/callback_helpers.h"
#include "base/no_destructor.h"
#include "base/notreached.h"
#include "base/strings/stringprintf.h"
#include "base/task/current_thread.h"
#include "base/test/mock_callback.h"
#include "base/test/scoped_feature_list.h"
#include "chrome/browser/ai/ai_test_utils.h"
#include "chrome/browser/ai/features.h"
#include "components/optimization_guide/core/mock_optimization_guide_model_executor.h"
#include "components/optimization_guide/core/optimization_guide_model_executor.h"
#include "components/optimization_guide/core/optimization_guide_proto_util.h"
#include "components/optimization_guide/proto/common_types.pb.h"
#include "components/optimization_guide/proto/features/prompt_api.pb.h"
#include "components/optimization_guide/proto/on_device_model_execution_config.pb.h"
#include "components/optimization_guide/proto/string_value.pb.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/blink/public/mojom/ai/ai_common.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_language_model.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_manager.mojom-shared.h"
#include "third_party/blink/public/mojom/ai/model_download_progress_observer.mojom-forward.h"
#include "third_party/blink/public/mojom/ai/model_streaming_responder.mojom-shared.h"
using testing::_;
using testing::ReturnRef;
using testing::Test;
using Role = blink::mojom::AILanguageModelInitialPromptRole;
namespace {
using optimization_guide::proto::PromptApiRequest;
using optimization_guide::proto::PromptApiRole;
constexpr uint32_t kTestMaxContextToken = 10u;
constexpr uint32_t kTestInitialPromptsToken = 5u;
constexpr uint32_t kTestDefaultTopK = 1u;
constexpr float kTestDefaultTemperature = 0.3;
constexpr uint32_t kTestMaxTopK = 5u;
constexpr float kTestMaxTemperature = 1.5;
constexpr uint64_t kTestModelDownloadSize = 572u;
static_assert(kTestDefaultTopK <= kTestMaxTopK);
static_assert(kTestDefaultTemperature <= kTestMaxTemperature);
const char kTestPrompt[] = "Test prompt";
const char kExpectedFormattedTestPrompt[] = "U: Test prompt\nM: ";
const char kTestSystemPrompts[] = "Test system prompt";
const char kExpectedFormattedSystemPrompts[] = "S: Test system prompt\n";
const char kTestResponse[] = "Test response";
const char kTestInitialPromptsUser1[] = "How are you?";
const char kTestInitialPromptsSystem1[] = "I'm fine, thank you, and you?";
const char kTestInitialPromptsUser2[] = "I'm fine too.";
const char kExpectedFormattedInitialPrompts[] =
("U: How are you?\n"
"M: I'm fine, thank you, and you?\n"
"U: I'm fine too.\n");
const char kExpectedFormattedSystemPromptAndInitialPrompts[] =
("S: Test system prompt\n"
"U: How are you?\n"
"M: I'm fine, thank you, and you?\n"
"U: I'm fine too.\n");
std::vector<blink::mojom::AILanguageModelInitialPromptPtr>
GetTestInitialPrompts() {
auto create_initial_prompt = [](Role role, const char* content) {
return blink::mojom::AILanguageModelInitialPrompt::New(role, content);
};
std::vector<blink::mojom::AILanguageModelInitialPromptPtr> initial_prompts{};
initial_prompts.push_back(
create_initial_prompt(Role::kUser, kTestInitialPromptsUser1));
initial_prompts.push_back(
create_initial_prompt(Role::kAssistant, kTestInitialPromptsSystem1));
initial_prompts.push_back(
create_initial_prompt(Role::kUser, kTestInitialPromptsUser2));
return initial_prompts;
}
AILanguageModel::Context::ContextItem SimpleContextItem(std::string text,
uint32_t size) {
auto item = AILanguageModel::Context::ContextItem();
item.tokens = size;
auto* prompt = item.prompts.Add();
prompt->set_role(PromptApiRole::PROMPT_API_ROLE_SYSTEM);
prompt->set_content(text);
return item;
}
const char* FormatPromptRole(PromptApiRole role) {
switch (role) {
case PromptApiRole::PROMPT_API_ROLE_SYSTEM:
return "S: ";
case PromptApiRole::PROMPT_API_ROLE_USER:
return "U: ";
case PromptApiRole::PROMPT_API_ROLE_ASSISTANT:
return "M: ";
default:
NOTREACHED();
}
}
std::string ToString(const PromptApiRequest& request) {
std::ostringstream oss;
for (const auto& prompt : request.initial_prompts()) {
oss << FormatPromptRole(prompt.role()) << prompt.content() << "\n";
}
for (const auto& prompt : request.prompt_history()) {
oss << FormatPromptRole(prompt.role()) << prompt.content() << "\n";
}
for (const auto& prompt : request.current_prompts()) {
oss << FormatPromptRole(prompt.role()) << prompt.content() << "\n";
}
if (request.current_prompts_size() > 0) {
oss << FormatPromptRole(PromptApiRole::PROMPT_API_ROLE_ASSISTANT);
}
return oss.str();
}
std::string ToString(const google::protobuf::MessageLite& request_metadata) {
if (request_metadata.GetTypeName() ==
"optimization_guide.proto.PromptApiRequest") {
return ToString(*static_cast<const PromptApiRequest*>(&request_metadata));
}
if (request_metadata.GetTypeName() ==
"optimization_guide.proto.StringValue") {
return static_cast<const optimization_guide::proto::StringValue*>(
&request_metadata)
->value();
}
return "unexpected type";
}
std::string GetContextString(AILanguageModel::Context& ctx) {
return ToString(*ctx.MakeRequest());
}
const optimization_guide::proto::Any& GetPromptApiMetadata(
bool use_prompt_api_proto,
bool is_streaming_chunk_by_chunk) {
static base::NoDestructor<
std::map<std::pair<bool, bool>, optimization_guide::proto::Any>>
metadata_map;
auto key = std::make_pair(use_prompt_api_proto, is_streaming_chunk_by_chunk);
if (metadata_map->find(key) == metadata_map->end()) {
metadata_map->emplace(key, [use_prompt_api_proto,
is_streaming_chunk_by_chunk]() {
optimization_guide::proto::PromptApiMetadata metadata;
metadata.set_version(
use_prompt_api_proto ? AILanguageModel::kMinVersionUsingProto : 0);
metadata.set_is_streaming_chunk_by_chunk(is_streaming_chunk_by_chunk);
return optimization_guide::AnyWrapProto(metadata);
}());
}
return metadata_map->at(key);
}
} // namespace
class AILanguageModelTest : public AITestUtils::AITestBase,
public testing::WithParamInterface<testing::tuple<
/*is_model_streaming_chunk_by_chunk=*/bool,
/*is_api_streaming_chunk_by_chunk=*/bool>> {
public:
struct Options {
blink::mojom::AILanguageModelSamplingParamsPtr sampling_params = nullptr;
std::optional<std::string> system_prompt = std::nullopt;
std::vector<blink::mojom::AILanguageModelInitialPromptPtr> initial_prompts;
std::string prompt_input = kTestPrompt;
std::string expected_context = "";
std::string expected_cloned_context =
base::StrCat({kExpectedFormattedTestPrompt, kTestResponse, "\n"});
std::string expected_prompt = kExpectedFormattedTestPrompt;
bool use_prompt_api_proto = false;
bool should_overflow_context = false;
bool should_use_supported_language = true;
};
void SetUp() override {
std::vector<base::test::FeatureRefAndParams> enabled_features{};
std::vector<base::test::FeatureRef> disabled_features{};
if (IsAPIStreamingChunkByChunk()) {
disabled_features.push_back(
features::kAILanguageModelForceStreamingFullResponse);
} else {
enabled_features.push_back(base::test::FeatureRefAndParams(
features::kAILanguageModelForceStreamingFullResponse, {}));
}
AITestUtils::AITestBase::SetUp();
scoped_feature_list_.InitWithFeaturesAndParameters(enabled_features,
disabled_features);
}
protected:
void SetupMockOptimizationGuideKeyedService() override {
AITestUtils::AITestBase::SetupMockOptimizationGuideKeyedService();
ON_CALL(*mock_optimization_guide_keyed_service_, GetSamplingParamsConfig(_))
.WillByDefault([](optimization_guide::ModelBasedCapabilityKey feature) {
return optimization_guide::SamplingParamsConfig{
.default_top_k = kTestDefaultTopK,
.default_temperature = kTestDefaultTemperature};
});
ON_CALL(*mock_optimization_guide_keyed_service_, GetFeatureMetadata(_))
.WillByDefault([](optimization_guide::ModelBasedCapabilityKey feature) {
optimization_guide::proto::SamplingParams sampling_params;
sampling_params.set_top_k(kTestMaxTopK);
sampling_params.set_temperature(kTestMaxTemperature);
optimization_guide::proto::PromptApiMetadata metadata;
*metadata.mutable_max_sampling_params() = sampling_params;
optimization_guide::proto::Any any;
any.set_value(metadata.SerializeAsString());
any.set_type_url("type.googleapis.com/" + metadata.GetTypeName());
return any;
});
}
bool IsModelStreamingChunkByChunk() { return std::get<0>(GetParam()); }
bool IsAPIStreamingChunkByChunk() { return std::get<1>(GetParam()); }
// The helper function that creates a `AILanguageModel` and executes the
// prompt.
void RunPromptTest(Options options) {
blink::mojom::AILanguageModelSamplingParamsPtr sampling_params_copy;
if (options.sampling_params) {
sampling_params_copy = options.sampling_params->Clone();
}
// Set up mock service.
SetupMockOptimizationGuideKeyedService();
if (options.should_use_supported_language) {
// `StartSession()` will run twice when creating and cloning the session.
EXPECT_CALL(*mock_optimization_guide_keyed_service_, StartSession(_, _))
.Times(2)
.WillOnce([&](optimization_guide::ModelBasedCapabilityKey feature,
const std::optional<
optimization_guide::SessionConfigParams>&
config_params) {
auto session = std::make_unique<
testing::NiceMock<optimization_guide::MockSession>>();
if (sampling_params_copy) {
EXPECT_EQ(config_params->sampling_params->top_k,
std::min(kTestMaxTopK, sampling_params_copy->top_k));
EXPECT_EQ(config_params->sampling_params->temperature,
std::min(kTestMaxTemperature,
sampling_params_copy->temperature));
}
SetUpMockSession(*session, options.use_prompt_api_proto,
IsModelStreamingChunkByChunk());
ON_CALL(*session, GetContextSizeInTokens(_, _))
.WillByDefault(
[&](const google::protobuf::MessageLite& request_metadata,
optimization_guide::
OptimizationGuideModelSizeInTokenCallback
callback) {
std::move(callback).Run(
options.should_overflow_context
? AITestUtils::GetFakeTokenLimits()
.max_context_tokens +
1
: 1);
});
ON_CALL(*session, AddContext(_))
.WillByDefault(
[&](const google::protobuf::MessageLite& request_metadata) {
EXPECT_THAT(ToString(request_metadata),
options.expected_context);
});
EXPECT_CALL(*session, ExecuteModel(_, _))
.WillOnce(
[&](const google::protobuf::MessageLite& request_metadata,
optimization_guide::
OptimizationGuideModelExecutionResultStreamingCallback
callback) {
EXPECT_THAT(ToString(request_metadata),
options.expected_prompt);
StreamResponse(callback);
});
return session;
})
.WillOnce([&](optimization_guide::ModelBasedCapabilityKey feature,
const std::optional<
optimization_guide::SessionConfigParams>&
config_params) {
auto session = std::make_unique<
testing::NiceMock<optimization_guide::MockSession>>();
SetUpMockSession(*session, options.use_prompt_api_proto,
IsModelStreamingChunkByChunk());
ON_CALL(*session, AddContext(_))
.WillByDefault(
[&](const google::protobuf::MessageLite& request_metadata) {
EXPECT_THAT(ToString(request_metadata),
options.expected_cloned_context);
});
EXPECT_CALL(*session, ExecuteModel(_, _))
.WillOnce(
[&](const google::protobuf::MessageLite& request_metadata,
optimization_guide::
OptimizationGuideModelExecutionResultStreamingCallback
callback) {
EXPECT_THAT(ToString(request_metadata),
options.expected_prompt);
StreamResponse(callback);
});
return session;
});
}
// Test session creation.
mojo::Remote<blink::mojom::AILanguageModel> mock_session;
AITestUtils::MockCreateLanguageModelClient
mock_create_language_model_client;
base::RunLoop creation_run_loop;
bool is_initial_prompts_or_system_prompt_set =
options.initial_prompts.size() > 0 ||
(options.system_prompt.has_value() &&
options.system_prompt->size() > 0);
if (options.should_use_supported_language) {
EXPECT_CALL(mock_create_language_model_client, OnResult(_, _))
.WillOnce([&](mojo::PendingRemote<blink::mojom::AILanguageModel>
language_model,
blink::mojom::AILanguageModelInstanceInfoPtr info) {
EXPECT_TRUE(language_model);
EXPECT_EQ(info->max_tokens,
AITestUtils::GetFakeTokenLimits().max_context_tokens);
if (is_initial_prompts_or_system_prompt_set) {
EXPECT_GT(info->current_tokens, 0ul);
} else {
EXPECT_EQ(info->current_tokens, 0ul);
}
mock_session = mojo::Remote<blink::mojom::AILanguageModel>(
std::move(language_model));
creation_run_loop.Quit();
});
} else {
EXPECT_CALL(mock_create_language_model_client, OnError(_))
.WillOnce([&](blink::mojom::AIManagerCreateClientError error) {
EXPECT_EQ(
error,
blink::mojom::AIManagerCreateClientError::kUnsupportedLanguage);
creation_run_loop.Quit();
});
}
mojo::Remote<blink::mojom::AIManager> mock_remote = GetAIManagerRemote();
if (options.should_use_supported_language) {
EXPECT_EQ(GetAIManagerDownloadProgressObserversSize(), 0u);
AITestUtils::MockModelDownloadProgressMonitor mock_monitor;
base::RunLoop download_progress_run_loop;
EXPECT_CALL(mock_monitor, OnDownloadProgressUpdate(_, _))
.WillOnce(testing::Invoke(
[&](uint64_t downloaded_bytes, uint64_t total_bytes) {
EXPECT_EQ(downloaded_bytes, kTestModelDownloadSize);
EXPECT_EQ(total_bytes, kTestModelDownloadSize);
download_progress_run_loop.Quit();
}));
mock_remote->AddModelDownloadProgressObserver(
mock_monitor.BindNewPipeAndPassRemote());
ASSERT_TRUE(base::test::RunUntil([this] {
return GetAIManagerDownloadProgressObserversSize() == 1u;
}));
MockDownloadProgressUpdate(kTestModelDownloadSize,
kTestModelDownloadSize);
download_progress_run_loop.Run();
}
std::vector<blink::mojom::AILanguageCodePtr> expected_input_languages;
if (!options.should_use_supported_language) {
expected_input_languages.emplace_back(
blink::mojom::AILanguageCode::New("ja"));
}
mock_remote->CreateLanguageModel(
mock_create_language_model_client.BindNewPipeAndPassRemote(),
blink::mojom::AILanguageModelCreateOptions::New(
std::move(options.sampling_params), options.system_prompt,
std::move(options.initial_prompts),
std::move(expected_input_languages)));
creation_run_loop.Run();
if (!options.should_use_supported_language) {
return;
}
AITestUtils::MockModelStreamingResponder mock_responder;
TestPromptCall(mock_session, options.prompt_input,
options.should_overflow_context);
// Test session cloning.
mojo::Remote<blink::mojom::AILanguageModel> mock_cloned_session;
AITestUtils::MockCreateLanguageModelClient mock_clone_language_model_client;
base::RunLoop clone_run_loop;
EXPECT_CALL(mock_clone_language_model_client, OnResult(_, _))
.WillOnce(testing::Invoke(
[&](mojo::PendingRemote<blink::mojom::AILanguageModel>
language_model,
blink::mojom::AILanguageModelInstanceInfoPtr info) {
EXPECT_TRUE(language_model);
mock_cloned_session = mojo::Remote<blink::mojom::AILanguageModel>(
std::move(language_model));
clone_run_loop.Quit();
}));
mock_session->Fork(
mock_clone_language_model_client.BindNewPipeAndPassRemote());
clone_run_loop.Run();
TestPromptCall(mock_cloned_session, options.prompt_input,
/*should_overflow_context=*/false);
}
void TestSessionDestroy(
base::OnceCallback<void(
mojo::Remote<blink::mojom::AILanguageModel> mock_session,
AITestUtils::MockModelStreamingResponder& mock_responder)> callback) {
SetupMockOptimizationGuideKeyedService();
base::OnceClosure size_in_token_callback;
EXPECT_CALL(*mock_optimization_guide_keyed_service_, StartSession(_, _))
.WillOnce(
[&](optimization_guide::ModelBasedCapabilityKey feature,
const std::optional<optimization_guide::SessionConfigParams>&
config_params) {
auto session = std::make_unique<
testing::NiceMock<optimization_guide::MockSession>>();
SetUpMockSession(*session, /*use_prompt_api_proto=*/true,
/*is_streaming_chunk_by_chunk=*/true);
ON_CALL(*session, GetExecutionInputSizeInTokens(_, _))
.WillByDefault(
[&](const google::protobuf::MessageLite& request_metadata,
optimization_guide::
OptimizationGuideModelSizeInTokenCallback
callback) {
size_in_token_callback =
base::BindOnce(std::move(callback),
ToString(request_metadata).size());
});
// The model should not be executed.
EXPECT_CALL(*session, ExecuteModel(_, _)).Times(0);
return session;
});
mojo::Remote<blink::mojom::AILanguageModel> mock_session =
CreateMockSession();
AITestUtils::MockModelStreamingResponder mock_responder;
base::RunLoop responder_run_loop;
EXPECT_CALL(mock_responder, OnError(_))
.WillOnce(testing::Invoke(
[&](blink::mojom::ModelStreamingResponseStatus status) {
EXPECT_EQ(status, blink::mojom::ModelStreamingResponseStatus::
kErrorSessionDestroyed);
responder_run_loop.Quit();
}));
std::move(callback).Run(std::move(mock_session), mock_responder);
// Defers the `size_in_token_callback` until the testing callback which
// destroys the session is run.
if (size_in_token_callback) {
std::move(size_in_token_callback).Run();
}
responder_run_loop.Run();
}
optimization_guide::OptimizationGuideModelStreamingExecutionResult
CreateExecutionResult(const std::string& output,
bool is_complete,
uint32_t input_token_count,
uint32_t output_token_count) {
optimization_guide::proto::StringValue response;
response.set_value(output);
return optimization_guide::OptimizationGuideModelStreamingExecutionResult(
optimization_guide::StreamingResponse{
.response = optimization_guide::AnyWrapProto(response),
.is_complete = is_complete,
.input_token_count = input_token_count,
.output_token_count = output_token_count},
/*provided_by_on_device=*/true);
}
void TestSessionAddContext(bool should_overflow_context) {
SetupMockOptimizationGuideKeyedService();
// Use `max_context_token / 2 + 1` to ensure the
// context overflow on the second prompt.
uint32_t mock_size_in_tokens =
should_overflow_context
? 1 + AITestUtils::GetFakeTokenLimits().max_context_tokens / 2
: 1;
EXPECT_CALL(*mock_optimization_guide_keyed_service_, StartSession(_, _))
.WillOnce([&](optimization_guide::ModelBasedCapabilityKey feature,
const std::optional<
optimization_guide::SessionConfigParams>&
config_params) {
auto session = std::make_unique<
testing::NiceMock<optimization_guide::MockSession>>();
SetUpMockSession(*session, /*use_prompt_api_proto=*/false,
IsModelStreamingChunkByChunk());
ON_CALL(*session, GetContextSizeInTokens(_, _))
.WillByDefault(
[&](const google::protobuf::MessageLite& request_metadata,
optimization_guide::
OptimizationGuideModelSizeInTokenCallback callback) {
std::move(callback).Run(mock_size_in_tokens);
});
ON_CALL(*session, GetExecutionInputSizeInTokens(_, _))
.WillByDefault(
[&](const google::protobuf::MessageLite& request_metadata,
optimization_guide::
OptimizationGuideModelSizeInTokenCallback callback) {
std::move(callback).Run(mock_size_in_tokens);
});
// If the context is overflow, the previous prompt history should not
// be added to the context.
EXPECT_CALL(*session, AddContext(_))
.Times(should_overflow_context ? 0 : 1);
EXPECT_CALL(*session, ExecuteModel(_, _))
.Times(2)
.WillOnce(
[&](const google::protobuf::MessageLite& request_metadata,
optimization_guide::
OptimizationGuideModelExecutionResultStreamingCallback
callback) {
EXPECT_THAT(ToString(request_metadata), "U: A\nM: ");
callback.Run(CreateExecutionResult(
"OK", /*is_complete=*/true, /*input_token_count=*/1u,
/*output_token_count=*/mock_size_in_tokens));
})
.WillOnce(
[&](const google::protobuf::MessageLite& request_metadata,
optimization_guide::
OptimizationGuideModelExecutionResultStreamingCallback
callback) {
EXPECT_THAT(ToString(request_metadata), "U: B\nM: ");
callback.Run(CreateExecutionResult(
"OK", /*is_complete=*/true, /*input_token_count=*/1u,
/*output_token_count=*/mock_size_in_tokens));
});
return session;
});
mojo::Remote<blink::mojom::AILanguageModel> mock_session =
CreateMockSession();
AITestUtils::MockModelStreamingResponder mock_responder_1;
AITestUtils::MockModelStreamingResponder mock_responder_2;
base::RunLoop responder_run_loop_1;
base::RunLoop responder_run_loop_2;
EXPECT_CALL(mock_responder_1, OnStreaming(_, _))
.WillOnce(testing::Invoke(
[&](const std::string& text,
blink::mojom::ModelStreamingResponderAction action) {
EXPECT_THAT(text, "OK");
EXPECT_EQ(
IsAPIStreamingChunkByChunk()
? blink::mojom::ModelStreamingResponderAction::kAppend
: blink::mojom::ModelStreamingResponderAction::kReplace,
action);
}));
EXPECT_CALL(mock_responder_2, OnStreaming(_, _))
.WillOnce(testing::Invoke(
[&](const std::string& text,
blink::mojom::ModelStreamingResponderAction action) {
EXPECT_THAT(text, "OK");
EXPECT_EQ(
IsAPIStreamingChunkByChunk()
? blink::mojom::ModelStreamingResponderAction::kAppend
: blink::mojom::ModelStreamingResponderAction::kReplace,
action);
}));
EXPECT_CALL(mock_responder_2, OnContextOverflow())
.Times(should_overflow_context ? 1 : 0);
EXPECT_CALL(mock_responder_1, OnCompletion(_))
.WillOnce(testing::Invoke(
[&](blink::mojom::ModelExecutionContextInfoPtr context_info) {
responder_run_loop_1.Quit();
}));
EXPECT_CALL(mock_responder_2, OnCompletion(_))
.WillOnce(testing::Invoke(
[&](blink::mojom::ModelExecutionContextInfoPtr context_info) {
responder_run_loop_2.Quit();
}));
mock_session->Prompt("A", mock_responder_1.BindNewPipeAndPassRemote());
responder_run_loop_1.Run();
mock_session->Prompt("B", mock_responder_2.BindNewPipeAndPassRemote());
responder_run_loop_2.Run();
}
private:
void SetUpMockSession(
testing::NiceMock<optimization_guide::MockSession>& session,
bool use_prompt_api_proto,
bool is_streaming_chunk_by_chunk) {
ON_CALL(session, GetTokenLimits())
.WillByDefault(AITestUtils::GetFakeTokenLimits);
ON_CALL(session, GetOnDeviceFeatureMetadata())
.WillByDefault(ReturnRef(GetPromptApiMetadata(
use_prompt_api_proto, is_streaming_chunk_by_chunk)));
ON_CALL(session, GetSamplingParams()).WillByDefault([]() {
// We don't need to use these value, so just mock it with defaults.
return optimization_guide::SamplingParams{
/*top_k=*/kTestDefaultTopK,
/*temperature=*/kTestDefaultTemperature};
});
ON_CALL(session, GetSizeInTokens(_, _))
.WillByDefault(
[](const std::string& text,
optimization_guide::OptimizationGuideModelSizeInTokenCallback
callback) { std::move(callback).Run(text.size()); });
ON_CALL(session, GetExecutionInputSizeInTokens(_, _))
.WillByDefault(
[](const google::protobuf::MessageLite& request_metadata,
optimization_guide::OptimizationGuideModelSizeInTokenCallback
callback) {
std::move(callback).Run(ToString(request_metadata).size());
});
ON_CALL(session, GetContextSizeInTokens(_, _))
.WillByDefault(
[](const google::protobuf::MessageLite& request_metadata,
optimization_guide::OptimizationGuideModelSizeInTokenCallback
callback) {
std::move(callback).Run(ToString(request_metadata).size());
});
}
void StreamResponse(
optimization_guide::OptimizationGuideModelExecutionResultStreamingCallback
callback) {
std::string responses[3];
std::string response = std::string(kTestResponse);
if (IsModelStreamingChunkByChunk()) {
responses[0] = response.substr(0, 1);
responses[1] = response.substr(1);
responses[2] = "";
} else {
responses[0] = response.substr(0, 1);
responses[1] = kTestResponse;
responses[2] = kTestResponse;
}
callback.Run(CreateExecutionResult(responses[0],
/*is_complete=*/false,
/*input_token_count=*/1u,
/*output_token_count=*/1u));
callback.Run(CreateExecutionResult(responses[1],
/*is_complete=*/false,
/*input_token_count=*/1u,
/*output_token_count=*/1u));
callback.Run(CreateExecutionResult(responses[2],
/*is_complete=*/true,
/*input_token_count=*/1u,
/*output_token_count=*/1u));
}
void TestPromptCall(mojo::Remote<blink::mojom::AILanguageModel>& mock_session,
std::string& prompt,
bool should_overflow_context) {
AITestUtils::MockModelStreamingResponder mock_responder;
base::RunLoop responder_run_loop;
std::string response = std::string(kTestResponse);
EXPECT_CALL(mock_responder, OnStreaming(_, _))
.Times(3)
.WillOnce(testing::Invoke(
[&](const std::string& text,
blink::mojom::ModelStreamingResponderAction action) {
EXPECT_THAT(text, response.substr(0, 1));
EXPECT_EQ(
IsAPIStreamingChunkByChunk()
? blink::mojom::ModelStreamingResponderAction::kAppend
: blink::mojom::ModelStreamingResponderAction::kReplace,
action);
}))
.WillOnce(testing::Invoke(
[&](const std::string& text,
blink::mojom::ModelStreamingResponderAction action) {
EXPECT_THAT(text, IsAPIStreamingChunkByChunk()
? response.substr(1)
: kTestResponse);
EXPECT_EQ(
IsAPIStreamingChunkByChunk()
? blink::mojom::ModelStreamingResponderAction::kAppend
: blink::mojom::ModelStreamingResponderAction::kReplace,
action);
}))
.WillOnce(testing::Invoke(
[&](const std::string& text,
blink::mojom::ModelStreamingResponderAction action) {
EXPECT_THAT(text,
IsAPIStreamingChunkByChunk() ? "" : kTestResponse);
EXPECT_EQ(
IsAPIStreamingChunkByChunk()
? blink::mojom::ModelStreamingResponderAction::kAppend
: blink::mojom::ModelStreamingResponderAction::kReplace,
action);
}));
EXPECT_CALL(mock_responder, OnCompletion(_))
.WillOnce(testing::Invoke(
[&](blink::mojom::ModelExecutionContextInfoPtr context_info) {
responder_run_loop.Quit();
}));
mock_session->Prompt(prompt, mock_responder.BindNewPipeAndPassRemote());
responder_run_loop.Run();
}
mojo::Remote<blink::mojom::AILanguageModel> CreateMockSession() {
mojo::Remote<blink::mojom::AILanguageModel> mock_session;
AITestUtils::MockCreateLanguageModelClient
mock_create_language_model_client;
base::RunLoop creation_run_loop;
EXPECT_CALL(mock_create_language_model_client, OnResult(_, _))
.WillOnce([&](mojo::PendingRemote<blink::mojom::AILanguageModel>
language_model,
blink::mojom::AILanguageModelInstanceInfoPtr info) {
EXPECT_TRUE(language_model);
mock_session = mojo::Remote<blink::mojom::AILanguageModel>(
std::move(language_model));
creation_run_loop.Quit();
});
mojo::Remote<blink::mojom::AIManager> mock_remote = GetAIManagerRemote();
mock_remote->CreateLanguageModel(
mock_create_language_model_client.BindNewPipeAndPassRemote(),
blink::mojom::AILanguageModelCreateOptions::New());
creation_run_loop.Run();
return mock_session;
}
base::test::ScopedFeatureList scoped_feature_list_;
};
INSTANTIATE_TEST_SUITE_P(
All,
AILanguageModelTest,
testing::Combine(testing::Bool(), testing::Bool()),
[](const testing::TestParamInfo<testing::tuple<bool, bool>>& info) {
std::string description = "";
description += std::get<0>(info.param)
? "IsModelStreamingChunkByChunk"
: "IsModelStreamingWithCurrentResponse";
description += "_";
description += std::get<1>(info.param)
? "IsAPIStreamingChunkByChunk"
: "IsAPIStreamingWithCurrentResponse";
return description;
});
TEST_P(AILanguageModelTest, PromptDefaultSession) {
RunPromptTest(AILanguageModelTest::Options{
.prompt_input = kTestPrompt,
.expected_prompt = kExpectedFormattedTestPrompt,
});
}
TEST_P(AILanguageModelTest, PromptSessionWithSamplingParams) {
RunPromptTest(AILanguageModelTest::Options{
.sampling_params = blink::mojom::AILanguageModelSamplingParams::New(
/*top_k=*/kTestMaxTopK - 1,
/*temperature=*/kTestMaxTemperature * 0.9),
.prompt_input = kTestPrompt,
.expected_prompt = kExpectedFormattedTestPrompt,
});
}
TEST_P(AILanguageModelTest, PromptSessionWithSamplingParams_ExceedMaxTopK) {
RunPromptTest(AILanguageModelTest::Options{
.sampling_params = blink::mojom::AILanguageModelSamplingParams::New(
/*top_k=*/kTestMaxTopK + 1,
/*temperature=*/kTestMaxTemperature * 0.9),
.prompt_input = kTestPrompt,
.expected_prompt = kExpectedFormattedTestPrompt,
});
}
TEST_P(AILanguageModelTest,
PromptSessionWithSamplingParams_ExceedMaxTemperature) {
RunPromptTest(AILanguageModelTest::Options{
.sampling_params = blink::mojom::AILanguageModelSamplingParams::New(
/*top_k=*/kTestMaxTopK - 1,
/*temperature=*/kTestMaxTemperature + 0.1),
.prompt_input = kTestPrompt,
.expected_prompt = kExpectedFormattedTestPrompt,
});
}
TEST_P(AILanguageModelTest, PromptSessionWithSystemPrompt) {
RunPromptTest(AILanguageModelTest::Options{
.system_prompt = kTestSystemPrompts,
.prompt_input = kTestPrompt,
.expected_context = kExpectedFormattedSystemPrompts,
.expected_cloned_context =
base::StrCat({kExpectedFormattedSystemPrompts,
kExpectedFormattedTestPrompt, kTestResponse, "\n"}),
.expected_prompt = kExpectedFormattedTestPrompt,
});
}
TEST_P(AILanguageModelTest, PromptSessionWithInitialPrompts) {
RunPromptTest(AILanguageModelTest::Options{
.initial_prompts = GetTestInitialPrompts(),
.prompt_input = kTestPrompt,
.expected_context = kExpectedFormattedInitialPrompts,
.expected_cloned_context =
base::StrCat({kExpectedFormattedInitialPrompts,
kExpectedFormattedTestPrompt, kTestResponse, "\n"}),
.expected_prompt = kExpectedFormattedTestPrompt,
});
}
TEST_P(AILanguageModelTest, PromptSessionWithSystemPromptAndInitialPrompts) {
RunPromptTest(AILanguageModelTest::Options{
.system_prompt = kTestSystemPrompts,
.initial_prompts = GetTestInitialPrompts(),
.prompt_input = kTestPrompt,
.expected_context = kExpectedFormattedSystemPromptAndInitialPrompts,
.expected_cloned_context = base::StrCat(
{kExpectedFormattedSystemPrompts, kExpectedFormattedInitialPrompts,
kExpectedFormattedTestPrompt, kTestResponse, "\n"}),
.expected_prompt = kExpectedFormattedTestPrompt,
});
}
TEST_P(AILanguageModelTest, PromptSessionWithPromptApiRequests) {
RunPromptTest(AILanguageModelTest::Options{
.system_prompt = "Test system prompt",
.initial_prompts = GetTestInitialPrompts(),
.prompt_input = "Test prompt",
.expected_context = ("S: Test system prompt\n"
"U: How are you?\n"
"M: I'm fine, thank you, and you?\n"
"U: I'm fine too.\n"),
.expected_cloned_context = ("S: Test system prompt\n"
"U: How are you?\n"
"M: I'm fine, thank you, and you?\n"
"U: I'm fine too.\n"
"U: Test prompt\n"
"M: Test response\n"),
.expected_prompt = "U: Test prompt\nM: ",
.use_prompt_api_proto = true,
});
}
TEST_P(AILanguageModelTest, PromptSessionWithContextOverflow) {
RunPromptTest({.prompt_input = kTestPrompt,
.expected_prompt = kExpectedFormattedTestPrompt,
.should_overflow_context = true});
}
TEST_P(AILanguageModelTest, PromptSessionWithUnsupportedLanguage) {
RunPromptTest({.should_use_supported_language = true});
}
// Tests that sending `Prompt()` after destroying the session won't make a real
// call to the model.
TEST_P(AILanguageModelTest, PromptAfterDestroy) {
TestSessionDestroy(base::BindOnce(
[](mojo::Remote<blink::mojom::AILanguageModel> mock_session,
AITestUtils::MockModelStreamingResponder& mock_responder) {
mock_session->Destroy();
mock_session->Prompt(kTestPrompt,
mock_responder.BindNewPipeAndPassRemote());
}));
}
// Tests that sending `Prompt()` right before destroying the session won't make
// a real call to the model.
TEST_P(AILanguageModelTest, PromptBeforeDestroy) {
TestSessionDestroy(base::BindOnce(
[](mojo::Remote<blink::mojom::AILanguageModel> mock_session,
AITestUtils::MockModelStreamingResponder& mock_responder) {
mock_session->Prompt(kTestPrompt,
mock_responder.BindNewPipeAndPassRemote());
mock_session->Destroy();
}));
}
// Tests that the session will call `AddContext()` from the second prompt when
// there is no context overflow.
TEST_P(AILanguageModelTest, PromptWithHistoryWithoutContextOverflow) {
TestSessionAddContext(/*should_overflow_context=*/false);
}
// Tests that the session will not call `AddContext()` from the second prompt
// when there is context overflow.
TEST_P(AILanguageModelTest, PromptWithHistoryWithContextOverflow) {
TestSessionAddContext(/*should_overflow_context=*/true);
}
TEST_P(AILanguageModelTest, CanCreate_IsLanguagesSupported) {
SetupMockOptimizationGuideKeyedService();
EXPECT_CALL(*mock_optimization_guide_keyed_service_,
GetOnDeviceModelEligibility(_))
.WillRepeatedly(testing::Return(
optimization_guide::OnDeviceModelEligibilityReason::kSuccess));
auto options = blink::mojom::AILanguageModelAvailabilityOptions::New(
/*top_k=*/std::nullopt,
/*temperature=*/std::nullopt,
/*expected_input_languages=*/
std::vector<blink::mojom::AILanguageCodePtr>{
AITestUtils::ToMojoLanguageCodes({"en"})});
base::MockCallback<AIManager::CanCreateLanguageModelCallback> callback;
EXPECT_CALL(callback,
Run(blink::mojom::ModelAvailabilityCheckResult::kAvailable));
GetAIManagerInterface()->CanCreateLanguageModel(std::move(options),
callback.Get());
}
TEST_P(AILanguageModelTest, CanCreate_UnIsLanguagesSupported) {
SetupMockOptimizationGuideKeyedService();
EXPECT_CALL(*mock_optimization_guide_keyed_service_,
GetOnDeviceModelEligibility(_))
.WillRepeatedly(testing::Return(
optimization_guide::OnDeviceModelEligibilityReason::kSuccess));
auto options = blink::mojom::AILanguageModelAvailabilityOptions::New(
/*top_k=*/std::nullopt,
/*temperature=*/std::nullopt,
/*expected_input_languages=*/
std::vector<blink::mojom::AILanguageCodePtr>{
AITestUtils::ToMojoLanguageCodes({"ja"})});
base::MockCallback<AIManager::CanCreateLanguageModelCallback> callback;
EXPECT_CALL(callback, Run(blink::mojom::ModelAvailabilityCheckResult::
kUnavailableUnsupportedLanguage));
GetAIManagerInterface()->CanCreateLanguageModel(std::move(options),
callback.Get());
}
// Tests `AILanguageModel::Context` creation without initial prompts.
TEST(AILanguageModelContextCreationTest, CreateContext_WithoutInitialPrompts) {
AILanguageModel::Context context(kTestMaxContextToken, {});
EXPECT_FALSE(context.HasContextItem());
}
// Tests `AILanguageModel::Context` creation with valid initial prompts.
TEST(AILanguageModelContextCreationTest,
CreateContext_WithInitialPrompts_Normal) {
AILanguageModel::Context context(
kTestMaxContextToken,
SimpleContextItem("initial prompts\n", kTestInitialPromptsToken));
EXPECT_TRUE(context.HasContextItem());
}
// Tests `AILanguageModel::Context` creation with initial prompts that exceeds
// the max token limit.
TEST(AILanguageModelContextCreationTest,
CreateContext_WithInitialPrompts_Overflow) {
EXPECT_DEATH_IF_SUPPORTED(
AILanguageModel::Context context(
kTestMaxContextToken, SimpleContextItem("long initial prompts\n",
kTestMaxContextToken + 1u)),
"");
}
// Tests the `AILanguageModel::Context` that's initialized with/without any
// initial prompt.
class AILanguageModelContextTest : public testing::Test,
public testing::WithParamInterface<
/*is_init_with_initial_prompts=*/bool> {
public:
bool IsInitializedWithInitialPrompts() { return GetParam(); }
uint32_t GetMaxContextToken() {
return IsInitializedWithInitialPrompts()
? kTestMaxContextToken - kTestInitialPromptsToken
: kTestMaxContextToken;
}
std::string GetInitialPromptsPrefix() {
return IsInitializedWithInitialPrompts() ? "S: initial prompts\n" : "";
}
AILanguageModel::Context context_{
kTestMaxContextToken,
IsInitializedWithInitialPrompts()
? SimpleContextItem("initial prompts", kTestInitialPromptsToken)
: AILanguageModel::Context::ContextItem()};
};
INSTANTIATE_TEST_SUITE_P(All,
AILanguageModelContextTest,
testing::Bool(),
[](const testing::TestParamInfo<bool>& info) {
return info.param ? "WithInitialPrompts"
: "WithoutInitialPrompts";
});
// Tests `GetContextString()` and `HasContextItem()` when the context is empty.
TEST_P(AILanguageModelContextTest, TestContextOperation_Empty) {
EXPECT_EQ(GetContextString(context_), GetInitialPromptsPrefix());
if (IsInitializedWithInitialPrompts()) {
EXPECT_TRUE(context_.HasContextItem());
} else {
EXPECT_FALSE(context_.HasContextItem());
}
}
// Tests `GetContextString()` and `HasContextItem()` when some items are added
// to the context.
TEST_P(AILanguageModelContextTest, TestContextOperation_NonEmpty) {
EXPECT_EQ(context_.AddContextItem(SimpleContextItem("test", 1u)),
AILanguageModel::Context::SpaceReservationResult::kSufficientSpace);
EXPECT_EQ(GetContextString(context_),
GetInitialPromptsPrefix() + "S: test\n");
EXPECT_TRUE(context_.HasContextItem());
context_.AddContextItem(SimpleContextItem(" test again", 2u));
EXPECT_EQ(GetContextString(context_),
GetInitialPromptsPrefix() + "S: test\nS: test again\n");
EXPECT_TRUE(context_.HasContextItem());
}
// Tests `GetContextString()` and `HasContextItem()` when the items overflow.
TEST_P(AILanguageModelContextTest, TestContextOperation_Overflow) {
EXPECT_EQ(context_.AddContextItem(SimpleContextItem("test", 1u)),
AILanguageModel::Context::SpaceReservationResult::kSufficientSpace);
EXPECT_EQ(GetContextString(context_),
GetInitialPromptsPrefix() + "S: test\n");
EXPECT_TRUE(context_.HasContextItem());
// Since the total number of tokens will exceed `kTestMaxContextToken`, the
// old item will be evicted.
EXPECT_EQ(
context_.AddContextItem(
SimpleContextItem("test long token", GetMaxContextToken())),
AILanguageModel::Context::SpaceReservationResult::kSpaceMadeAvailable);
EXPECT_EQ(GetContextString(context_),
GetInitialPromptsPrefix() + "S: test long token\n");
EXPECT_TRUE(context_.HasContextItem());
}
// Tests `GetContextString()` and `HasContextItem()` when the items overflow on
// the first insertion.
TEST_P(AILanguageModelContextTest, TestContextOperation_OverflowOnFirstItem) {
EXPECT_EQ(
context_.AddContextItem(
SimpleContextItem("test very long token", GetMaxContextToken() + 1u)),
AILanguageModel::Context::SpaceReservationResult::kInsufficientSpace);
EXPECT_EQ(GetContextString(context_), GetInitialPromptsPrefix());
if (IsInitializedWithInitialPrompts()) {
EXPECT_TRUE(context_.HasContextItem());
} else {
EXPECT_FALSE(context_.HasContextItem());
}
}