blob: 7ebf0fd04c50d3da733ea03da77725e16ede51b3 [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/ai/ai_language_model.h"
#include <cstdint>
#include <initializer_list>
#include <optional>
#include <string>
#include <vector>
#include "base/functional/bind.h"
#include "base/functional/callback_forward.h"
#include "base/functional/callback_helpers.h"
#include "base/notreached.h"
#include "base/run_loop.h"
#include "base/strings/strcat.h"
#include "base/strings/string_number_conversions.h"
#include "base/strings/stringprintf.h"
#include "base/task/current_thread.h"
#include "base/test/mock_callback.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/test_future.h"
#include "chrome/browser/ai/ai_test_utils.h"
#include "chrome/browser/ai/ai_utils.h"
#include "chrome/browser/ai/features.h"
#include "chrome/browser/component_updater/optimization_guide_on_device_model_installer.h"
#include "components/optimization_guide/core/mock_optimization_guide_model_executor.h"
#include "components/optimization_guide/core/model_execution/multimodal_message.h"
#include "components/optimization_guide/core/model_execution/test/fake_model_assets.h"
#include "components/optimization_guide/core/model_execution/test/fake_model_broker.h"
#include "components/optimization_guide/core/optimization_guide_features.h"
#include "components/optimization_guide/core/optimization_guide_model_executor.h"
#include "components/optimization_guide/core/optimization_guide_proto_util.h"
#include "components/optimization_guide/proto/common_types.pb.h"
#include "components/optimization_guide/proto/descriptors.pb.h"
#include "components/optimization_guide/proto/features/prompt_api.pb.h"
#include "components/optimization_guide/proto/on_device_model_execution_config.pb.h"
#include "components/optimization_guide/proto/string_value.pb.h"
#include "components/update_client/update_client.h"
#include "content/public/browser/render_widget_host_view.h"
#include "services/on_device_model/public/cpp/capabilities.h"
#include "services/on_device_model/public/mojom/on_device_model.mojom.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/blink/public/common/features_generated.h"
#include "third_party/blink/public/mojom/ai/ai_common.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_language_model.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_manager.mojom-shared.h"
#include "third_party/blink/public/mojom/ai/model_download_progress_observer.mojom-forward.h"
#include "third_party/blink/public/mojom/ai/model_streaming_responder.mojom-shared.h"
namespace {
using ::optimization_guide::FieldSubstitution;
using ::optimization_guide::ForbidUnsafe;
using ::optimization_guide::StringValueField;
using ::testing::_;
using ::testing::ElementsAre;
using ::testing::ElementsAreArray;
using ::testing::Return;
using Role = ::blink::mojom::AILanguageModelPromptRole;
constexpr uint32_t kTestMaxContextToken = 10u;
constexpr uint32_t kTestDefaultTopK = 1u;
constexpr float kTestDefaultTemperature = 0.0f;
constexpr uint32_t kTestMaxTopK = 5u;
constexpr float kTestMaxTemperature = 1.5;
constexpr uint32_t kTestMaxTokens = 100u;
constexpr uint32_t kTestModelMaxTokens = 200u;
constexpr uint64_t kTestModelDownloadSize = 572u;
static_assert(kTestDefaultTopK <= kTestMaxTopK);
static_assert(kTestDefaultTemperature <= kTestMaxTemperature);
SkBitmap CreateTestBitmap(int width, int height) {
SkBitmap bitmap;
bitmap.allocN32Pixels(width, height);
bitmap.eraseColor(SK_ColorRED);
return bitmap;
}
on_device_model::mojom::AudioDataPtr CreateTestAudio() {
return on_device_model::mojom::AudioData::New();
}
// Convert a single AILanguageModelPromptContentPtr to a vector.
std::vector<blink::mojom::AILanguageModelPromptContentPtr> ToVector(
blink::mojom::AILanguageModelPromptContentPtr content) {
std::vector<blink::mojom::AILanguageModelPromptContentPtr> vector;
vector.push_back(std::move(content));
return vector;
}
// Convert a list of strings to a AILanguageModelPromptContentPtr vector.
std::vector<blink::mojom::AILanguageModelPromptContentPtr> ToContentVector(
std::initializer_list<std::string> texts) {
std::vector<blink::mojom::AILanguageModelPromptContentPtr> vector;
for (const std::string& text : texts) {
vector.push_back(blink::mojom::AILanguageModelPromptContent::NewText(text));
}
return vector;
}
optimization_guide::proto::FeatureTextSafetyConfiguration CreateSafetyConfig() {
optimization_guide::proto::FeatureTextSafetyConfiguration safety_config;
safety_config.set_feature(
optimization_guide::proto::MODEL_EXECUTION_FEATURE_PROMPT_API);
safety_config.mutable_safety_category_thresholds()->Add(ForbidUnsafe());
return safety_config;
}
// Build a mojo prompt struct with the specified `role` and `text`
blink::mojom::AILanguageModelPromptPtr MakePrompt(Role role,
const std::string& text,
bool is_prefix = false) {
return blink::mojom::AILanguageModelPrompt::New(role, ToContentVector({text}),
is_prefix);
}
// Build a vector with a single prompt that has multiple user text contents.
std::vector<blink::mojom::AILanguageModelPromptPtr> MakeInput(
std::initializer_list<std::string> texts) {
std::vector<blink::mojom::AILanguageModelPromptPtr> prompts;
prompts.push_back(blink::mojom::AILanguageModelPrompt::New(
Role::kUser, ToContentVector(std::move(texts)), /*is_prefix=*/false));
return prompts;
}
// Build a vector with a single prompt that has a single user text content.
std::vector<blink::mojom::AILanguageModelPromptPtr> MakeInput(
const std::string& text) {
return MakeInput({text});
}
// Construct a ContextItem with system prompt text.
AILanguageModel::Context::ContextItem SimpleContextItem(std::string text,
uint32_t size) {
auto item = AILanguageModel::Context::ContextItem();
item.tokens = size;
item.input = on_device_model::mojom::Input::New();
item.input->pieces = {ml::Token::kSystem, text};
return item;
}
// Convert a ml::Token to a string for expectation matching.
const char* FormatToken(ml::Token token) {
switch (token) {
case ml::Token::kSystem:
return "S: ";
case ml::Token::kUser:
return "U: ";
case ml::Token::kModel:
return "M: ";
default:
NOTREACHED();
}
}
// Convert an Input to a string for expectation matching.
std::string FormatInput(const on_device_model::mojom::Input& input) {
std::string str;
for (const auto& piece : input.pieces) {
if (std::holds_alternative<ml::Token>(piece)) {
str += FormatToken(std::get<ml::Token>(piece));
} else if (std::holds_alternative<std::string>(piece)) {
str += std::get<std::string>(piece);
} else if (std::holds_alternative<SkBitmap>(piece)) {
str += "<image>";
} else if (std::holds_alternative<ml::AudioBuffer>(piece)) {
str += "<audio>";
}
}
return str;
}
// Convert a Context to string for expectation matching.
std::string GetContextString(AILanguageModel::Context& ctx) {
return FormatInput(*ctx.GetNonInitialPrompts());
}
class TestStreamingResponder : public blink::mojom::ModelStreamingResponder {
public:
TestStreamingResponder() = default;
~TestStreamingResponder() override = default;
mojo::PendingRemote<blink::mojom::ModelStreamingResponder> BindRemote() {
return receiver_.BindNewPipeAndPassRemote();
}
// Returns true on successful completion and false on error.
bool WaitForCompletion() {
run_loop_.Run();
return !error_status_.has_value();
}
void WaitForQuotaOverflow() { quota_overflow_run_loop_.Run(); }
blink::mojom::ModelStreamingResponseStatus error_status() const {
EXPECT_TRUE(error_status_.has_value());
return *error_status_;
}
blink::mojom::QuotaErrorInfo quota_error_info() const {
return *quota_error_info_;
}
const std::vector<std::string> responses() const { return responses_; }
uint64_t current_tokens() const { return current_tokens_; }
private:
// blink::mojom::ModelStreamingResponder:
void OnError(blink::mojom::ModelStreamingResponseStatus status,
blink::mojom::QuotaErrorInfoPtr quota_error_info) override {
error_status_ = status;
quota_error_info_ = std::move(quota_error_info);
run_loop_.Quit();
}
void OnStreaming(const std::string& text) override {
responses_.push_back(text);
}
void OnCompletion(
blink::mojom::ModelExecutionContextInfoPtr context_info) override {
current_tokens_ = context_info->current_tokens;
run_loop_.Quit();
}
void OnQuotaOverflow() override { quota_overflow_run_loop_.Quit(); }
std::optional<blink::mojom::ModelStreamingResponseStatus> error_status_;
blink::mojom::QuotaErrorInfoPtr quota_error_info_;
std::vector<std::string> responses_;
uint64_t current_tokens_ = 0;
base::RunLoop run_loop_;
base::RunLoop quota_overflow_run_loop_;
mojo::Receiver<blink::mojom::ModelStreamingResponder> receiver_{this};
};
optimization_guide::proto::OnDeviceModelExecutionFeatureConfig CreateConfig() {
optimization_guide::proto::OnDeviceModelExecutionFeatureConfig config;
config.set_can_skip_text_safety(true);
optimization_guide::proto::SamplingParams sampling_params;
sampling_params.set_top_k(kTestMaxTopK);
sampling_params.set_temperature(kTestMaxTemperature);
*config.mutable_sampling_params() = sampling_params;
config.mutable_input_config()->set_max_context_tokens(kTestMaxTokens);
optimization_guide::proto::PromptApiMetadata metadata;
*metadata.mutable_max_sampling_params() = sampling_params;
*config.mutable_feature_metadata() =
optimization_guide::AnyWrapProto(metadata);
config.set_feature(optimization_guide::proto::ModelExecutionFeature::
MODEL_EXECUTION_FEATURE_PROMPT_API);
return config;
}
// Formats responses to match what the fake on device model service will return.
// The fake service keeps track of all previous inputs to a session, and will
// spit them all back out during a Generate() call. This gets a bit complicated
// for the language model, which also adds back the output as input to the
// session. An example language model session using the default behavior of the
// fake service would look something like this:
// - s1.Prompt("foo")
// - Adds "UfooEM" to the session
// - Gets output of ["UfooEM"] from fake service
// - Adds "UfooEME" to the session (fake response + end token)
// - s1.Prompt("bar")
// - Adds "UbarEM" to the session
// - Gets output of ["UfooEM", "UfooEME", "UbarEM"].
// - Adds "UfooEMUfooEMEUbarEM"
// (concatenated output from fake service) to the session
// This behavior verifies the correct inputs and outputs are being returned from
// the model, and this helper makes it easier to construct these expectations.
// TODO(crbug.com/415808003): Simplify this in the fake service.
std::vector<std::string> FormatResponses(
const std::vector<std::string>& responses) {
std::vector<std::string> formatted;
std::string last_output;
for (const std::string& response : responses) {
if (!last_output.empty()) {
formatted.push_back(last_output + "E");
last_output += formatted.back();
}
formatted.push_back(response);
last_output += formatted.back();
}
return formatted;
}
class AILanguageModelTest : public AITestUtils::AITestBase {
public:
AILanguageModelTest()
: fake_broker_(optimization_guide::FakeAdaptationAsset(
{.config = CreateConfig()})) {
scoped_feature_list_.InitWithFeaturesAndParameters(
{{blink::features::kAIPromptAPIMultimodalInput, {}},
{features::kAILanguageModelOverrideConfiguration,
{{"ai_language_model_output_buffer", "100"}}},
{optimization_guide::features::kOptimizationGuideOnDeviceModel,
{{"on_device_model_max_tokens_for_execute", "0"},
{"on_device_model_max_tokens_for_output", "0"},
{"on_device_model_max_tokens_for_context",
base::NumberToString(kTestModelMaxTokens)}}}},
{});
// Reset the adaptation to make sure the feature params get picked up.
fake_broker_.UpdateModelAdaptation(
optimization_guide::FakeAdaptationAsset({.config = CreateConfig()}));
}
void SetUp() override {
AITestBase::SetUp();
SetupMockOptimizationGuideKeyedService();
ai_manager_ =
std::make_unique<AIManager>(main_rfh()->GetBrowserContext(),
&component_update_service_, main_rfh());
}
protected:
void SetupMockOptimizationGuideKeyedService() override {
AITestUtils::AITestBase::SetupMockOptimizationGuideKeyedService();
ON_CALL(*mock_optimization_guide_keyed_service_, CreateModelBrokerClient())
.WillByDefault([&]() {
return std::make_unique<optimization_guide::ModelBrokerClient>(
fake_broker_.BindAndPassRemote(),
optimization_guide::CreateSessionArgs(nullptr, {}));
});
ON_CALL(*mock_optimization_guide_keyed_service_,
GetSamplingParamsConfig(
optimization_guide::ModelBasedCapabilityKey::kPromptApi))
.WillByDefault([]() {
return optimization_guide::SamplingParamsConfig{
.default_top_k = kTestDefaultTopK,
.default_temperature = kTestDefaultTemperature};
});
ON_CALL(*mock_optimization_guide_keyed_service_,
GetFeatureMetadata(
optimization_guide::ModelBasedCapabilityKey::kPromptApi))
.WillByDefault([]() { return CreateConfig().feature_metadata(); });
ON_CALL(*mock_optimization_guide_keyed_service_, GetOnDeviceCapabilities())
.WillByDefault(Return(on_device_model::Capabilities(
{on_device_model::CapabilityFlags::kImageInput,
on_device_model::CapabilityFlags::kAudioInput})));
ON_CALL(*mock_optimization_guide_keyed_service_,
GetOnDeviceModelEligibility(_))
.WillByDefault(Return(
optimization_guide::OnDeviceModelEligibilityReason::kSuccess));
}
mojo::Remote<blink::mojom::AILanguageModel> CreateSession(
blink::mojom::AILanguageModelCreateOptionsPtr options =
blink::mojom::AILanguageModelCreateOptions::New()) {
base::test::TestFuture<mojo::Remote<blink::mojom::AILanguageModel>> future;
AITestUtils::MockCreateLanguageModelClient language_model_client;
EXPECT_CALL(language_model_client, OnResult(_, _))
.WillOnce([&](mojo::PendingRemote<blink::mojom::AILanguageModel>
language_model,
blink::mojom::AILanguageModelInstanceInfoPtr info) {
future.SetValue(mojo::Remote<blink::mojom::AILanguageModel>(
std::move(language_model)));
});
GetAIManagerRemote()->CreateLanguageModel(
language_model_client.BindNewPipeAndPassRemote(), std::move(options));
return future.Take();
}
std::vector<std::string> Prompt(
blink::mojom::AILanguageModel& model,
std::vector<blink::mojom::AILanguageModelPromptPtr> input,
on_device_model::mojom::ResponseConstraintPtr constraint = nullptr) {
TestStreamingResponder responder;
model.Prompt(std::move(input), std::move(constraint),
responder.BindRemote());
EXPECT_TRUE(responder.WaitForCompletion());
return responder.responses();
}
void Append(blink::mojom::AILanguageModel& model,
std::vector<blink::mojom::AILanguageModelPromptPtr> input) {
TestStreamingResponder responder;
model.Append(std::move(input), responder.BindRemote());
EXPECT_TRUE(responder.WaitForCompletion());
}
mojo::Remote<blink::mojom::AILanguageModel> Fork(
blink::mojom::AILanguageModel& model) {
base::test::TestFuture<mojo::Remote<blink::mojom::AILanguageModel>> future;
AITestUtils::MockCreateLanguageModelClient language_model_client;
EXPECT_CALL(language_model_client, OnResult(_, _))
.WillOnce([&](mojo::PendingRemote<blink::mojom::AILanguageModel>
language_model,
blink::mojom::AILanguageModelInstanceInfoPtr info) {
future.SetValue(mojo::Remote<blink::mojom::AILanguageModel>(
std::move(language_model)));
});
model.Fork(language_model_client.BindNewPipeAndPassRemote());
return future.Take();
}
protected:
optimization_guide::FakeModelBroker fake_broker_;
base::test::ScopedFeatureList scoped_feature_list_;
};
TEST_F(AILanguageModelTest, Prompt) {
auto session = CreateSession();
EXPECT_THAT(Prompt(*session, MakeInput("foo")),
ElementsAreArray(FormatResponses({"UfooEM"})));
}
TEST_F(AILanguageModelTest, MultiplePrompts) {
auto session = CreateSession();
EXPECT_THAT(Prompt(*session, MakeInput("foo")),
ElementsAreArray(FormatResponses({"UfooEM"})));
EXPECT_THAT(Prompt(*session, MakeInput("bar")),
ElementsAreArray(FormatResponses({"UfooEM", "UbarEM"})));
EXPECT_THAT(
Prompt(*session, MakeInput("baz")),
ElementsAreArray(FormatResponses({"UfooEM", "UbarEM", "UbazEM"})));
}
TEST_F(AILanguageModelTest, PromptMultipleContents) {
auto session = CreateSession();
EXPECT_THAT(Prompt(*session, MakeInput({"foo", "bar"})),
ElementsAreArray(FormatResponses({"UfoobarEM"})));
}
TEST_F(AILanguageModelTest, Append) {
auto session = CreateSession();
Append(*session, MakeInput("foo"));
EXPECT_THAT(Prompt(*session, MakeInput("bar")),
ElementsAre("UfooE", "UbarEM"));
}
TEST_F(AILanguageModelTest, AppendMultipleContents) {
auto session = CreateSession();
Append(*session, MakeInput({"foo", "bar"}));
EXPECT_THAT(Prompt(*session, MakeInput("baz")),
ElementsAre("UfoobarE", "UbazEM"));
}
TEST_F(AILanguageModelTest, PromptTokenCounts) {
fake_broker_.settings().set_execute_result({"hi"});
auto session = CreateSession();
std::string expected_tokens = "UfooEMhiE";
{
TestStreamingResponder responder;
session->Prompt(MakeInput("foo"), nullptr, responder.BindRemote());
EXPECT_TRUE(responder.WaitForCompletion());
EXPECT_EQ(responder.current_tokens(), expected_tokens.size());
}
expected_tokens += "UbarEMhiE";
{
TestStreamingResponder responder;
session->Prompt(MakeInput("bar"), nullptr, responder.BindRemote());
EXPECT_TRUE(responder.WaitForCompletion());
EXPECT_EQ(responder.current_tokens(), expected_tokens.size());
}
auto fork = Fork(*session);
expected_tokens += "UbazEMhiE";
{
TestStreamingResponder responder;
fork->Prompt(MakeInput("baz"), nullptr, responder.BindRemote());
EXPECT_TRUE(responder.WaitForCompletion());
EXPECT_EQ(responder.current_tokens(), expected_tokens.size());
}
}
TEST_F(AILanguageModelTest, AppendTokenCounts) {
auto session = CreateSession();
std::string expected_tokens = "UfooE";
{
TestStreamingResponder responder;
session->Append(MakeInput("foo"), responder.BindRemote());
EXPECT_TRUE(responder.WaitForCompletion());
EXPECT_EQ(responder.current_tokens(), expected_tokens.size());
}
expected_tokens += "UbarE";
{
TestStreamingResponder responder;
session->Append(MakeInput("bar"), responder.BindRemote());
EXPECT_TRUE(responder.WaitForCompletion());
EXPECT_EQ(responder.current_tokens(), expected_tokens.size());
}
auto fork = Fork(*session);
expected_tokens += "UbazE";
{
TestStreamingResponder responder;
fork->Append(MakeInput("baz"), responder.BindRemote());
EXPECT_TRUE(responder.WaitForCompletion());
EXPECT_EQ(responder.current_tokens(), expected_tokens.size());
}
}
TEST_F(AILanguageModelTest, Roles) {
auto session = CreateSession();
std::vector<blink::mojom::AILanguageModelPromptPtr> prompts;
prompts.push_back(MakePrompt(Role::kUser, "user"));
prompts.push_back(MakePrompt(Role::kSystem, "system"));
prompts.push_back(MakePrompt(Role::kAssistant, "model"));
EXPECT_THAT(Prompt(*session, std::move(prompts)),
ElementsAreArray(FormatResponses({"UuserESsystemEMmodelEM"})));
}
TEST_F(AILanguageModelTest, Fork) {
auto session = CreateSession();
auto fork1 = Fork(*session);
EXPECT_THAT(Prompt(*session, MakeInput("foo")),
ElementsAreArray(FormatResponses({"UfooEM"})));
auto fork2 = Fork(*session);
EXPECT_THAT(Prompt(*session, MakeInput("bar")),
ElementsAreArray(FormatResponses({"UfooEM", "UbarEM"})));
auto fork3 = Fork(*session);
EXPECT_THAT(Prompt(*fork1, MakeInput("fork")),
ElementsAreArray(FormatResponses({"UforkEM"})));
EXPECT_THAT(Prompt(*fork2, MakeInput("fork")),
ElementsAreArray(FormatResponses({"UfooEM", "UforkEM"})));
auto fork4 = Fork(*fork2);
EXPECT_THAT(
Prompt(*fork3, MakeInput("fork")),
ElementsAreArray(FormatResponses({"UfooEM", "UbarEM", "UforkEM"})));
EXPECT_THAT(
Prompt(*session, MakeInput("baz")),
ElementsAreArray(FormatResponses({"UfooEM", "UbarEM", "UbazEM"})));
EXPECT_THAT(
Prompt(*fork4, MakeInput("more")),
ElementsAreArray(FormatResponses({"UfooEM", "UforkEM", "UmoreEM"})));
}
TEST_F(AILanguageModelTest, SamplingParams) {
auto sampling_params = blink::mojom::AILanguageModelSamplingParams::New();
sampling_params->top_k = 2;
sampling_params->temperature = 1.0;
auto options = blink::mojom::AILanguageModelCreateOptions::New();
options->sampling_params = std::move(sampling_params);
auto session = CreateSession(std::move(options));
auto fork = Fork(*session);
EXPECT_THAT(Prompt(*session, MakeInput("foo")),
ElementsAre("UfooEM", "TopK: 2, Temp: 1"));
EXPECT_THAT(Prompt(*fork, MakeInput("bar")),
ElementsAre("UbarEM", "TopK: 2, Temp: 1"));
}
TEST_F(AILanguageModelTest, SamplingParamsTopKOutOfRange) {
auto sampling_params = blink::mojom::AILanguageModelSamplingParams::New();
sampling_params->top_k = 0;
sampling_params->temperature = 1.5f;
auto options = blink::mojom::AILanguageModelCreateOptions::New();
options->sampling_params = std::move(sampling_params);
auto session = CreateSession(std::move(options));
EXPECT_THAT(Prompt(*session, MakeInput("foo")),
ElementsAre("UfooEM", "TopK: 1, Temp: 1.5"));
}
TEST_F(AILanguageModelTest, SamplingParamsTemperatureOutOfRange) {
auto sampling_params = blink::mojom::AILanguageModelSamplingParams::New();
sampling_params->top_k = 2;
sampling_params->temperature = -1.0f;
auto options = blink::mojom::AILanguageModelCreateOptions::New();
options->sampling_params = std::move(sampling_params);
auto session = CreateSession(std::move(options));
EXPECT_THAT(Prompt(*session, MakeInput("foo")),
ElementsAre("UfooEM", "TopK: 2, Temp: 0"));
}
TEST_F(AILanguageModelTest, MaxSamplingParams) {
auto sampling_params = blink::mojom::AILanguageModelSamplingParams::New();
sampling_params->top_k = kTestMaxTopK + 1;
sampling_params->temperature = kTestMaxTemperature + 1;
auto options = blink::mojom::AILanguageModelCreateOptions::New();
options->sampling_params = std::move(sampling_params);
auto session = CreateSession(std::move(options));
EXPECT_THAT(Prompt(*session, MakeInput("foo")),
ElementsAre("UfooEM", "TopK: 5, Temp: 1.5"));
}
TEST_F(AILanguageModelTest, InitialPrompts) {
auto options = blink::mojom::AILanguageModelCreateOptions::New();
options->initial_prompts.push_back(MakePrompt(Role::kSystem, "hi"));
options->initial_prompts.push_back(MakePrompt(Role::kUser, "bye"));
auto session = CreateSession(std::move(options));
EXPECT_THAT(Prompt(*session, MakeInput("foo")),
ElementsAre("ShiEUbyeE", "UfooEM"));
}
TEST_F(AILanguageModelTest, InitialPromptsMultipleContents) {
auto options = blink::mojom::AILanguageModelCreateOptions::New();
options->initial_prompts = MakeInput({"foo", "bar"});
auto session = CreateSession(std::move(options));
EXPECT_THAT(Prompt(*session, MakeInput("baz")),
ElementsAre("UfoobarE", "UbazEM"));
}
TEST_F(AILanguageModelTest, InitialPromptsInstanceInfo) {
base::test::TestFuture<blink::mojom::AILanguageModelInstanceInfoPtr> future;
AITestUtils::MockCreateLanguageModelClient language_model_client;
EXPECT_CALL(language_model_client, OnResult(_, _))
.WillOnce(
[&](mojo::PendingRemote<blink::mojom::AILanguageModel> language_model,
blink::mojom::AILanguageModelInstanceInfoPtr info) {
future.SetValue(std::move(info));
});
auto options = blink::mojom::AILanguageModelCreateOptions::New();
options->initial_prompts.push_back(MakePrompt(Role::kSystem, "hi"));
GetAIManagerRemote()->CreateLanguageModel(
language_model_client.BindNewPipeAndPassRemote(), std::move(options));
auto info = future.Take();
EXPECT_EQ(info->input_quota, kTestMaxTokens);
EXPECT_EQ(info->input_usage, std::strlen("ShiE"));
}
TEST_F(AILanguageModelTest, InitialPromptsTooLarge) {
base::test::TestFuture<blink::mojom::AIManagerCreateClientError> error_future;
base::test::TestFuture<blink::mojom::QuotaErrorInfoPtr>
quota_error_info_future;
AITestUtils::MockCreateLanguageModelClient language_model_client;
EXPECT_CALL(language_model_client, OnError(_, _))
.WillOnce([&](auto error, auto quota_error_info) {
error_future.SetValue(error);
quota_error_info_future.SetValue(std::move(quota_error_info));
});
auto options = blink::mojom::AILanguageModelCreateOptions::New();
options->initial_prompts.push_back(
MakePrompt(Role::kSystem, std::string(kTestMaxTokens + 1, 'a')));
GetAIManagerRemote()->CreateLanguageModel(
language_model_client.BindNewPipeAndPassRemote(), std::move(options));
EXPECT_EQ(error_future.Take(),
blink::mojom::AIManagerCreateClientError::kInitialInputTooLarge);
auto quota_error_info = quota_error_info_future.Take();
ASSERT_TRUE(quota_error_info);
ASSERT_GT(quota_error_info->requested, kTestMaxTokens);
ASSERT_EQ(quota_error_info->quota, kTestMaxTokens);
}
TEST_F(AILanguageModelTest, InputTooLarge) {
auto session = CreateSession();
TestStreamingResponder responder;
session->Prompt(MakeInput(std::string(kTestMaxTokens + 1, 'a')), nullptr,
responder.BindRemote());
EXPECT_FALSE(responder.WaitForCompletion());
EXPECT_EQ(responder.error_status(),
blink::mojom::ModelStreamingResponseStatus::kErrorInputTooLarge);
ASSERT_GT(responder.quota_error_info().requested, kTestMaxTokens);
ASSERT_EQ(responder.quota_error_info().quota, kTestMaxTokens);
}
TEST_F(AILanguageModelTest, QuotaOverflowOnPromptInput) {
// Set the execute result so the long prompt is not echoed back as the
// response.
fake_broker_.settings().set_execute_result({"hi"});
// Initial prompt should be kept on overflow.
auto options = blink::mojom::AILanguageModelCreateOptions::New();
options->initial_prompts.push_back(MakePrompt(Role::kSystem, "init"));
auto session = CreateSession(std::move(options));
// Set a prompt that is close to max token length. This string should be
// stripped from the prompt history, while the initial prompts and
// `long_prompt` will be kept.
EXPECT_THAT(
Prompt(*session, MakeInput(std::string(kTestMaxTokens - 20, 'a'))),
ElementsAre("hi"));
EXPECT_THAT(Prompt(*session, MakeInput("foo")), ElementsAre("hi"));
// Clear execute result so we can verify the input by checking the response.
fake_broker_.settings().set_execute_result({});
std::string long_prompt(kTestMaxTokens / 3, 'a');
TestStreamingResponder responder;
session->Prompt(MakeInput(long_prompt), nullptr, responder.BindRemote());
responder.WaitForQuotaOverflow();
EXPECT_TRUE(responder.WaitForCompletion());
// Response should include input/output of previous prompt with the original
// long prompt not present.
EXPECT_THAT(responder.responses(),
ElementsAre("SinitE", "UfooEMhiE", "U" + long_prompt + "EM"));
}
TEST_F(AILanguageModelTest, QuotaOverflowOnAppend) {
// Initial prompt should be kept on overflow.
auto options = blink::mojom::AILanguageModelCreateOptions::New();
options->initial_prompts.push_back(MakePrompt(Role::kSystem, "init"));
auto session = CreateSession(std::move(options));
// Set a prompt that is close to max token length.
Append(*session, MakeInput(std::string(kTestMaxTokens - 20, 'a')));
std::string long_prompt(kTestMaxTokens / 3, 'a');
TestStreamingResponder responder;
session->Append(MakeInput(long_prompt), responder.BindRemote());
responder.WaitForQuotaOverflow();
EXPECT_TRUE(responder.WaitForCompletion());
EXPECT_THAT(Prompt(*session, MakeInput("foo")),
ElementsAre("SinitE", "U" + long_prompt + "E", "UfooEM"));
}
TEST_F(AILanguageModelTest, QuotaOverflowOnOutput) {
// Set the execute result so the long prompt is not echoed back as the
// response.
fake_broker_.settings().set_execute_result({"hi"});
auto session = CreateSession();
// Set a prompt that is close to max token length. This string should be
// stripped from the prompt history, while the next prompt's input and output
// will be kept.
EXPECT_THAT(
Prompt(*session, MakeInput(std::string(kTestMaxTokens - 20, 'a'))),
ElementsAre("hi"));
// Reset result to a long response that should cause overflow. `long_response`
// should be kept, but the previous prompt will be removed.
std::string long_response(kTestMaxTokens / 3, 'a');
fake_broker_.settings().set_execute_result({long_response});
TestStreamingResponder responder;
session->Prompt(MakeInput("foo"), nullptr, responder.BindRemote());
responder.WaitForQuotaOverflow();
EXPECT_TRUE(responder.WaitForCompletion());
EXPECT_THAT(responder.responses(), ElementsAre(long_response));
// Verify the original long response was removed. The response should contain:
// - "foo"+long_response from the previous prompt call
// - "bar" from the current prompt call
fake_broker_.settings().set_execute_result({});
EXPECT_THAT(Prompt(*session, MakeInput("bar")),
ElementsAre("UfooEM" + long_response + "E", "UbarEM"));
}
TEST_F(AILanguageModelTest, OutputOverflowsModelMaxTokens) {
auto session = CreateSession();
// Add a prompt to start, this should be kept after the overflow.
EXPECT_THAT(Prompt(*session, MakeInput("foo")),
ElementsAreArray(FormatResponses({"UfooEM"})));
// Set a fake response that will overrun the max model tokens.
fake_broker_.settings().set_execute_result(
{std::string(kTestModelMaxTokens, 'a')});
TestStreamingResponder responder;
session->Prompt(MakeInput("bar"), nullptr, responder.BindRemote());
EXPECT_FALSE(responder.WaitForCompletion());
EXPECT_EQ(responder.error_status(),
blink::mojom::ModelStreamingResponseStatus::kErrorGenericFailure);
// Now prompt again, the failed prompt should not be present.
fake_broker_.settings().set_execute_result({});
EXPECT_THAT(Prompt(*session, MakeInput("baz")),
ElementsAreArray(FormatResponses({"UfooEM", "UbazEM"})));
}
TEST_F(AILanguageModelTest, OutputOverflowsAdditionalBuffer) {
base::test::ScopedFeatureList scoped_feature_list;
// Use a smaller output buffer to test the value is used correctly.
scoped_feature_list.InitWithFeaturesAndParameters(
{{features::kAILanguageModelOverrideConfiguration,
{{"ai_language_model_output_buffer", "10"}}}},
{});
auto session = CreateSession();
// Append an input that is just below max tokens, the next output should
// overflow the buffer and cause an error.
Append(*session, MakeInput(std::string(kTestMaxTokens - 5, 'a')));
// Create a response that will be just larger than the output buffer.
fake_broker_.settings().set_execute_result({std::string(15, 'a')});
TestStreamingResponder responder;
session->Prompt(MakeInput(""), nullptr, responder.BindRemote());
EXPECT_FALSE(responder.WaitForCompletion());
EXPECT_EQ(responder.error_status(),
blink::mojom::ModelStreamingResponseStatus::kErrorGenericFailure);
}
TEST_F(AILanguageModelTest, OutputOverflowsContextMaxTokens) {
auto session = CreateSession();
// Add a prompt to start, this should be kept after the overflow.
EXPECT_THAT(Prompt(*session, MakeInput("foo")),
ElementsAreArray(FormatResponses({"UfooEM"})));
// Set a fake response that will overflow the maximum context size.
fake_broker_.settings().set_execute_result(
{std::string(kTestMaxTokens, 'a')});
TestStreamingResponder responder;
session->Prompt(MakeInput("bar"), nullptr, responder.BindRemote());
EXPECT_FALSE(responder.WaitForCompletion());
EXPECT_EQ(responder.error_status(),
blink::mojom::ModelStreamingResponseStatus::kErrorGenericFailure);
// Now prompt again, the failed prompt should not be present.
fake_broker_.settings().set_execute_result({});
EXPECT_THAT(Prompt(*session, MakeInput("baz")),
ElementsAreArray(FormatResponses({"UfooEM", "UbazEM"})));
}
TEST_F(AILanguageModelTest, Destroy) {
auto session = CreateSession();
base::RunLoop run_loop;
session.set_disconnect_handler(run_loop.QuitClosure());
EXPECT_THAT(Prompt(*session, MakeInput("foo")),
ElementsAreArray(FormatResponses({"UfooEM"})));
session->Destroy();
run_loop.Run();
}
TEST_F(AILanguageModelTest, DestroyWithActivePrompt) {
fake_broker_.settings().set_execute_delay(base::Minutes(1));
auto session = CreateSession();
base::RunLoop run_loop;
session.set_disconnect_handler(run_loop.QuitClosure());
TestStreamingResponder responder;
session->Prompt(MakeInput("foo"), nullptr, responder.BindRemote());
session->Destroy();
run_loop.Run();
EXPECT_FALSE(responder.WaitForCompletion());
}
struct LanguageParams {
std::string enabled_languages;
std::vector<std::string> expected_input_language;
bool expect_error;
};
std::ostream& operator<<(std::ostream& os, const LanguageParams& params) {
// Print the desired data members of params to the output stream (os)
os << "enabled_languages:" << params.enabled_languages
<< ", expected_input_language:"
<< base::JoinString(params.expected_input_language, ", ")
<< ", expect_error:" << params.expect_error;
return os; // Return the ostream reference for chaining
}
class AILanguageModelTestWithLanguageParams
: public AILanguageModelTest,
public testing::WithParamInterface<LanguageParams> {
public:
AILanguageModelTestWithLanguageParams() {
features_.InitWithFeaturesAndParameters(
{{blink::features::kAIPromptAPI,
{{"langs", GetParam().enabled_languages}}}},
{});
}
base::test::ScopedFeatureList features_;
};
TEST_P(AILanguageModelTestWithLanguageParams, PromptWithEnabledLanguages) {
base::test::TestFuture<blink::mojom::AIManagerCreateClientError> error_future;
base::test::TestFuture<mojo::Remote<blink::mojom::AILanguageModel>>
result_future;
AITestUtils::MockCreateLanguageModelClient language_model_client;
if (GetParam().expect_error) {
EXPECT_CALL(language_model_client, OnResult(_, _)).Times(0);
EXPECT_CALL(language_model_client, OnError(_, _))
.WillOnce([&](auto error, auto quota_error_info) {
error_future.SetValue(error);
});
} else {
EXPECT_CALL(language_model_client, OnResult(_, _))
.WillOnce([&](mojo::PendingRemote<blink::mojom::AILanguageModel>
language_model,
blink::mojom::AILanguageModelInstanceInfoPtr info) {
result_future.SetValue(mojo::Remote<blink::mojom::AILanguageModel>(
std::move(language_model)));
});
EXPECT_CALL(language_model_client, OnError(_, _)).Times(0);
}
auto expected_input = blink::mojom::AILanguageModelExpected::New();
expected_input->languages.emplace();
for (const auto& language : GetParam().expected_input_language) {
expected_input->languages->push_back(
blink::mojom::AILanguageCode::New(language));
}
auto options = blink::mojom::AILanguageModelCreateOptions::New();
options->expected_inputs.emplace();
options->expected_inputs->push_back(std::move(expected_input));
GetAIManagerRemote()->CreateLanguageModel(
language_model_client.BindNewPipeAndPassRemote(), std::move(options));
if (GetParam().expect_error) {
EXPECT_EQ(error_future.Take(),
blink::mojom::AIManagerCreateClientError::kUnsupportedLanguage);
} else {
EXPECT_TRUE(result_future.Wait());
}
}
INSTANTIATE_TEST_SUITE_P(
/* no prefix */,
AILanguageModelTestWithLanguageParams,
::testing::Values(LanguageParams{"en,es,ja", {"en"}, false},
LanguageParams{"*", {"en"}, false},
LanguageParams{"*", {"fr"}, true},
LanguageParams{"", {"en"}, true},
LanguageParams{"", {"fr"}, true},
LanguageParams{"es,ja", {"es"}, false},
LanguageParams{"en,es,ja", {"ja"}, false},
LanguageParams{"en,es,ja", {"ja", "es"}, false},
LanguageParams{"en,es,ja", {"ja", "fr"}, true}));
TEST_F(AILanguageModelTest, UnsupportedInputCapability) {
ON_CALL(*mock_optimization_guide_keyed_service_, GetOnDeviceCapabilities())
.WillByDefault(Return(on_device_model::Capabilities()));
base::test::TestFuture<blink::mojom::AIManagerCreateClientError> future;
AITestUtils::MockCreateLanguageModelClient language_model_client;
EXPECT_CALL(language_model_client, OnError(_, _))
.WillOnce(
[&](auto error, auto quota_error_info) { future.SetValue(error); });
auto expected_input = blink::mojom::AILanguageModelExpected::New();
expected_input->type = blink::mojom::AILanguageModelPromptType::kImage;
auto options = blink::mojom::AILanguageModelCreateOptions::New();
options->expected_inputs.emplace();
options->expected_inputs->push_back(std::move(expected_input));
GetAIManagerRemote()->CreateLanguageModel(
language_model_client.BindNewPipeAndPassRemote(), std::move(options));
EXPECT_EQ(future.Take(),
blink::mojom::AIManagerCreateClientError::kUnableToCreateSession);
}
TEST_F(AILanguageModelTest, UnsupportedOutputCapability) {
ON_CALL(*mock_optimization_guide_keyed_service_, GetOnDeviceCapabilities())
.WillByDefault(Return(on_device_model::Capabilities()));
base::test::TestFuture<blink::mojom::AIManagerCreateClientError> future;
AITestUtils::MockCreateLanguageModelClient language_model_client;
EXPECT_CALL(language_model_client, OnError(_, _))
.WillOnce(
[&](auto error, auto quota_error_info) { future.SetValue(error); });
auto expected_output = blink::mojom::AILanguageModelExpected::New();
expected_output->type = blink::mojom::AILanguageModelPromptType::kImage;
auto options = blink::mojom::AILanguageModelCreateOptions::New();
options->expected_outputs.emplace();
options->expected_outputs->push_back(std::move(expected_output));
GetAIManagerRemote()->CreateLanguageModel(
language_model_client.BindNewPipeAndPassRemote(), std::move(options));
EXPECT_EQ(future.Take(),
blink::mojom::AIManagerCreateClientError::kUnableToCreateSession);
}
TEST_F(AILanguageModelTest, MultimodalInputImageNotSpecified) {
auto audio_input = blink::mojom::AILanguageModelExpected::New();
audio_input->type = blink::mojom::AILanguageModelPromptType::kAudio;
auto options = blink::mojom::AILanguageModelCreateOptions::New();
options->expected_inputs.emplace();
options->expected_inputs->push_back(std::move(audio_input));
auto session = CreateSession(std::move(options));
auto make_input = [] {
std::vector<blink::mojom::AILanguageModelPromptPtr> input =
MakeInput("foo");
input.push_back(blink::mojom::AILanguageModelPrompt::New(
Role::kUser,
ToVector(blink::mojom::AILanguageModelPromptContent::NewBitmap(
CreateTestBitmap(10, 10))),
/*is_prefix=*/false));
return input;
};
{
TestStreamingResponder responder;
session->Prompt(make_input(), nullptr, responder.BindRemote());
EXPECT_FALSE(responder.WaitForCompletion());
EXPECT_EQ(responder.error_status(),
blink::mojom::ModelStreamingResponseStatus::kErrorInvalidRequest);
}
{
TestStreamingResponder responder;
session->Append(make_input(), responder.BindRemote());
EXPECT_FALSE(responder.WaitForCompletion());
EXPECT_EQ(responder.error_status(),
blink::mojom::ModelStreamingResponseStatus::kErrorInvalidRequest);
}
base::test::TestFuture<std::optional<uint32_t>> measure_future;
session->MeasureInputUsage(make_input(), measure_future.GetCallback());
EXPECT_EQ(measure_future.Get(), std::nullopt);
}
TEST_F(AILanguageModelTest, MultimodalInputAudioNotSpecified) {
auto image_input = blink::mojom::AILanguageModelExpected::New();
image_input->type = blink::mojom::AILanguageModelPromptType::kImage;
auto options = blink::mojom::AILanguageModelCreateOptions::New();
options->expected_inputs.emplace();
options->expected_inputs->push_back(std::move(image_input));
auto session = CreateSession(std::move(options));
auto make_input = [] {
std::vector<blink::mojom::AILanguageModelPromptPtr> input =
MakeInput("foo");
input.push_back(blink::mojom::AILanguageModelPrompt::New(
Role::kUser,
ToVector(blink::mojom::AILanguageModelPromptContent::NewAudio(
CreateTestAudio())),
/*is_prefix=*/false));
return input;
};
{
TestStreamingResponder responder;
session->Prompt(make_input(), nullptr, responder.BindRemote());
EXPECT_FALSE(responder.WaitForCompletion());
EXPECT_EQ(responder.error_status(),
blink::mojom::ModelStreamingResponseStatus::kErrorInvalidRequest);
}
{
TestStreamingResponder responder;
session->Append(make_input(), responder.BindRemote());
EXPECT_FALSE(responder.WaitForCompletion());
EXPECT_EQ(responder.error_status(),
blink::mojom::ModelStreamingResponseStatus::kErrorInvalidRequest);
}
base::test::TestFuture<std::optional<uint32_t>> measure_future;
session->MeasureInputUsage(make_input(), measure_future.GetCallback());
EXPECT_EQ(measure_future.Get(), std::nullopt);
}
TEST_F(AILanguageModelTest, MultimodalInput) {
auto audio_input = blink::mojom::AILanguageModelExpected::New();
audio_input->type = blink::mojom::AILanguageModelPromptType::kAudio;
auto image_input = blink::mojom::AILanguageModelExpected::New();
image_input->type = blink::mojom::AILanguageModelPromptType::kImage;
auto options = blink::mojom::AILanguageModelCreateOptions::New();
options->expected_inputs.emplace();
options->expected_inputs->push_back(std::move(audio_input));
options->expected_inputs->push_back(std::move(image_input));
auto session = CreateSession(std::move(options));
std::vector<blink::mojom::AILanguageModelPromptPtr> input = MakeInput("foo");
input.push_back(blink::mojom::AILanguageModelPrompt::New(
Role::kUser,
ToVector(blink::mojom::AILanguageModelPromptContent::NewBitmap(
CreateTestBitmap(10, 10))),
/*is_prefix=*/false));
input.push_back(blink::mojom::AILanguageModelPrompt::New(
Role::kUser,
ToVector(blink::mojom::AILanguageModelPromptContent::NewAudio(
CreateTestAudio())),
/*is_prefix=*/false));
EXPECT_THAT(Prompt(*session, std::move(input)),
ElementsAreArray(FormatResponses({"UfooEU<image>EU<audio>EM"})));
}
TEST_F(AILanguageModelTest, ModelDownload) {
// This is the component id of the on device model. The `AIManager` sends
// updates for it to the `CreateMonitor`s.
std::string model_component_id =
component_updater::OptimizationGuideOnDeviceModelInstallerPolicy::
GetOnDeviceModelExtensionId();
AITestUtils::FakeComponent model_component(model_component_id,
kTestModelDownloadSize);
EXPECT_EQ(GetAIManagerDownloadProgressObserversSize(), 0u);
AITestUtils::FakeMonitor mock_monitor;
EXPECT_CALL(component_update_service_, GetComponentDetails(_, _))
.WillOnce(
[&](const std::string& id, component_updater::CrxUpdateItem* item) {
*item = model_component.CreateUpdateItem(
update_client::ComponentState::kNew, 0);
return true;
});
GetAIManagerRemote()->AddModelDownloadProgressObserver(
mock_monitor.BindNewPipeAndPassRemote());
ASSERT_TRUE(base::test::RunUntil(
[this] { return GetAIManagerDownloadProgressObserversSize() == 1u; }));
component_update_service_.SendUpdate(model_component.CreateUpdateItem(
update_client::ComponentState::kDownloading, kTestModelDownloadSize));
mock_monitor.ExpectReceivedNormalizedUpdate(0, kTestModelDownloadSize);
mock_monitor.ExpectReceivedNormalizedUpdate(kTestModelDownloadSize,
kTestModelDownloadSize);
}
TEST_F(AILanguageModelTest, MeasureInputUsage) {
auto session = CreateSession();
base::test::TestFuture<std::optional<uint32_t>> measure_future;
session->MeasureInputUsage(MakeInput("foo"), measure_future.GetCallback());
EXPECT_EQ(measure_future.Get(), std::string("UfooEM").size());
}
TEST_F(AILanguageModelTest, TextSafetyInitialPrompts) {
auto config = CreateConfig();
config.set_can_skip_text_safety(false);
fake_broker_.UpdateModelAdaptation(
optimization_guide::FakeAdaptationAsset({.config = config}));
auto safety_config = CreateSafetyConfig();
auto* check = safety_config.add_request_check();
check->mutable_input_template()->Add(
FieldSubstitution("%s", StringValueField()));
optimization_guide::FakeSafetyModelAsset safety_asset(
std::move(safety_config));
fake_broker_.UpdateSafetyModel(safety_asset.model_info());
base::test::TestFuture<blink::mojom::AIManagerCreateClientError> future;
AITestUtils::MockCreateLanguageModelClient language_model_client;
EXPECT_CALL(language_model_client, OnError(_, _))
.WillOnce(
[&](auto error, auto quota_error_info) { future.SetValue(error); });
auto options = blink::mojom::AILanguageModelCreateOptions::New();
options->initial_prompts.push_back(MakePrompt(Role::kSystem, "unsafe"));
GetAIManagerRemote()->CreateLanguageModel(
language_model_client.BindNewPipeAndPassRemote(), std::move(options));
EXPECT_EQ(future.Take(),
blink::mojom::AIManagerCreateClientError::kUnableToCreateSession);
}
TEST_F(AILanguageModelTest, TextSafetyInput) {
auto config = CreateConfig();
config.set_can_skip_text_safety(false);
fake_broker_.UpdateModelAdaptation(
optimization_guide::FakeAdaptationAsset({.config = config}));
auto safety_config = CreateSafetyConfig();
auto* check = safety_config.add_request_check();
check->mutable_input_template()->Add(
FieldSubstitution("%s", StringValueField()));
optimization_guide::FakeSafetyModelAsset safety_asset(
std::move(safety_config));
fake_broker_.UpdateSafetyModel(safety_asset.model_info());
fake_broker_.settings().set_execute_result({"hi"});
auto session = CreateSession();
EXPECT_THAT(Prompt(*session, MakeInput("safe")), ElementsAre("hi"));
// Fake text safety checker looks for the string "unsafe".
TestStreamingResponder responder;
session->Prompt(MakeInput("unsafe"), nullptr, responder.BindRemote());
EXPECT_FALSE(responder.WaitForCompletion());
EXPECT_EQ(responder.error_status(),
blink::mojom::ModelStreamingResponseStatus::kErrorFiltered);
}
TEST_F(AILanguageModelTest, TextSafetyOutput) {
auto config = CreateConfig();
config.set_can_skip_text_safety(false);
fake_broker_.UpdateModelAdaptation(
optimization_guide::FakeAdaptationAsset({.config = config}));
auto safety_config = CreateSafetyConfig();
auto* check = safety_config.mutable_raw_output_check();
check->mutable_input_template()->Add(
FieldSubstitution("%s", StringValueField()));
safety_config.mutable_partial_output_checks()->set_minimum_tokens(1000);
optimization_guide::FakeSafetyModelAsset safety_asset(
std::move(safety_config));
fake_broker_.UpdateSafetyModel(safety_asset.model_info());
// Fake text safety checker looks for the string "unsafe".
fake_broker_.settings().set_execute_result(
{"a", "b", "c", "d", "e", "f", "g", "unsafe", "h"});
auto session = CreateSession();
TestStreamingResponder responder;
session->Prompt(MakeInput("foo"), nullptr, responder.BindRemote());
EXPECT_FALSE(responder.WaitForCompletion());
EXPECT_EQ(responder.error_status(),
blink::mojom::ModelStreamingResponseStatus::kErrorFiltered);
EXPECT_TRUE(responder.responses().empty());
}
TEST_F(AILanguageModelTest, TextSafetyOutputPartial) {
auto config = CreateConfig();
config.set_can_skip_text_safety(false);
fake_broker_.UpdateModelAdaptation(
optimization_guide::FakeAdaptationAsset({.config = config}));
auto safety_config = CreateSafetyConfig();
auto* check = safety_config.mutable_raw_output_check();
check->mutable_input_template()->Add(
FieldSubstitution("%s", StringValueField()));
safety_config.mutable_partial_output_checks()->set_minimum_tokens(3);
safety_config.mutable_partial_output_checks()->set_token_interval(2);
optimization_guide::FakeSafetyModelAsset safety_asset(
std::move(safety_config));
fake_broker_.UpdateSafetyModel(safety_asset.model_info());
// Fake text safety checker looks for the string "unsafe".
fake_broker_.settings().set_execute_result(
{"a", "b", "c", "d", "e", "f", "g", "unsafe", "h"});
auto session = CreateSession();
TestStreamingResponder responder;
session->Prompt(MakeInput("foo"), nullptr, responder.BindRemote());
EXPECT_FALSE(responder.WaitForCompletion());
EXPECT_EQ(responder.error_status(),
blink::mojom::ModelStreamingResponseStatus::kErrorFiltered);
// Partial checks should still allow some output to stream.
EXPECT_THAT(responder.responses(), ElementsAre("abc", "de", "fg"));
}
TEST_F(AILanguageModelTest, QueuesOperations) {
base::test::TestFuture<mojo::Remote<blink::mojom::AILanguageModel>>
fork_future;
AITestUtils::MockCreateLanguageModelClient fork_client;
EXPECT_CALL(fork_client, OnResult(_, _))
.WillOnce([&](auto language_model, auto info) {
fork_future.SetValue(mojo::Remote<blink::mojom::AILanguageModel>(
std::move(language_model)));
});
auto session = CreateSession();
TestStreamingResponder responder1;
TestStreamingResponder responder2;
TestStreamingResponder responder3;
// Add three prompts and a fork, all these operations should complete
// successfully and in order.
session->Prompt(MakeInput("foo"), nullptr, responder1.BindRemote());
session->Prompt(MakeInput("bar"), nullptr, responder2.BindRemote());
session->Fork(fork_client.BindNewPipeAndPassRemote());
session->Prompt(MakeInput("baz"), nullptr, responder3.BindRemote());
EXPECT_TRUE(responder1.WaitForCompletion());
EXPECT_THAT(responder1.responses(),
ElementsAreArray(FormatResponses({"UfooEM"})));
EXPECT_TRUE(responder2.WaitForCompletion());
EXPECT_THAT(responder2.responses(),
ElementsAreArray(FormatResponses({"UfooEM", "UbarEM"})));
EXPECT_TRUE(responder3.WaitForCompletion());
EXPECT_THAT(
responder3.responses(),
ElementsAreArray(FormatResponses({"UfooEM", "UbarEM", "UbazEM"})));
EXPECT_THAT(
Prompt(*fork_future.Take(), MakeInput("fork")),
ElementsAreArray(FormatResponses({"UfooEM", "UbarEM", "UforkEM"})));
}
TEST_F(AILanguageModelTest, Constraint) {
auto session = CreateSession();
EXPECT_THAT(
Prompt(*session, MakeInput("foo"),
on_device_model::mojom::ResponseConstraint::NewRegex("reg")),
ElementsAre("Constraint: regex reg", "UfooEM"));
}
TEST_F(AILanguageModelTest, Prefix) {
auto session = CreateSession();
std::vector<blink::mojom::AILanguageModelPromptPtr> prompts;
prompts.push_back(MakePrompt(Role::kUser, "foo"));
prompts.push_back(MakePrompt(Role::kAssistant, "bar", /*is_prefix=*/true));
// Expect no 'bar' end token, nor separate model response start token.
EXPECT_THAT(Prompt(*session, std::move(prompts)), ElementsAre("UfooEMbar"));
}
TEST_F(AILanguageModelTest, ServiceCrash) {
auto session = CreateSession();
TestStreamingResponder responder;
session->Prompt(MakeInput("bar"), nullptr, responder.BindRemote());
fake_broker_.CrashService();
EXPECT_FALSE(responder.WaitForCompletion());
EXPECT_EQ(responder.error_status(),
blink::mojom::ModelStreamingResponseStatus::kErrorGenericFailure);
// Recreating the session should be fine.
session = CreateSession();
EXPECT_THAT(Prompt(*session, MakeInput("foo")),
ElementsAreArray(FormatResponses({"UfooEM"})));
}
TEST_F(AILanguageModelTest, CrashRecovery) {
auto session = CreateSession();
Append(*session, MakeInput("foo"));
fake_broker_.CrashService();
EXPECT_THAT(Prompt(*session, MakeInput("bar")),
ElementsAre("UfooE", "UbarEM"));
}
TEST_F(AILanguageModelTest, CrashRecoveryWithMultipleCrashes) {
auto session = CreateSession();
Append(*session, MakeInput("foo"));
fake_broker_.CrashService();
Append(*session, MakeInput("bar"));
fake_broker_.CrashService();
EXPECT_THAT(Prompt(*session, MakeInput("baz")),
ElementsAre("UfooEUbarE", "UbazEM"));
}
TEST_F(AILanguageModelTest, CrashRecoveryWithInitialPrompts) {
auto options = blink::mojom::AILanguageModelCreateOptions::New();
options->initial_prompts.push_back(MakePrompt(Role::kSystem, "hi"));
auto session = CreateSession(std::move(options));
Append(*session, MakeInput("foo"));
fake_broker_.CrashService();
EXPECT_THAT(Prompt(*session, MakeInput("bar")),
ElementsAre("ShiE", "UfooE", "UbarEM"));
}
TEST_F(AILanguageModelTest, CrashRecoveryMeasureInputUsage) {
auto session = CreateSession();
Append(*session, MakeInput("foo"));
fake_broker_.CrashService();
base::test::TestFuture<std::optional<uint32_t>> measure_future;
session->MeasureInputUsage(MakeInput("foo"), measure_future.GetCallback());
EXPECT_EQ(measure_future.Get(), std::string("UfooEM").size());
}
// TODO(crbug.com/414632884): This test is flaky on Linux TSAN.
#if BUILDFLAG(IS_LINUX) && defined(THREAD_SANITIZER)
#define MAYBE_CanCreate_WaitsForEligibility \
DISABLED_CanCreate_WaitsForEligibility
#else
#define MAYBE_CanCreate_WaitsForEligibility CanCreate_WaitsForEligibility
#endif
TEST_F(AILanguageModelTest, MAYBE_CanCreate_WaitsForEligibility) {
base::test::TestFuture<base::OnceCallback<void(
optimization_guide::OnDeviceModelEligibilityReason)>>
eligibility_future;
EXPECT_CALL(*mock_optimization_guide_keyed_service_,
GetOnDeviceModelEligibilityAsync(_, _, _))
.WillOnce(
testing::Invoke([&](auto feature, auto capabilities, auto callback) {
eligibility_future.SetValue(std::move(callback));
}));
base::test::TestFuture<blink::mojom::ModelAvailabilityCheckResult>
result_future;
GetAIManagerInterface()->CanCreateLanguageModel({},
result_future.GetCallback());
// Session should not be ready until eligibility callback has run.
EXPECT_FALSE(result_future.IsReady());
eligibility_future.Take().Run(
optimization_guide::OnDeviceModelEligibilityReason::kSuccess);
EXPECT_EQ(result_future.Get(),
blink::mojom::ModelAvailabilityCheckResult::kAvailable);
}
TEST_F(AILanguageModelTest, CanCreate_SupportedLanguages) {
base::MockCallback<AIManager::CanCreateLanguageModelCallback> callback;
auto options = blink::mojom::AILanguageModelCreateOptions::New();
options->expected_inputs.emplace();
options->expected_inputs->push_back(
blink::mojom::AILanguageModelExpected::New(
blink::mojom::AILanguageModelPromptType::kText,
AITestUtils::ToMojoLanguageCodes({"en"})));
options->expected_outputs.emplace();
options->expected_outputs->push_back(
blink::mojom::AILanguageModelExpected::New(
blink::mojom::AILanguageModelPromptType::kText,
AITestUtils::ToMojoLanguageCodes({"en"})));
EXPECT_CALL(callback,
Run(blink::mojom::ModelAvailabilityCheckResult::kAvailable));
GetAIManagerInterface()->CanCreateLanguageModel(std::move(options),
callback.Get());
}
TEST_F(AILanguageModelTest, CanCreate_UnsupportedInputLanguages) {
base::MockCallback<AIManager::CanCreateLanguageModelCallback> callback;
EXPECT_CALL(callback, Run(blink::mojom::ModelAvailabilityCheckResult::
kUnavailableUnsupportedLanguage));
auto options = blink::mojom::AILanguageModelCreateOptions::New();
options->expected_inputs.emplace();
options->expected_inputs->push_back(
blink::mojom::AILanguageModelExpected::New(
blink::mojom::AILanguageModelPromptType::kText,
AITestUtils::ToMojoLanguageCodes({"fr"})));
GetAIManagerInterface()->CanCreateLanguageModel(std::move(options),
callback.Get());
}
TEST_F(AILanguageModelTest, CanCreate_UnsupportedOutputLanguages) {
base::MockCallback<AIManager::CanCreateLanguageModelCallback> callback;
EXPECT_CALL(callback, Run(blink::mojom::ModelAvailabilityCheckResult::
kUnavailableUnsupportedLanguage));
auto options = blink::mojom::AILanguageModelCreateOptions::New();
options->expected_outputs.emplace();
options->expected_outputs->push_back(
blink::mojom::AILanguageModelExpected::New(
blink::mojom::AILanguageModelPromptType::kText,
AITestUtils::ToMojoLanguageCodes({"fr"})));
GetAIManagerInterface()->CanCreateLanguageModel(std::move(options),
callback.Get());
}
TEST_F(AILanguageModelTest, CanCreate_UnavailableWhenAdaptationNotAvailable) {
EXPECT_CALL(*mock_optimization_guide_keyed_service_,
GetOnDeviceModelEligibilityAsync(_, _, _))
.WillOnce([](auto feature, auto capabilities, auto callback) {
std::move(callback).Run(
optimization_guide::OnDeviceModelEligibilityReason::
kModelAdaptationNotAvailable);
});
base::test::TestFuture<blink::mojom::ModelAvailabilityCheckResult>
result_future;
GetAIManagerInterface()->CanCreateLanguageModel({},
result_future.GetCallback());
EXPECT_EQ(result_future.Get(), blink::mojom::ModelAvailabilityCheckResult::
kUnavailableModelAdaptationNotAvailable);
}
// Tests the `AILanguageModel::Context` that's initialized with/without any
// initial prompt.
class AILanguageModelContextTest : public testing::Test {
public:
AILanguageModel::Context context_{kTestMaxContextToken};
};
// Tests `GetContextString()` when the context is empty.
TEST_F(AILanguageModelContextTest, TestContextOperation_Empty) {
EXPECT_EQ(GetContextString(context_), "");
}
// Tests `GetContextString()` when some items are added to the context.
TEST_F(AILanguageModelContextTest, TestContextOperation_NonEmpty) {
EXPECT_EQ(context_.AddContextItem(SimpleContextItem("test", 1u)),
AILanguageModel::Context::SpaceReservationResult::kSufficientSpace);
EXPECT_EQ(GetContextString(context_), "S: test");
context_.AddContextItem(SimpleContextItem(" test again", 2u));
EXPECT_EQ(GetContextString(context_), "S: testS: test again");
}
// Tests `GetContextString()` when the items overflow.
TEST_F(AILanguageModelContextTest, TestContextOperation_Overflow) {
EXPECT_EQ(context_.AddContextItem(SimpleContextItem("test", 1u)),
AILanguageModel::Context::SpaceReservationResult::kSufficientSpace);
EXPECT_EQ(GetContextString(context_), "S: test");
// Since the total number of tokens will exceed `kTestMaxContextToken`, the
// old item will be evicted.
EXPECT_EQ(
context_.AddContextItem(
SimpleContextItem("long token", kTestMaxContextToken)),
AILanguageModel::Context::SpaceReservationResult::kSpaceMadeAvailable);
EXPECT_EQ(GetContextString(context_), "S: long token");
}
TEST_F(AILanguageModelContextTest, TestContextOperation_PartialOverflow) {
EXPECT_EQ(context_.AddContextItem(SimpleContextItem("foo", 1u)),
AILanguageModel::Context::SpaceReservationResult::kSufficientSpace);
EXPECT_EQ(GetContextString(context_), "S: foo");
EXPECT_EQ(context_.AddContextItem(SimpleContextItem("bar", 1u)),
AILanguageModel::Context::SpaceReservationResult::kSufficientSpace);
EXPECT_EQ(GetContextString(context_), "S: fooS: bar");
// Add 1 token less than `kTestMaxContextToken` so one of the previous items
// will be kept.
EXPECT_EQ(
context_.AddContextItem(
SimpleContextItem("long token", kTestMaxContextToken - 1)),
AILanguageModel::Context::SpaceReservationResult::kSpaceMadeAvailable);
EXPECT_EQ(GetContextString(context_), "S: barS: long token");
}
// Tests `GetContextString()` when the items overflow on the first insertion.
TEST_F(AILanguageModelContextTest, TestContextOperation_OverflowOnFirstItem) {
EXPECT_EQ(
context_.AddContextItem(
SimpleContextItem("test very long token", kTestMaxContextToken + 1u)),
AILanguageModel::Context::SpaceReservationResult::kInsufficientSpace);
EXPECT_EQ(GetContextString(context_), "");
}
TEST_F(AILanguageModelTest, Priority) {
fake_broker_.settings().set_execute_result({"hi"});
auto session = CreateSession();
EXPECT_THAT(Prompt(*session, MakeInput("foo")), ElementsAre("hi"));
main_rfh()->GetRenderWidgetHost()->GetView()->Hide();
EXPECT_THAT(Prompt(*session, MakeInput("bar")),
ElementsAre("Priority: background", "hi"));
auto fork = Fork(*session);
EXPECT_THAT(Prompt(*fork, MakeInput("bar")),
ElementsAre("Priority: background", "hi"));
main_rfh()->GetRenderWidgetHost()->GetView()->Show();
EXPECT_THAT(Prompt(*session, MakeInput("baz")), ElementsAre("hi"));
}
} // namespace