blob: e4e456ac044137131e2243711efba050e0842dff [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/ai/ai_assistant.h"
#include <optional>
#include "base/functional/callback_helpers.h"
#include "base/notreached.h"
#include "base/strings/stringprintf.h"
#include "base/test/bind.h"
#include "chrome/browser/ai/ai_test_utils.h"
#include "components/optimization_guide/core/mock_optimization_guide_model_executor.h"
#include "components/optimization_guide/core/optimization_guide_model_executor.h"
#include "components/optimization_guide/proto/common_types.pb.h"
#include "components/optimization_guide/proto/features/prompt_api.pb.h"
#include "components/optimization_guide/proto/string_value.pb.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/blink/public/mojom/ai/ai_assistant.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_manager.mojom-forward.h"
#include "third_party/blink/public/mojom/ai/ai_manager.mojom-shared.h"
using testing::_;
using testing::Test;
using Role = blink::mojom::AIAssistantInitialPromptRole;
namespace {
using optimization_guide::proto::PromptApiRequest;
using optimization_guide::proto::PromptApiRole;
const uint32_t kTestMaxContextToken = 10u;
const uint32_t kTestInitialPromptsToken = 5u;
const uint32_t kDefaultTopK = 1;
const float kDefaultTemperature = 0.0;
const char kTestPrompt[] = "Test prompt";
const char kExpectedFormattedTestPrompt[] = "User: Test prompt\nModel: ";
const char kTestSystemPrompts[] = "Test system prompt";
const char kExpectedFormattedSystemPrompts[] = "Test system prompt\n";
const char kTestResponse[] = "Test response";
const char kTestInitialPromptsUser1[] = "How are you?";
const char kTestInitialPromptsSystem1[] = "I'm fine, thank you, and you?";
const char kTestInitialPromptsUser2[] = "I'm fine too.";
const char kExpectedFormattedInitialPrompts[] =
"User: How are you?\nModel: I'm fine, thank you, and you?\nUser: I'm fine "
"too.\n";
const char kExpectedFormattedSystemPromptAndInitialPrompts[] =
"Test system prompt\nUser: How are you?\nModel: I'm fine, thank you, and "
"you?\nUser: I'm fine too.\n";
std::vector<blink::mojom::AIAssistantInitialPromptPtr> GetTestInitialPrompts() {
auto create_initial_prompt = [](Role role, const char* content) {
return blink::mojom::AIAssistantInitialPrompt::New(role, content);
};
std::vector<blink::mojom::AIAssistantInitialPromptPtr> initial_prompts{};
initial_prompts.push_back(
create_initial_prompt(Role::kUser, kTestInitialPromptsUser1));
initial_prompts.push_back(
create_initial_prompt(Role::kAssistant, kTestInitialPromptsSystem1));
initial_prompts.push_back(
create_initial_prompt(Role::kUser, kTestInitialPromptsUser2));
return initial_prompts;
}
std::string GetContextString(AIAssistant::Context& ctx) {
auto msg = ctx.MakeRequest();
auto* v = static_cast<optimization_guide::proto::StringValue*>(msg.get());
return v->value();
}
AIAssistant::Context::ContextItem SimpleContextItem(std::string text,
uint32_t size) {
auto item = AIAssistant::Context::ContextItem();
item.tokens = size;
auto* prompt = item.prompts.Add();
prompt->set_role(PromptApiRole::PROMPT_API_ROLE_SYSTEM);
prompt->set_content(text);
return item;
}
const char* FormatPromptRole(PromptApiRole role) {
switch (role) {
case PromptApiRole::PROMPT_API_ROLE_SYSTEM:
return "S: ";
case PromptApiRole::PROMPT_API_ROLE_USER:
return "U: ";
case PromptApiRole::PROMPT_API_ROLE_ASSISTANT:
return "M: ";
default:
NOTREACHED();
}
}
std::string ToString(const PromptApiRequest& request) {
std::ostringstream oss;
for (const auto& prompt : request.initial_prompts()) {
oss << FormatPromptRole(prompt.role()) << prompt.content() << "\n";
}
for (const auto& prompt : request.prompt_history()) {
oss << FormatPromptRole(prompt.role()) << prompt.content() << "\n";
}
for (const auto& prompt : request.current_prompts()) {
oss << FormatPromptRole(prompt.role()) << prompt.content() << "\n";
}
if (request.current_prompts_size() > 0) {
oss << FormatPromptRole(PromptApiRole::PROMPT_API_ROLE_ASSISTANT);
}
return oss.str();
}
std::string ToString(const google::protobuf::MessageLite& request_metadata) {
if (request_metadata.GetTypeName() ==
"optimization_guide.proto.PromptApiRequest") {
return ToString(*static_cast<const PromptApiRequest*>(&request_metadata));
}
if (request_metadata.GetTypeName() ==
"optimization_guide.proto.StringValue") {
return static_cast<const optimization_guide::proto::StringValue*>(
&request_metadata)
->value();
}
return "unexpected type";
}
const optimization_guide::proto::Any& GetPromptApiMetadata() {
static base::NoDestructor<optimization_guide::proto::Any> data([]() {
optimization_guide::proto::PromptApiMetadata metadata;
metadata.set_version(1);
optimization_guide::proto::Any any;
any.set_type_url("type.googleapis.com/" + metadata.GetTypeName());
any.set_value(metadata.SerializeAsString());
return any;
}());
return *data;
}
} // namespace
class AIAssistantTest : public AITestUtils::AITestBase {
public:
struct Options {
blink::mojom::AIAssistantSamplingParamsPtr sampling_params = nullptr;
std::optional<std::string> system_prompt = std::nullopt;
std::vector<blink::mojom::AIAssistantInitialPromptPtr> initial_prompts;
std::string prompt_input = kTestPrompt;
std::string expected_context = "";
std::string expected_prompt = kExpectedFormattedTestPrompt;
bool use_prompt_api_proto = false;
};
protected:
// The helper function that creates a `AIAssistant` and executes the prompt.
void RunPromptTest(Options options) {
blink::mojom::AIAssistantSamplingParamsPtr sampling_params_copy;
if (options.sampling_params) {
sampling_params_copy = options.sampling_params->Clone();
}
// Set up mock service.
SetupMockOptimizationGuideKeyedService();
EXPECT_CALL(*mock_optimization_guide_keyed_service_, StartSession(_, _))
.WillOnce([&](optimization_guide::ModelBasedCapabilityKey feature,
const std::optional<
optimization_guide::SessionConfigParams>&
config_params) {
auto session = std::make_unique<
testing::NiceMock<optimization_guide::MockSession>>();
if (sampling_params_copy) {
EXPECT_EQ(config_params->sampling_params->top_k,
sampling_params_copy->top_k);
EXPECT_EQ(config_params->sampling_params->temperature,
sampling_params_copy->temperature);
}
ON_CALL(*session, GetTokenLimits())
.WillByDefault(AITestUtils::GetFakeTokenLimits);
ON_CALL(*session, GetOnDeviceFeatureMetadata())
.WillByDefault(options.use_prompt_api_proto
? GetPromptApiMetadata
: AITestUtils::GetFakeFeatureMetadata);
ON_CALL(*session, GetSamplingParams()).WillByDefault([]() {
// We don't need to use these value, so just mock it with defaults.
return optimization_guide::SamplingParams{
/*top_k=*/kDefaultTopK,
/*temperature=*/kDefaultTemperature};
});
ON_CALL(*session, GetSizeInTokens(_, _))
.WillByDefault(
[](const std::string& text,
optimization_guide::
OptimizationGuideModelSizeInTokenCallback callback) {
std::move(callback).Run(text.size());
});
ON_CALL(*session, GetContextSizeInTokens(_, _))
.WillByDefault(
[](const google::protobuf::MessageLite& request_metadata,
optimization_guide::
OptimizationGuideModelSizeInTokenCallback callback) {
std::move(callback).Run(ToString(request_metadata).size());
});
ON_CALL(*session, AddContext(_))
.WillByDefault(
[&](const google::protobuf::MessageLite& request_metadata) {
EXPECT_THAT(ToString(request_metadata),
options.expected_context);
});
EXPECT_CALL(*session, ExecuteModel(_, _))
.WillOnce(
[&](const google::protobuf::MessageLite& request_metadata,
optimization_guide::
OptimizationGuideModelExecutionResultStreamingCallback
callback) {
EXPECT_THAT(ToString(request_metadata),
options.expected_prompt);
callback.Run(CreateExecutionResult(kTestResponse,
/*is_complete=*/true));
});
return session;
});
mojo::Remote<blink::mojom::AIManager> ai_manager = GetAIManagerRemote();
mojo::Remote<blink::mojom::AIAssistant> mock_session;
ai_manager->CreateAssistant(
mock_session.BindNewPipeAndPassReceiver(),
std::move(options.sampling_params), options.system_prompt,
std::move(options.initial_prompts), base::NullCallback());
AITestUtils::MockModelStreamingResponder mock_responder;
base::RunLoop run_loop;
// This is run twice because the response is returned together with
// `is_complete` set to true.
EXPECT_CALL(mock_responder, OnResponse(_, _, _))
.WillOnce([&](blink::mojom::ModelStreamingResponseStatus status,
const std::optional<std::string>& text,
std::optional<uint64_t> current_tokens) {
EXPECT_THAT(text, kTestResponse);
EXPECT_EQ(status,
blink::mojom::ModelStreamingResponseStatus::kOngoing);
})
.WillOnce([&](blink::mojom::ModelStreamingResponseStatus status,
const std::optional<std::string>& text,
std::optional<uint64_t> current_tokens) {
EXPECT_EQ(status,
blink::mojom::ModelStreamingResponseStatus::kComplete);
run_loop.Quit();
});
mock_session->Prompt(options.prompt_input,
mock_responder.BindNewPipeAndPassRemote());
run_loop.Run();
}
private:
optimization_guide::OptimizationGuideModelStreamingExecutionResult
CreateExecutionResult(const std::string& output, bool is_complete) {
optimization_guide::proto::StringValue response;
response.set_value(output);
std::string serialized_metadata;
response.SerializeToString(&serialized_metadata);
optimization_guide::proto::Any any;
any.set_value(serialized_metadata);
any.set_type_url(AITestUtils::GetTypeURLForProto(response.GetTypeName()));
return optimization_guide::OptimizationGuideModelStreamingExecutionResult(
optimization_guide::StreamingResponse{
.response = any,
.is_complete = is_complete,
},
/*provided_by_on_device=*/true);
}
std::unique_ptr<AITestUtils::MockSupportsUserData> mock_host_;
};
TEST_F(AIAssistantTest, PromptDefaultSession) {
RunPromptTest(AIAssistantTest::Options{
.prompt_input = kTestPrompt,
.expected_prompt = kExpectedFormattedTestPrompt,
});
}
TEST_F(AIAssistantTest, PromptSessionWithSamplingParams) {
RunPromptTest(AIAssistantTest::Options{
.sampling_params = blink::mojom::AIAssistantSamplingParams::New(
/*top_k=*/10, /*temperature=*/0.6),
.prompt_input = kTestPrompt,
.expected_prompt = kExpectedFormattedTestPrompt,
});
}
TEST_F(AIAssistantTest, PromptSessionWithSystemPrompt) {
RunPromptTest(AIAssistantTest::Options{
.system_prompt = kTestSystemPrompts,
.prompt_input = kTestPrompt,
.expected_context = kExpectedFormattedSystemPrompts,
.expected_prompt = kExpectedFormattedTestPrompt,
});
}
TEST_F(AIAssistantTest, PromptSessionWithInitialPrompts) {
RunPromptTest(AIAssistantTest::Options{
.initial_prompts = GetTestInitialPrompts(),
.prompt_input = kTestPrompt,
.expected_context = kExpectedFormattedInitialPrompts,
.expected_prompt = kExpectedFormattedTestPrompt,
});
}
TEST_F(AIAssistantTest, PromptSessionWithSystemPromptAndInitialPrompts) {
RunPromptTest(AIAssistantTest::Options{
.system_prompt = kTestSystemPrompts,
.initial_prompts = GetTestInitialPrompts(),
.prompt_input = kTestPrompt,
.expected_context = kExpectedFormattedSystemPromptAndInitialPrompts,
.expected_prompt = kExpectedFormattedTestPrompt,
});
}
TEST_F(AIAssistantTest, PromptSessionWithPromptApiRequests) {
RunPromptTest(AIAssistantTest::Options{
.system_prompt = "Test system prompt",
.initial_prompts = GetTestInitialPrompts(),
.prompt_input = "Test prompt",
.expected_context = ("S: Test system prompt\n"
"U: How are you?\n"
"M: I'm fine, thank you, and you?\n"
"U: I'm fine too.\n"),
.expected_prompt = "U: Test prompt\nM: ",
.use_prompt_api_proto = true,
});
}
// Tests `AIAssistant::Context` creation without initial prompts.
TEST(AIAssistantContextCreationTest, CreateContext_WithoutInitialPrompts) {
AIAssistant::Context context(kTestMaxContextToken, {},
/*use_prompt_api_request*/ false);
EXPECT_FALSE(context.HasContextItem());
}
// Tests `AIAssistant::Context` creation with valid initial prompts.
TEST(AIAssistantContextCreationTest, CreateContext_WithInitialPrompts_Normal) {
AIAssistant::Context context(
kTestMaxContextToken,
SimpleContextItem("initial prompts\n", kTestInitialPromptsToken),
/*use_prompt_api_request*/ false);
EXPECT_TRUE(context.HasContextItem());
}
// Tests `AIAssistant::Context` creation with initial prompts that exceeds the
// max token limit.
TEST(AIAssistantContextCreationTest,
CreateContext_WithInitialPrompts_Overflow) {
EXPECT_DEATH_IF_SUPPORTED(
AIAssistant::Context context(kTestMaxContextToken,
SimpleContextItem("long initial prompts\n",
kTestMaxContextToken + 1u),
/*use_prompt_api_request*/ false),
"");
}
// Tests the `AIAssistant::Context` that's initialized with/without any
// initial prompt.
class AIAssistantContextTest : public testing::Test,
public testing::WithParamInterface<
/*is_init_with_initial_prompts=*/bool> {
public:
bool IsInitializedWithInitialPrompts() { return GetParam(); }
uint32_t GetMaxContextToken() {
return IsInitializedWithInitialPrompts()
? kTestMaxContextToken - kTestInitialPromptsToken
: kTestMaxContextToken;
}
std::string GetInitialPromptsPrefix() {
return IsInitializedWithInitialPrompts() ? "initial prompts\n" : "";
}
AIAssistant::Context context_{
kTestMaxContextToken,
IsInitializedWithInitialPrompts()
? SimpleContextItem("initial prompts", kTestInitialPromptsToken)
: AIAssistant::Context::ContextItem(),
/*use_prompt_api_request*/ false};
};
INSTANTIATE_TEST_SUITE_P(All,
AIAssistantContextTest,
testing::Bool(),
[](const testing::TestParamInfo<bool>& info) {
return info.param ? "WithInitialPrompts"
: "WithoutInitialPrompts";
});
// Tests `GetContextString()` and `HasContextItem()` when the context is empty.
TEST_P(AIAssistantContextTest, TestContextOperation_Empty) {
EXPECT_EQ(GetContextString(context_), GetInitialPromptsPrefix());
if (IsInitializedWithInitialPrompts()) {
EXPECT_TRUE(context_.HasContextItem());
} else {
EXPECT_FALSE(context_.HasContextItem());
}
}
// Tests `GetContextString()` and `HasContextItem()` when some items are added
// to the context.
TEST_P(AIAssistantContextTest, TestContextOperation_NonEmpty) {
context_.AddContextItem(SimpleContextItem("test", 1u));
EXPECT_EQ(GetContextString(context_), GetInitialPromptsPrefix() + "test\n");
EXPECT_TRUE(context_.HasContextItem());
context_.AddContextItem(SimpleContextItem(" test again", 2u));
EXPECT_EQ(GetContextString(context_),
GetInitialPromptsPrefix() + "test\n test again\n");
EXPECT_TRUE(context_.HasContextItem());
}
// Tests `GetContextString()` and `HasContextItem()` when the items overflow.
TEST_P(AIAssistantContextTest, TestContextOperation_Overflow) {
context_.AddContextItem(SimpleContextItem("test", 1u));
EXPECT_EQ(GetContextString(context_), GetInitialPromptsPrefix() + "test\n");
EXPECT_TRUE(context_.HasContextItem());
// Since the total number of tokens will exceed `kTestMaxContextToken`, the
// old item will be evicted.
context_.AddContextItem(
SimpleContextItem("test long token", GetMaxContextToken()));
EXPECT_EQ(GetContextString(context_),
GetInitialPromptsPrefix() + "test long token\n");
EXPECT_TRUE(context_.HasContextItem());
}
// Tests `GetContextString()` and `HasContextItem()` when the items overflow on
// the first insertion.
TEST_P(AIAssistantContextTest, TestContextOperation_OverflowOnFirstItem) {
context_.AddContextItem(
SimpleContextItem("test very long token", GetMaxContextToken() + 1u));
EXPECT_EQ(GetContextString(context_), GetInitialPromptsPrefix());
if (IsInitializedWithInitialPrompts()) {
EXPECT_TRUE(context_.HasContextItem());
} else {
EXPECT_FALSE(context_.HasContextItem());
}
}