// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/ai/ai_language_model.h"

#include <cstdint>
#include <initializer_list>
#include <optional>
#include <string>
#include <vector>

#include "base/functional/bind.h"
#include "base/functional/callback_forward.h"
#include "base/functional/callback_helpers.h"
#include "base/notreached.h"
#include "base/run_loop.h"
#include "base/strings/strcat.h"
#include "base/strings/string_number_conversions.h"
#include "base/strings/stringprintf.h"
#include "base/task/current_thread.h"
#include "base/test/mock_callback.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/test_future.h"
#include "chrome/browser/ai/ai_test_utils.h"
#include "chrome/browser/ai/ai_utils.h"
#include "chrome/browser/ai/features.h"
#include "chrome/browser/component_updater/optimization_guide_on_device_model_installer.h"
#include "components/optimization_guide/core/mock_optimization_guide_model_executor.h"
#include "components/optimization_guide/core/model_execution/multimodal_message.h"
#include "components/optimization_guide/core/model_execution/test/fake_model_assets.h"
#include "components/optimization_guide/core/model_execution/test/fake_model_broker.h"
#include "components/optimization_guide/core/optimization_guide_features.h"
#include "components/optimization_guide/core/optimization_guide_model_executor.h"
#include "components/optimization_guide/core/optimization_guide_proto_util.h"
#include "components/optimization_guide/proto/common_types.pb.h"
#include "components/optimization_guide/proto/descriptors.pb.h"
#include "components/optimization_guide/proto/features/prompt_api.pb.h"
#include "components/optimization_guide/proto/on_device_model_execution_config.pb.h"
#include "components/optimization_guide/proto/string_value.pb.h"
#include "components/update_client/update_client.h"
#include "content/public/browser/render_widget_host_view.h"
#include "services/on_device_model/public/cpp/capabilities.h"
#include "services/on_device_model/public/mojom/on_device_model.mojom.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/blink/public/common/features_generated.h"
#include "third_party/blink/public/mojom/ai/ai_common.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_language_model.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_manager.mojom-shared.h"
#include "third_party/blink/public/mojom/ai/model_download_progress_observer.mojom-forward.h"
#include "third_party/blink/public/mojom/ai/model_streaming_responder.mojom-shared.h"

namespace {

using ::optimization_guide::FieldSubstitution;
using ::optimization_guide::ForbidUnsafe;
using ::optimization_guide::StringValueField;
using ::testing::_;
using ::testing::ElementsAre;
using ::testing::ElementsAreArray;
using ::testing::Return;
using Role = ::blink::mojom::AILanguageModelPromptRole;

constexpr uint32_t kTestMaxContextToken = 10u;
constexpr uint32_t kTestDefaultTopK = 1u;
constexpr float kTestDefaultTemperature = 0.0f;
constexpr uint32_t kTestMaxTopK = 5u;
constexpr float kTestMaxTemperature = 1.5;
constexpr uint32_t kTestMaxTokens = 100u;
constexpr uint32_t kTestModelMaxTokens = 200u;
constexpr uint64_t kTestModelDownloadSize = 572u;
static_assert(kTestDefaultTopK <= kTestMaxTopK);
static_assert(kTestDefaultTemperature <= kTestMaxTemperature);

SkBitmap CreateTestBitmap(int width, int height) {
  SkBitmap bitmap;
  bitmap.allocN32Pixels(width, height);
  bitmap.eraseColor(SK_ColorRED);
  return bitmap;
}

on_device_model::mojom::AudioDataPtr CreateTestAudio() {
  return on_device_model::mojom::AudioData::New();
}

// Convert a single AILanguageModelPromptContentPtr to a vector.
std::vector<blink::mojom::AILanguageModelPromptContentPtr> ToVector(
    blink::mojom::AILanguageModelPromptContentPtr content) {
  std::vector<blink::mojom::AILanguageModelPromptContentPtr> vector;
  vector.push_back(std::move(content));
  return vector;
}

// Convert a list of strings to a AILanguageModelPromptContentPtr vector.
std::vector<blink::mojom::AILanguageModelPromptContentPtr> ToContentVector(
    std::initializer_list<std::string> texts) {
  std::vector<blink::mojom::AILanguageModelPromptContentPtr> vector;
  for (const std::string& text : texts) {
    vector.push_back(blink::mojom::AILanguageModelPromptContent::NewText(text));
  }
  return vector;
}

optimization_guide::proto::FeatureTextSafetyConfiguration CreateSafetyConfig() {
  optimization_guide::proto::FeatureTextSafetyConfiguration safety_config;
  safety_config.set_feature(
      optimization_guide::proto::MODEL_EXECUTION_FEATURE_PROMPT_API);
  safety_config.mutable_safety_category_thresholds()->Add(ForbidUnsafe());
  return safety_config;
}

// Build a mojo prompt struct with the specified `role` and `text`
blink::mojom::AILanguageModelPromptPtr MakePrompt(Role role,
                                                  const std::string& text,
                                                  bool is_prefix = false) {
  return blink::mojom::AILanguageModelPrompt::New(role, ToContentVector({text}),
                                                  is_prefix);
}

// Build a vector with a single prompt that has multiple user text contents.
std::vector<blink::mojom::AILanguageModelPromptPtr> MakeInput(
    std::initializer_list<std::string> texts) {
  std::vector<blink::mojom::AILanguageModelPromptPtr> prompts;
  prompts.push_back(blink::mojom::AILanguageModelPrompt::New(
      Role::kUser, ToContentVector(std::move(texts)), /*is_prefix=*/false));
  return prompts;
}

// Build a vector with a single prompt that has a single user text content.
std::vector<blink::mojom::AILanguageModelPromptPtr> MakeInput(
    const std::string& text) {
  return MakeInput({text});
}

// Construct a ContextItem with system prompt text.
AILanguageModel::Context::ContextItem SimpleContextItem(std::string text,
                                                        uint32_t size) {
  auto item = AILanguageModel::Context::ContextItem();
  item.tokens = size;
  item.input = on_device_model::mojom::Input::New();
  item.input->pieces = {ml::Token::kSystem, text};
  return item;
}

// Convert a ml::Token to a string for expectation matching.
const char* FormatToken(ml::Token token) {
  switch (token) {
    case ml::Token::kSystem:
      return "S: ";
    case ml::Token::kUser:
      return "U: ";
    case ml::Token::kModel:
      return "M: ";
    default:
      NOTREACHED();
  }
}

// Convert an Input to a string for expectation matching.
std::string FormatInput(const on_device_model::mojom::Input& input) {
  std::string str;
  for (const auto& piece : input.pieces) {
    if (std::holds_alternative<ml::Token>(piece)) {
      str += FormatToken(std::get<ml::Token>(piece));
    } else if (std::holds_alternative<std::string>(piece)) {
      str += std::get<std::string>(piece);
    } else if (std::holds_alternative<SkBitmap>(piece)) {
      str += "<image>";
    } else if (std::holds_alternative<ml::AudioBuffer>(piece)) {
      str += "<audio>";
    }
  }
  return str;
}

// Convert a Context to string for expectation matching.
std::string GetContextString(AILanguageModel::Context& ctx) {
  return FormatInput(*ctx.GetNonInitialPrompts());
}

class TestStreamingResponder : public blink::mojom::ModelStreamingResponder {
 public:
  TestStreamingResponder() = default;
  ~TestStreamingResponder() override = default;

  mojo::PendingRemote<blink::mojom::ModelStreamingResponder> BindRemote() {
    return receiver_.BindNewPipeAndPassRemote();
  }

  // Returns true on successful completion and false on error.
  bool WaitForCompletion() {
    run_loop_.Run();
    return !error_status_.has_value();
  }

  void WaitForQuotaOverflow() { quota_overflow_run_loop_.Run(); }

  blink::mojom::ModelStreamingResponseStatus error_status() const {
    EXPECT_TRUE(error_status_.has_value());
    return *error_status_;
  }

  blink::mojom::QuotaErrorInfo quota_error_info() const {
    return *quota_error_info_;
  }

  const std::vector<std::string> responses() const { return responses_; }
  uint64_t current_tokens() const { return current_tokens_; }

 private:
  // blink::mojom::ModelStreamingResponder:
  void OnError(blink::mojom::ModelStreamingResponseStatus status,
               blink::mojom::QuotaErrorInfoPtr quota_error_info) override {
    error_status_ = status;
    quota_error_info_ = std::move(quota_error_info);
    run_loop_.Quit();
  }

  void OnStreaming(const std::string& text) override {
    responses_.push_back(text);
  }

  void OnCompletion(
      blink::mojom::ModelExecutionContextInfoPtr context_info) override {
    current_tokens_ = context_info->current_tokens;
    run_loop_.Quit();
  }

  void OnQuotaOverflow() override { quota_overflow_run_loop_.Quit(); }

  std::optional<blink::mojom::ModelStreamingResponseStatus> error_status_;
  blink::mojom::QuotaErrorInfoPtr quota_error_info_;
  std::vector<std::string> responses_;
  uint64_t current_tokens_ = 0;
  base::RunLoop run_loop_;
  base::RunLoop quota_overflow_run_loop_;
  mojo::Receiver<blink::mojom::ModelStreamingResponder> receiver_{this};
};

optimization_guide::proto::OnDeviceModelExecutionFeatureConfig CreateConfig() {
  optimization_guide::proto::OnDeviceModelExecutionFeatureConfig config;
  config.set_can_skip_text_safety(true);
  optimization_guide::proto::SamplingParams sampling_params;
  sampling_params.set_top_k(kTestMaxTopK);
  sampling_params.set_temperature(kTestMaxTemperature);
  *config.mutable_sampling_params() = sampling_params;

  config.mutable_input_config()->set_max_context_tokens(kTestMaxTokens);

  optimization_guide::proto::PromptApiMetadata metadata;
  *metadata.mutable_max_sampling_params() = sampling_params;
  *config.mutable_feature_metadata() =
      optimization_guide::AnyWrapProto(metadata);

  config.set_feature(optimization_guide::proto::ModelExecutionFeature::
                         MODEL_EXECUTION_FEATURE_PROMPT_API);
  return config;
}

// Formats responses to match what the fake on device model service will return.
// The fake service keeps track of all previous inputs to a session, and will
// spit them all back out during a Generate() call. This gets a bit complicated
// for the language model, which also adds back the output as input to the
// session. An example language model session using the default behavior of the
// fake service would look something like this:
// - s1.Prompt("foo")
//   - Adds "UfooEM" to the session
//   - Gets output of ["UfooEM"] from fake service
//   - Adds "UfooEME" to the session (fake response + end token)
// - s1.Prompt("bar")
//   - Adds "UbarEM" to the session
//   - Gets output of ["UfooEM", "UfooEME", "UbarEM"].
//   - Adds "UfooEMUfooEMEUbarEM"
//     (concatenated output from fake service) to the session
// This behavior verifies the correct inputs and outputs are being returned from
// the model, and this helper makes it easier to construct these expectations.
// TODO(crbug.com/415808003): Simplify this in the fake service.
std::vector<std::string> FormatResponses(
    const std::vector<std::string>& responses) {
  std::vector<std::string> formatted;
  std::string last_output;
  for (const std::string& response : responses) {
    if (!last_output.empty()) {
      formatted.push_back(last_output + "E");
      last_output += formatted.back();
    }
    formatted.push_back(response);
    last_output += formatted.back();
  }
  return formatted;
}

class AILanguageModelTest : public AITestUtils::AITestBase {
 public:
  AILanguageModelTest()
      : fake_broker_(optimization_guide::FakeAdaptationAsset(
            {.config = CreateConfig()})) {
    scoped_feature_list_.InitWithFeaturesAndParameters(
        {{blink::features::kAIPromptAPIMultimodalInput, {}},
         {features::kAILanguageModelOverrideConfiguration,
          {{"ai_language_model_output_buffer", "100"}}},
         {optimization_guide::features::kOptimizationGuideOnDeviceModel,
          {{"on_device_model_max_tokens_for_execute", "0"},
           {"on_device_model_max_tokens_for_output", "0"},
           {"on_device_model_max_tokens_for_context",
            base::NumberToString(kTestModelMaxTokens)}}}},
        {});
    // Reset the adaptation to make sure the feature params get picked up.
    fake_broker_.UpdateModelAdaptation(
        optimization_guide::FakeAdaptationAsset({.config = CreateConfig()}));
  }

  void SetUp() override {
    AITestBase::SetUp();
    SetupMockOptimizationGuideKeyedService();
    ai_manager_ =
        std::make_unique<AIManager>(main_rfh()->GetBrowserContext(),
                                    &component_update_service_, main_rfh());
  }

 protected:
  void SetupMockOptimizationGuideKeyedService() override {
    AITestUtils::AITestBase::SetupMockOptimizationGuideKeyedService();
    ON_CALL(*mock_optimization_guide_keyed_service_, CreateModelBrokerClient())
        .WillByDefault([&]() {
          return std::make_unique<optimization_guide::ModelBrokerClient>(
              fake_broker_.BindAndPassRemote(),
              optimization_guide::CreateSessionArgs(nullptr, {}));
        });
    ON_CALL(*mock_optimization_guide_keyed_service_,
            GetSamplingParamsConfig(
                optimization_guide::ModelBasedCapabilityKey::kPromptApi))
        .WillByDefault([]() {
          return optimization_guide::SamplingParamsConfig{
              .default_top_k = kTestDefaultTopK,
              .default_temperature = kTestDefaultTemperature};
        });
    ON_CALL(*mock_optimization_guide_keyed_service_,
            GetFeatureMetadata(
                optimization_guide::ModelBasedCapabilityKey::kPromptApi))
        .WillByDefault([]() { return CreateConfig().feature_metadata(); });
    ON_CALL(*mock_optimization_guide_keyed_service_, GetOnDeviceCapabilities())
        .WillByDefault(Return(on_device_model::Capabilities(
            {on_device_model::CapabilityFlags::kImageInput,
             on_device_model::CapabilityFlags::kAudioInput})));
    ON_CALL(*mock_optimization_guide_keyed_service_,
            GetOnDeviceModelEligibility(_))
        .WillByDefault(Return(
            optimization_guide::OnDeviceModelEligibilityReason::kSuccess));
  }

  mojo::Remote<blink::mojom::AILanguageModel> CreateSession(
      blink::mojom::AILanguageModelCreateOptionsPtr options =
          blink::mojom::AILanguageModelCreateOptions::New()) {
    base::test::TestFuture<mojo::Remote<blink::mojom::AILanguageModel>> future;
    AITestUtils::MockCreateLanguageModelClient language_model_client;
    EXPECT_CALL(language_model_client, OnResult(_, _))
        .WillOnce([&](mojo::PendingRemote<blink::mojom::AILanguageModel>
                          language_model,
                      blink::mojom::AILanguageModelInstanceInfoPtr info) {
          future.SetValue(mojo::Remote<blink::mojom::AILanguageModel>(
              std::move(language_model)));
        });

    GetAIManagerRemote()->CreateLanguageModel(
        language_model_client.BindNewPipeAndPassRemote(), std::move(options));
    return future.Take();
  }

  std::vector<std::string> Prompt(
      blink::mojom::AILanguageModel& model,
      std::vector<blink::mojom::AILanguageModelPromptPtr> input,
      on_device_model::mojom::ResponseConstraintPtr constraint = nullptr) {
    TestStreamingResponder responder;
    model.Prompt(std::move(input), std::move(constraint),
                 responder.BindRemote());
    EXPECT_TRUE(responder.WaitForCompletion());
    return responder.responses();
  }

  void Append(blink::mojom::AILanguageModel& model,
              std::vector<blink::mojom::AILanguageModelPromptPtr> input) {
    TestStreamingResponder responder;
    model.Append(std::move(input), responder.BindRemote());
    EXPECT_TRUE(responder.WaitForCompletion());
  }

  mojo::Remote<blink::mojom::AILanguageModel> Fork(
      blink::mojom::AILanguageModel& model) {
    base::test::TestFuture<mojo::Remote<blink::mojom::AILanguageModel>> future;
    AITestUtils::MockCreateLanguageModelClient language_model_client;
    EXPECT_CALL(language_model_client, OnResult(_, _))
        .WillOnce([&](mojo::PendingRemote<blink::mojom::AILanguageModel>
                          language_model,
                      blink::mojom::AILanguageModelInstanceInfoPtr info) {
          future.SetValue(mojo::Remote<blink::mojom::AILanguageModel>(
              std::move(language_model)));
        });

    model.Fork(language_model_client.BindNewPipeAndPassRemote());
    return future.Take();
  }

 protected:
  optimization_guide::FakeModelBroker fake_broker_;
  base::test::ScopedFeatureList scoped_feature_list_;
};

TEST_F(AILanguageModelTest, Prompt) {
  auto session = CreateSession();
  EXPECT_THAT(Prompt(*session, MakeInput("foo")),
              ElementsAreArray(FormatResponses({"UfooEM"})));
}

TEST_F(AILanguageModelTest, MultiplePrompts) {
  auto session = CreateSession();
  EXPECT_THAT(Prompt(*session, MakeInput("foo")),
              ElementsAreArray(FormatResponses({"UfooEM"})));
  EXPECT_THAT(Prompt(*session, MakeInput("bar")),
              ElementsAreArray(FormatResponses({"UfooEM", "UbarEM"})));
  EXPECT_THAT(
      Prompt(*session, MakeInput("baz")),
      ElementsAreArray(FormatResponses({"UfooEM", "UbarEM", "UbazEM"})));
}

TEST_F(AILanguageModelTest, PromptMultipleContents) {
  auto session = CreateSession();
  EXPECT_THAT(Prompt(*session, MakeInput({"foo", "bar"})),
              ElementsAreArray(FormatResponses({"UfoobarEM"})));
}

TEST_F(AILanguageModelTest, Append) {
  auto session = CreateSession();
  Append(*session, MakeInput("foo"));
  EXPECT_THAT(Prompt(*session, MakeInput("bar")),
              ElementsAre("UfooE", "UbarEM"));
}

TEST_F(AILanguageModelTest, AppendMultipleContents) {
  auto session = CreateSession();
  Append(*session, MakeInput({"foo", "bar"}));
  EXPECT_THAT(Prompt(*session, MakeInput("baz")),
              ElementsAre("UfoobarE", "UbazEM"));
}

TEST_F(AILanguageModelTest, PromptTokenCounts) {
  fake_broker_.settings().set_execute_result({"hi"});
  auto session = CreateSession();

  std::string expected_tokens = "UfooEMhiE";
  {
    TestStreamingResponder responder;
    session->Prompt(MakeInput("foo"), nullptr, responder.BindRemote());
    EXPECT_TRUE(responder.WaitForCompletion());
    EXPECT_EQ(responder.current_tokens(), expected_tokens.size());
  }
  expected_tokens += "UbarEMhiE";
  {
    TestStreamingResponder responder;
    session->Prompt(MakeInput("bar"), nullptr, responder.BindRemote());
    EXPECT_TRUE(responder.WaitForCompletion());
    EXPECT_EQ(responder.current_tokens(), expected_tokens.size());
  }
  auto fork = Fork(*session);
  expected_tokens += "UbazEMhiE";
  {
    TestStreamingResponder responder;
    fork->Prompt(MakeInput("baz"), nullptr, responder.BindRemote());
    EXPECT_TRUE(responder.WaitForCompletion());
    EXPECT_EQ(responder.current_tokens(), expected_tokens.size());
  }
}

TEST_F(AILanguageModelTest, AppendTokenCounts) {
  auto session = CreateSession();

  std::string expected_tokens = "UfooE";
  {
    TestStreamingResponder responder;
    session->Append(MakeInput("foo"), responder.BindRemote());
    EXPECT_TRUE(responder.WaitForCompletion());
    EXPECT_EQ(responder.current_tokens(), expected_tokens.size());
  }
  expected_tokens += "UbarE";
  {
    TestStreamingResponder responder;
    session->Append(MakeInput("bar"), responder.BindRemote());
    EXPECT_TRUE(responder.WaitForCompletion());
    EXPECT_EQ(responder.current_tokens(), expected_tokens.size());
  }
  auto fork = Fork(*session);
  expected_tokens += "UbazE";
  {
    TestStreamingResponder responder;
    fork->Append(MakeInput("baz"), responder.BindRemote());
    EXPECT_TRUE(responder.WaitForCompletion());
    EXPECT_EQ(responder.current_tokens(), expected_tokens.size());
  }
}

TEST_F(AILanguageModelTest, Roles) {
  auto session = CreateSession();
  std::vector<blink::mojom::AILanguageModelPromptPtr> prompts;
  prompts.push_back(MakePrompt(Role::kUser, "user"));
  prompts.push_back(MakePrompt(Role::kSystem, "system"));
  prompts.push_back(MakePrompt(Role::kAssistant, "model"));
  EXPECT_THAT(Prompt(*session, std::move(prompts)),
              ElementsAreArray(FormatResponses({"UuserESsystemEMmodelEM"})));
}

TEST_F(AILanguageModelTest, Fork) {
  auto session = CreateSession();
  auto fork1 = Fork(*session);

  EXPECT_THAT(Prompt(*session, MakeInput("foo")),
              ElementsAreArray(FormatResponses({"UfooEM"})));
  auto fork2 = Fork(*session);

  EXPECT_THAT(Prompt(*session, MakeInput("bar")),
              ElementsAreArray(FormatResponses({"UfooEM", "UbarEM"})));
  auto fork3 = Fork(*session);

  EXPECT_THAT(Prompt(*fork1, MakeInput("fork")),
              ElementsAreArray(FormatResponses({"UforkEM"})));
  EXPECT_THAT(Prompt(*fork2, MakeInput("fork")),
              ElementsAreArray(FormatResponses({"UfooEM", "UforkEM"})));
  auto fork4 = Fork(*fork2);
  EXPECT_THAT(
      Prompt(*fork3, MakeInput("fork")),
      ElementsAreArray(FormatResponses({"UfooEM", "UbarEM", "UforkEM"})));
  EXPECT_THAT(
      Prompt(*session, MakeInput("baz")),
      ElementsAreArray(FormatResponses({"UfooEM", "UbarEM", "UbazEM"})));

  EXPECT_THAT(
      Prompt(*fork4, MakeInput("more")),
      ElementsAreArray(FormatResponses({"UfooEM", "UforkEM", "UmoreEM"})));
}

TEST_F(AILanguageModelTest, SamplingParams) {
  auto sampling_params = blink::mojom::AILanguageModelSamplingParams::New();
  sampling_params->top_k = 2;
  sampling_params->temperature = 1.0;

  auto options = blink::mojom::AILanguageModelCreateOptions::New();
  options->sampling_params = std::move(sampling_params);
  auto session = CreateSession(std::move(options));
  auto fork = Fork(*session);

  EXPECT_THAT(Prompt(*session, MakeInput("foo")),
              ElementsAre("UfooEM", "TopK: 2, Temp: 1"));
  EXPECT_THAT(Prompt(*fork, MakeInput("bar")),
              ElementsAre("UbarEM", "TopK: 2, Temp: 1"));
}

TEST_F(AILanguageModelTest, SamplingParamsTopKOutOfRange) {
  auto sampling_params = blink::mojom::AILanguageModelSamplingParams::New();
  sampling_params->top_k = 0;
  sampling_params->temperature = 1.5f;

  auto options = blink::mojom::AILanguageModelCreateOptions::New();
  options->sampling_params = std::move(sampling_params);
  auto session = CreateSession(std::move(options));

  EXPECT_THAT(Prompt(*session, MakeInput("foo")),
              ElementsAre("UfooEM", "TopK: 1, Temp: 1.5"));
}

TEST_F(AILanguageModelTest, SamplingParamsTemperatureOutOfRange) {
  auto sampling_params = blink::mojom::AILanguageModelSamplingParams::New();
  sampling_params->top_k = 2;
  sampling_params->temperature = -1.0f;

  auto options = blink::mojom::AILanguageModelCreateOptions::New();
  options->sampling_params = std::move(sampling_params);
  auto session = CreateSession(std::move(options));

  EXPECT_THAT(Prompt(*session, MakeInput("foo")),
              ElementsAre("UfooEM", "TopK: 2, Temp: 0"));
}

TEST_F(AILanguageModelTest, MaxSamplingParams) {
  auto sampling_params = blink::mojom::AILanguageModelSamplingParams::New();
  sampling_params->top_k = kTestMaxTopK + 1;
  sampling_params->temperature = kTestMaxTemperature + 1;

  auto options = blink::mojom::AILanguageModelCreateOptions::New();
  options->sampling_params = std::move(sampling_params);
  auto session = CreateSession(std::move(options));

  EXPECT_THAT(Prompt(*session, MakeInput("foo")),
              ElementsAre("UfooEM", "TopK: 5, Temp: 1.5"));
}

TEST_F(AILanguageModelTest, InitialPrompts) {
  auto options = blink::mojom::AILanguageModelCreateOptions::New();
  options->initial_prompts.push_back(MakePrompt(Role::kSystem, "hi"));
  options->initial_prompts.push_back(MakePrompt(Role::kUser, "bye"));
  auto session = CreateSession(std::move(options));

  EXPECT_THAT(Prompt(*session, MakeInput("foo")),
              ElementsAre("ShiEUbyeE", "UfooEM"));
}

TEST_F(AILanguageModelTest, InitialPromptsMultipleContents) {
  auto options = blink::mojom::AILanguageModelCreateOptions::New();
  options->initial_prompts = MakeInput({"foo", "bar"});
  auto session = CreateSession(std::move(options));

  EXPECT_THAT(Prompt(*session, MakeInput("baz")),
              ElementsAre("UfoobarE", "UbazEM"));
}

TEST_F(AILanguageModelTest, InitialPromptsInstanceInfo) {
  base::test::TestFuture<blink::mojom::AILanguageModelInstanceInfoPtr> future;
  AITestUtils::MockCreateLanguageModelClient language_model_client;
  EXPECT_CALL(language_model_client, OnResult(_, _))
      .WillOnce(
          [&](mojo::PendingRemote<blink::mojom::AILanguageModel> language_model,
              blink::mojom::AILanguageModelInstanceInfoPtr info) {
            future.SetValue(std::move(info));
          });

  auto options = blink::mojom::AILanguageModelCreateOptions::New();
  options->initial_prompts.push_back(MakePrompt(Role::kSystem, "hi"));
  GetAIManagerRemote()->CreateLanguageModel(
      language_model_client.BindNewPipeAndPassRemote(), std::move(options));
  auto info = future.Take();
  EXPECT_EQ(info->input_quota, kTestMaxTokens);
  EXPECT_EQ(info->input_usage, std::strlen("ShiE"));
}

TEST_F(AILanguageModelTest, InitialPromptsTooLarge) {
  base::test::TestFuture<blink::mojom::AIManagerCreateClientError> error_future;
  base::test::TestFuture<blink::mojom::QuotaErrorInfoPtr>
      quota_error_info_future;
  AITestUtils::MockCreateLanguageModelClient language_model_client;
  EXPECT_CALL(language_model_client, OnError(_, _))
      .WillOnce([&](auto error, auto quota_error_info) {
        error_future.SetValue(error);
        quota_error_info_future.SetValue(std::move(quota_error_info));
      });

  auto options = blink::mojom::AILanguageModelCreateOptions::New();
  options->initial_prompts.push_back(
      MakePrompt(Role::kSystem, std::string(kTestMaxTokens + 1, 'a')));

  GetAIManagerRemote()->CreateLanguageModel(
      language_model_client.BindNewPipeAndPassRemote(), std::move(options));
  EXPECT_EQ(error_future.Take(),
            blink::mojom::AIManagerCreateClientError::kInitialInputTooLarge);
  auto quota_error_info = quota_error_info_future.Take();
  ASSERT_TRUE(quota_error_info);
  ASSERT_GT(quota_error_info->requested, kTestMaxTokens);
  ASSERT_EQ(quota_error_info->quota, kTestMaxTokens);
}

TEST_F(AILanguageModelTest, InputTooLarge) {
  auto session = CreateSession();

  TestStreamingResponder responder;
  session->Prompt(MakeInput(std::string(kTestMaxTokens + 1, 'a')), nullptr,
                  responder.BindRemote());
  EXPECT_FALSE(responder.WaitForCompletion());
  EXPECT_EQ(responder.error_status(),
            blink::mojom::ModelStreamingResponseStatus::kErrorInputTooLarge);
  ASSERT_GT(responder.quota_error_info().requested, kTestMaxTokens);
  ASSERT_EQ(responder.quota_error_info().quota, kTestMaxTokens);
}

TEST_F(AILanguageModelTest, QuotaOverflowOnPromptInput) {
  // Set the execute result so the long prompt is not echoed back as the
  // response.
  fake_broker_.settings().set_execute_result({"hi"});
  // Initial prompt should be kept on overflow.
  auto options = blink::mojom::AILanguageModelCreateOptions::New();
  options->initial_prompts.push_back(MakePrompt(Role::kSystem, "init"));
  auto session = CreateSession(std::move(options));
  // Set a prompt that is close to max token length. This string should be
  // stripped from the prompt history, while the initial prompts and
  // `long_prompt` will be kept.
  EXPECT_THAT(
      Prompt(*session, MakeInput(std::string(kTestMaxTokens - 20, 'a'))),
      ElementsAre("hi"));
  EXPECT_THAT(Prompt(*session, MakeInput("foo")), ElementsAre("hi"));

  // Clear execute result so we can verify the input by checking the response.
  fake_broker_.settings().set_execute_result({});
  std::string long_prompt(kTestMaxTokens / 3, 'a');
  TestStreamingResponder responder;
  session->Prompt(MakeInput(long_prompt), nullptr, responder.BindRemote());
  responder.WaitForQuotaOverflow();
  EXPECT_TRUE(responder.WaitForCompletion());
  // Response should include input/output of previous prompt with the original
  // long prompt not present.
  EXPECT_THAT(responder.responses(),
              ElementsAre("SinitE", "UfooEMhiE", "U" + long_prompt + "EM"));
}

TEST_F(AILanguageModelTest, QuotaOverflowOnAppend) {
  // Initial prompt should be kept on overflow.
  auto options = blink::mojom::AILanguageModelCreateOptions::New();
  options->initial_prompts.push_back(MakePrompt(Role::kSystem, "init"));
  auto session = CreateSession(std::move(options));
  // Set a prompt that is close to max token length.
  Append(*session, MakeInput(std::string(kTestMaxTokens - 20, 'a')));

  std::string long_prompt(kTestMaxTokens / 3, 'a');
  TestStreamingResponder responder;
  session->Append(MakeInput(long_prompt), responder.BindRemote());
  responder.WaitForQuotaOverflow();
  EXPECT_TRUE(responder.WaitForCompletion());

  EXPECT_THAT(Prompt(*session, MakeInput("foo")),
              ElementsAre("SinitE", "U" + long_prompt + "E", "UfooEM"));
}

TEST_F(AILanguageModelTest, QuotaOverflowOnOutput) {
  // Set the execute result so the long prompt is not echoed back as the
  // response.
  fake_broker_.settings().set_execute_result({"hi"});
  auto session = CreateSession();
  // Set a prompt that is close to max token length. This string should be
  // stripped from the prompt history, while the next prompt's input and output
  // will be kept.
  EXPECT_THAT(
      Prompt(*session, MakeInput(std::string(kTestMaxTokens - 20, 'a'))),
      ElementsAre("hi"));

  // Reset result to a long response that should cause overflow. `long_response`
  // should be kept, but the previous prompt will be removed.
  std::string long_response(kTestMaxTokens / 3, 'a');
  fake_broker_.settings().set_execute_result({long_response});

  TestStreamingResponder responder;
  session->Prompt(MakeInput("foo"), nullptr, responder.BindRemote());
  responder.WaitForQuotaOverflow();
  EXPECT_TRUE(responder.WaitForCompletion());
  EXPECT_THAT(responder.responses(), ElementsAre(long_response));

  // Verify the original long response was removed. The response should contain:
  // - "foo"+long_response from the previous prompt call
  // - "bar" from the current prompt call
  fake_broker_.settings().set_execute_result({});
  EXPECT_THAT(Prompt(*session, MakeInput("bar")),
              ElementsAre("UfooEM" + long_response + "E", "UbarEM"));
}

TEST_F(AILanguageModelTest, OutputOverflowsModelMaxTokens) {
  auto session = CreateSession();
  // Add a prompt to start, this should be kept after the overflow.
  EXPECT_THAT(Prompt(*session, MakeInput("foo")),
              ElementsAreArray(FormatResponses({"UfooEM"})));

  // Set a fake response that will overrun the max model tokens.
  fake_broker_.settings().set_execute_result(
      {std::string(kTestModelMaxTokens, 'a')});
  TestStreamingResponder responder;
  session->Prompt(MakeInput("bar"), nullptr, responder.BindRemote());
  EXPECT_FALSE(responder.WaitForCompletion());
  EXPECT_EQ(responder.error_status(),
            blink::mojom::ModelStreamingResponseStatus::kErrorGenericFailure);

  // Now prompt again, the failed prompt should not be present.
  fake_broker_.settings().set_execute_result({});
  EXPECT_THAT(Prompt(*session, MakeInput("baz")),
              ElementsAreArray(FormatResponses({"UfooEM", "UbazEM"})));
}

TEST_F(AILanguageModelTest, OutputOverflowsAdditionalBuffer) {
  base::test::ScopedFeatureList scoped_feature_list;
  // Use a smaller output buffer to test the value is used correctly.
  scoped_feature_list.InitWithFeaturesAndParameters(
      {{features::kAILanguageModelOverrideConfiguration,
        {{"ai_language_model_output_buffer", "10"}}}},
      {});
  auto session = CreateSession();
  // Append an input that is just below max tokens, the next output should
  // overflow the buffer and cause an error.
  Append(*session, MakeInput(std::string(kTestMaxTokens - 5, 'a')));

  // Create a response that will be just larger than the output buffer.
  fake_broker_.settings().set_execute_result({std::string(15, 'a')});
  TestStreamingResponder responder;
  session->Prompt(MakeInput(""), nullptr, responder.BindRemote());
  EXPECT_FALSE(responder.WaitForCompletion());
  EXPECT_EQ(responder.error_status(),
            blink::mojom::ModelStreamingResponseStatus::kErrorGenericFailure);
}

TEST_F(AILanguageModelTest, OutputOverflowsContextMaxTokens) {
  auto session = CreateSession();
  // Add a prompt to start, this should be kept after the overflow.
  EXPECT_THAT(Prompt(*session, MakeInput("foo")),
              ElementsAreArray(FormatResponses({"UfooEM"})));

  // Set a fake response that will overflow the maximum context size.
  fake_broker_.settings().set_execute_result(
      {std::string(kTestMaxTokens, 'a')});
  TestStreamingResponder responder;
  session->Prompt(MakeInput("bar"), nullptr, responder.BindRemote());
  EXPECT_FALSE(responder.WaitForCompletion());
  EXPECT_EQ(responder.error_status(),
            blink::mojom::ModelStreamingResponseStatus::kErrorGenericFailure);

  // Now prompt again, the failed prompt should not be present.
  fake_broker_.settings().set_execute_result({});
  EXPECT_THAT(Prompt(*session, MakeInput("baz")),
              ElementsAreArray(FormatResponses({"UfooEM", "UbazEM"})));
}

TEST_F(AILanguageModelTest, Destroy) {
  auto session = CreateSession();
  base::RunLoop run_loop;
  session.set_disconnect_handler(run_loop.QuitClosure());
  EXPECT_THAT(Prompt(*session, MakeInput("foo")),
              ElementsAreArray(FormatResponses({"UfooEM"})));
  session->Destroy();
  run_loop.Run();
}

TEST_F(AILanguageModelTest, DestroyWithActivePrompt) {
  fake_broker_.settings().set_execute_delay(base::Minutes(1));
  auto session = CreateSession();
  base::RunLoop run_loop;
  session.set_disconnect_handler(run_loop.QuitClosure());

  TestStreamingResponder responder;
  session->Prompt(MakeInput("foo"), nullptr, responder.BindRemote());
  session->Destroy();
  run_loop.Run();

  EXPECT_FALSE(responder.WaitForCompletion());
}

struct LanguageParams {
  std::string enabled_languages;
  std::vector<std::string> expected_input_language;
  bool expect_error;
};

std::ostream& operator<<(std::ostream& os, const LanguageParams& params) {
  // Print the desired data members of params to the output stream (os)
  os << "enabled_languages:" << params.enabled_languages
     << ", expected_input_language:"
     << base::JoinString(params.expected_input_language, ", ")
     << ", expect_error:" << params.expect_error;
  return os;  // Return the ostream reference for chaining
}

class AILanguageModelTestWithLanguageParams
    : public AILanguageModelTest,
      public testing::WithParamInterface<LanguageParams> {
 public:
  AILanguageModelTestWithLanguageParams() {
    features_.InitWithFeaturesAndParameters(
        {{blink::features::kAIPromptAPI,
          {{"langs", GetParam().enabled_languages}}}},
        {});
  }
  base::test::ScopedFeatureList features_;
};

TEST_P(AILanguageModelTestWithLanguageParams, PromptWithEnabledLanguages) {
  base::test::TestFuture<blink::mojom::AIManagerCreateClientError> error_future;
  base::test::TestFuture<mojo::Remote<blink::mojom::AILanguageModel>>
      result_future;

  AITestUtils::MockCreateLanguageModelClient language_model_client;
  if (GetParam().expect_error) {
    EXPECT_CALL(language_model_client, OnResult(_, _)).Times(0);

    EXPECT_CALL(language_model_client, OnError(_, _))
        .WillOnce([&](auto error, auto quota_error_info) {
          error_future.SetValue(error);
        });
  } else {
    EXPECT_CALL(language_model_client, OnResult(_, _))
        .WillOnce([&](mojo::PendingRemote<blink::mojom::AILanguageModel>
                          language_model,
                      blink::mojom::AILanguageModelInstanceInfoPtr info) {
          result_future.SetValue(mojo::Remote<blink::mojom::AILanguageModel>(
              std::move(language_model)));
        });
    EXPECT_CALL(language_model_client, OnError(_, _)).Times(0);
  }

  auto expected_input = blink::mojom::AILanguageModelExpected::New();
  expected_input->languages.emplace();
  for (const auto& language : GetParam().expected_input_language) {
    expected_input->languages->push_back(
        blink::mojom::AILanguageCode::New(language));
  }

  auto options = blink::mojom::AILanguageModelCreateOptions::New();
  options->expected_inputs.emplace();
  options->expected_inputs->push_back(std::move(expected_input));
  GetAIManagerRemote()->CreateLanguageModel(
      language_model_client.BindNewPipeAndPassRemote(), std::move(options));
  if (GetParam().expect_error) {
    EXPECT_EQ(error_future.Take(),
              blink::mojom::AIManagerCreateClientError::kUnsupportedLanguage);
  } else {
    EXPECT_TRUE(result_future.Wait());
  }
}

INSTANTIATE_TEST_SUITE_P(
    /* no prefix */,
    AILanguageModelTestWithLanguageParams,
    ::testing::Values(LanguageParams{"en,es,ja", {"en"}, false},
                      LanguageParams{"*", {"en"}, false},
                      LanguageParams{"*", {"fr"}, true},
                      LanguageParams{"", {"en"}, true},
                      LanguageParams{"", {"fr"}, true},
                      LanguageParams{"es,ja", {"es"}, false},
                      LanguageParams{"en,es,ja", {"ja"}, false},
                      LanguageParams{"en,es,ja", {"ja", "es"}, false},
                      LanguageParams{"en,es,ja", {"ja", "fr"}, true}));

TEST_F(AILanguageModelTest, UnsupportedInputCapability) {
  ON_CALL(*mock_optimization_guide_keyed_service_, GetOnDeviceCapabilities())
      .WillByDefault(Return(on_device_model::Capabilities()));

  base::test::TestFuture<blink::mojom::AIManagerCreateClientError> future;
  AITestUtils::MockCreateLanguageModelClient language_model_client;
  EXPECT_CALL(language_model_client, OnError(_, _))
      .WillOnce(
          [&](auto error, auto quota_error_info) { future.SetValue(error); });

  auto expected_input = blink::mojom::AILanguageModelExpected::New();
  expected_input->type = blink::mojom::AILanguageModelPromptType::kImage;

  auto options = blink::mojom::AILanguageModelCreateOptions::New();
  options->expected_inputs.emplace();
  options->expected_inputs->push_back(std::move(expected_input));
  GetAIManagerRemote()->CreateLanguageModel(
      language_model_client.BindNewPipeAndPassRemote(), std::move(options));
  EXPECT_EQ(future.Take(),
            blink::mojom::AIManagerCreateClientError::kUnableToCreateSession);
}

TEST_F(AILanguageModelTest, UnsupportedOutputCapability) {
  ON_CALL(*mock_optimization_guide_keyed_service_, GetOnDeviceCapabilities())
      .WillByDefault(Return(on_device_model::Capabilities()));

  base::test::TestFuture<blink::mojom::AIManagerCreateClientError> future;
  AITestUtils::MockCreateLanguageModelClient language_model_client;
  EXPECT_CALL(language_model_client, OnError(_, _))
      .WillOnce(
          [&](auto error, auto quota_error_info) { future.SetValue(error); });

  auto expected_output = blink::mojom::AILanguageModelExpected::New();
  expected_output->type = blink::mojom::AILanguageModelPromptType::kImage;

  auto options = blink::mojom::AILanguageModelCreateOptions::New();
  options->expected_outputs.emplace();
  options->expected_outputs->push_back(std::move(expected_output));
  GetAIManagerRemote()->CreateLanguageModel(
      language_model_client.BindNewPipeAndPassRemote(), std::move(options));
  EXPECT_EQ(future.Take(),
            blink::mojom::AIManagerCreateClientError::kUnableToCreateSession);
}

TEST_F(AILanguageModelTest, MultimodalInputImageNotSpecified) {
  auto audio_input = blink::mojom::AILanguageModelExpected::New();
  audio_input->type = blink::mojom::AILanguageModelPromptType::kAudio;
  auto options = blink::mojom::AILanguageModelCreateOptions::New();
  options->expected_inputs.emplace();
  options->expected_inputs->push_back(std::move(audio_input));
  auto session = CreateSession(std::move(options));

  auto make_input = [] {
    std::vector<blink::mojom::AILanguageModelPromptPtr> input =
        MakeInput("foo");
    input.push_back(blink::mojom::AILanguageModelPrompt::New(
        Role::kUser,
        ToVector(blink::mojom::AILanguageModelPromptContent::NewBitmap(
            CreateTestBitmap(10, 10))),
        /*is_prefix=*/false));
    return input;
  };
  {
    TestStreamingResponder responder;
    session->Prompt(make_input(), nullptr, responder.BindRemote());
    EXPECT_FALSE(responder.WaitForCompletion());
    EXPECT_EQ(responder.error_status(),
              blink::mojom::ModelStreamingResponseStatus::kErrorInvalidRequest);
  }
  {
    TestStreamingResponder responder;
    session->Append(make_input(), responder.BindRemote());
    EXPECT_FALSE(responder.WaitForCompletion());
    EXPECT_EQ(responder.error_status(),
              blink::mojom::ModelStreamingResponseStatus::kErrorInvalidRequest);
  }
  base::test::TestFuture<std::optional<uint32_t>> measure_future;
  session->MeasureInputUsage(make_input(), measure_future.GetCallback());
  EXPECT_EQ(measure_future.Get(), std::nullopt);
}

TEST_F(AILanguageModelTest, MultimodalInputAudioNotSpecified) {
  auto image_input = blink::mojom::AILanguageModelExpected::New();
  image_input->type = blink::mojom::AILanguageModelPromptType::kImage;
  auto options = blink::mojom::AILanguageModelCreateOptions::New();
  options->expected_inputs.emplace();
  options->expected_inputs->push_back(std::move(image_input));
  auto session = CreateSession(std::move(options));

  auto make_input = [] {
    std::vector<blink::mojom::AILanguageModelPromptPtr> input =
        MakeInput("foo");
    input.push_back(blink::mojom::AILanguageModelPrompt::New(
        Role::kUser,
        ToVector(blink::mojom::AILanguageModelPromptContent::NewAudio(
            CreateTestAudio())),
        /*is_prefix=*/false));
    return input;
  };
  {
    TestStreamingResponder responder;
    session->Prompt(make_input(), nullptr, responder.BindRemote());
    EXPECT_FALSE(responder.WaitForCompletion());
    EXPECT_EQ(responder.error_status(),
              blink::mojom::ModelStreamingResponseStatus::kErrorInvalidRequest);
  }
  {
    TestStreamingResponder responder;
    session->Append(make_input(), responder.BindRemote());
    EXPECT_FALSE(responder.WaitForCompletion());
    EXPECT_EQ(responder.error_status(),
              blink::mojom::ModelStreamingResponseStatus::kErrorInvalidRequest);
  }
  base::test::TestFuture<std::optional<uint32_t>> measure_future;
  session->MeasureInputUsage(make_input(), measure_future.GetCallback());
  EXPECT_EQ(measure_future.Get(), std::nullopt);
}

TEST_F(AILanguageModelTest, MultimodalInput) {
  auto audio_input = blink::mojom::AILanguageModelExpected::New();
  audio_input->type = blink::mojom::AILanguageModelPromptType::kAudio;
  auto image_input = blink::mojom::AILanguageModelExpected::New();
  image_input->type = blink::mojom::AILanguageModelPromptType::kImage;

  auto options = blink::mojom::AILanguageModelCreateOptions::New();
  options->expected_inputs.emplace();
  options->expected_inputs->push_back(std::move(audio_input));
  options->expected_inputs->push_back(std::move(image_input));
  auto session = CreateSession(std::move(options));

  std::vector<blink::mojom::AILanguageModelPromptPtr> input = MakeInput("foo");
  input.push_back(blink::mojom::AILanguageModelPrompt::New(
      Role::kUser,
      ToVector(blink::mojom::AILanguageModelPromptContent::NewBitmap(
          CreateTestBitmap(10, 10))),
      /*is_prefix=*/false));
  input.push_back(blink::mojom::AILanguageModelPrompt::New(
      Role::kUser,
      ToVector(blink::mojom::AILanguageModelPromptContent::NewAudio(
          CreateTestAudio())),
      /*is_prefix=*/false));
  EXPECT_THAT(Prompt(*session, std::move(input)),
              ElementsAreArray(FormatResponses({"UfooEU<image>EU<audio>EM"})));
}

TEST_F(AILanguageModelTest, ModelDownload) {
  // This is the component id of the on device model. The `AIManager` sends
  // updates for it to the `CreateMonitor`s.
  std::string model_component_id =
      component_updater::OptimizationGuideOnDeviceModelInstallerPolicy::
          GetOnDeviceModelExtensionId();
  AITestUtils::FakeComponent model_component(model_component_id,
                                             kTestModelDownloadSize);

  EXPECT_EQ(GetAIManagerDownloadProgressObserversSize(), 0u);
  AITestUtils::FakeMonitor mock_monitor;

  EXPECT_CALL(component_update_service_, GetComponentDetails(_, _))
      .WillOnce(
          [&](const std::string& id, component_updater::CrxUpdateItem* item) {
            *item = model_component.CreateUpdateItem(
                update_client::ComponentState::kNew, 0);
            return true;
          });
  GetAIManagerRemote()->AddModelDownloadProgressObserver(
      mock_monitor.BindNewPipeAndPassRemote());

  ASSERT_TRUE(base::test::RunUntil(
      [this] { return GetAIManagerDownloadProgressObserversSize() == 1u; }));

  component_update_service_.SendUpdate(model_component.CreateUpdateItem(
      update_client::ComponentState::kDownloading, kTestModelDownloadSize));

  mock_monitor.ExpectReceivedNormalizedUpdate(0, kTestModelDownloadSize);
}

TEST_F(AILanguageModelTest, MeasureInputUsage) {
  auto session = CreateSession();
  base::test::TestFuture<std::optional<uint32_t>> measure_future;
  session->MeasureInputUsage(MakeInput("foo"), measure_future.GetCallback());
  EXPECT_EQ(measure_future.Get(), std::string("UfooEM").size());
}

TEST_F(AILanguageModelTest, TextSafetyInitialPrompts) {
  auto config = CreateConfig();
  config.set_can_skip_text_safety(false);
  fake_broker_.UpdateModelAdaptation(
      optimization_guide::FakeAdaptationAsset({.config = config}));
  auto safety_config = CreateSafetyConfig();
  auto* check = safety_config.add_request_check();
  check->mutable_input_template()->Add(
      FieldSubstitution("%s", StringValueField()));
  optimization_guide::FakeSafetyModelAsset safety_asset(
      std::move(safety_config));
  fake_broker_.UpdateSafetyModel(safety_asset.model_info());

  base::test::TestFuture<blink::mojom::AIManagerCreateClientError> future;
  AITestUtils::MockCreateLanguageModelClient language_model_client;
  EXPECT_CALL(language_model_client, OnError(_, _))
      .WillOnce(
          [&](auto error, auto quota_error_info) { future.SetValue(error); });

  auto options = blink::mojom::AILanguageModelCreateOptions::New();
  options->initial_prompts.push_back(MakePrompt(Role::kSystem, "unsafe"));
  GetAIManagerRemote()->CreateLanguageModel(
      language_model_client.BindNewPipeAndPassRemote(), std::move(options));
  EXPECT_EQ(future.Take(),
            blink::mojom::AIManagerCreateClientError::kUnableToCreateSession);
}

TEST_F(AILanguageModelTest, TextSafetyInput) {
  auto config = CreateConfig();
  config.set_can_skip_text_safety(false);
  fake_broker_.UpdateModelAdaptation(
      optimization_guide::FakeAdaptationAsset({.config = config}));
  auto safety_config = CreateSafetyConfig();
  auto* check = safety_config.add_request_check();
  check->mutable_input_template()->Add(
      FieldSubstitution("%s", StringValueField()));
  optimization_guide::FakeSafetyModelAsset safety_asset(
      std::move(safety_config));
  fake_broker_.UpdateSafetyModel(safety_asset.model_info());

  fake_broker_.settings().set_execute_result({"hi"});
  auto session = CreateSession();
  EXPECT_THAT(Prompt(*session, MakeInput("safe")), ElementsAre("hi"));

  // Fake text safety checker looks for the string "unsafe".
  TestStreamingResponder responder;
  session->Prompt(MakeInput("unsafe"), nullptr, responder.BindRemote());
  EXPECT_FALSE(responder.WaitForCompletion());
  EXPECT_EQ(responder.error_status(),
            blink::mojom::ModelStreamingResponseStatus::kErrorFiltered);
}

TEST_F(AILanguageModelTest, TextSafetyOutput) {
  auto config = CreateConfig();
  config.set_can_skip_text_safety(false);
  fake_broker_.UpdateModelAdaptation(
      optimization_guide::FakeAdaptationAsset({.config = config}));
  auto safety_config = CreateSafetyConfig();
  auto* check = safety_config.mutable_raw_output_check();
  check->mutable_input_template()->Add(
      FieldSubstitution("%s", StringValueField()));
  safety_config.mutable_partial_output_checks()->set_minimum_tokens(1000);
  optimization_guide::FakeSafetyModelAsset safety_asset(
      std::move(safety_config));
  fake_broker_.UpdateSafetyModel(safety_asset.model_info());

  // Fake text safety checker looks for the string "unsafe".
  fake_broker_.settings().set_execute_result(
      {"a", "b", "c", "d", "e", "f", "g", "unsafe", "h"});
  auto session = CreateSession();
  TestStreamingResponder responder;
  session->Prompt(MakeInput("foo"), nullptr, responder.BindRemote());
  EXPECT_FALSE(responder.WaitForCompletion());
  EXPECT_EQ(responder.error_status(),
            blink::mojom::ModelStreamingResponseStatus::kErrorFiltered);
  EXPECT_TRUE(responder.responses().empty());
}

TEST_F(AILanguageModelTest, TextSafetyOutputPartial) {
  auto config = CreateConfig();
  config.set_can_skip_text_safety(false);
  fake_broker_.UpdateModelAdaptation(
      optimization_guide::FakeAdaptationAsset({.config = config}));
  auto safety_config = CreateSafetyConfig();
  auto* check = safety_config.mutable_raw_output_check();
  check->mutable_input_template()->Add(
      FieldSubstitution("%s", StringValueField()));
  safety_config.mutable_partial_output_checks()->set_minimum_tokens(3);
  safety_config.mutable_partial_output_checks()->set_token_interval(2);
  optimization_guide::FakeSafetyModelAsset safety_asset(
      std::move(safety_config));
  fake_broker_.UpdateSafetyModel(safety_asset.model_info());

  // Fake text safety checker looks for the string "unsafe".
  fake_broker_.settings().set_execute_result(
      {"a", "b", "c", "d", "e", "f", "g", "unsafe", "h"});
  auto session = CreateSession();
  TestStreamingResponder responder;
  session->Prompt(MakeInput("foo"), nullptr, responder.BindRemote());
  EXPECT_FALSE(responder.WaitForCompletion());
  EXPECT_EQ(responder.error_status(),
            blink::mojom::ModelStreamingResponseStatus::kErrorFiltered);
  // Partial checks should still allow some output to stream.
  EXPECT_THAT(responder.responses(), ElementsAre("abc", "de", "fg"));
}

TEST_F(AILanguageModelTest, QueuesOperations) {
  base::test::TestFuture<mojo::Remote<blink::mojom::AILanguageModel>>
      fork_future;
  AITestUtils::MockCreateLanguageModelClient fork_client;
  EXPECT_CALL(fork_client, OnResult(_, _))
      .WillOnce([&](auto language_model, auto info) {
        fork_future.SetValue(mojo::Remote<blink::mojom::AILanguageModel>(
            std::move(language_model)));
      });

  auto session = CreateSession();
  TestStreamingResponder responder1;
  TestStreamingResponder responder2;
  TestStreamingResponder responder3;
  // Add three prompts and a fork, all these operations should complete
  // successfully and in order.
  session->Prompt(MakeInput("foo"), nullptr, responder1.BindRemote());
  session->Prompt(MakeInput("bar"), nullptr, responder2.BindRemote());
  session->Fork(fork_client.BindNewPipeAndPassRemote());
  session->Prompt(MakeInput("baz"), nullptr, responder3.BindRemote());

  EXPECT_TRUE(responder1.WaitForCompletion());
  EXPECT_THAT(responder1.responses(),
              ElementsAreArray(FormatResponses({"UfooEM"})));

  EXPECT_TRUE(responder2.WaitForCompletion());
  EXPECT_THAT(responder2.responses(),
              ElementsAreArray(FormatResponses({"UfooEM", "UbarEM"})));

  EXPECT_TRUE(responder3.WaitForCompletion());
  EXPECT_THAT(
      responder3.responses(),
      ElementsAreArray(FormatResponses({"UfooEM", "UbarEM", "UbazEM"})));

  EXPECT_THAT(
      Prompt(*fork_future.Take(), MakeInput("fork")),
      ElementsAreArray(FormatResponses({"UfooEM", "UbarEM", "UforkEM"})));
}

TEST_F(AILanguageModelTest, Constraint) {
  auto session = CreateSession();
  EXPECT_THAT(
      Prompt(*session, MakeInput("foo"),
             on_device_model::mojom::ResponseConstraint::NewRegex("reg")),
      ElementsAre("Constraint: regex reg", "UfooEM"));
}

TEST_F(AILanguageModelTest, Prefix) {
  auto session = CreateSession();
  std::vector<blink::mojom::AILanguageModelPromptPtr> prompts;
  prompts.push_back(MakePrompt(Role::kUser, "foo"));
  prompts.push_back(MakePrompt(Role::kAssistant, "bar", /*is_prefix=*/true));
  // Expect no 'bar' end token, nor separate model response start token.
  EXPECT_THAT(Prompt(*session, std::move(prompts)), ElementsAre("UfooEMbar"));
}

TEST_F(AILanguageModelTest, ServiceCrash) {
  auto session = CreateSession();
  TestStreamingResponder responder;
  session->Prompt(MakeInput("bar"), nullptr, responder.BindRemote());
  fake_broker_.CrashService();
  EXPECT_FALSE(responder.WaitForCompletion());
  EXPECT_EQ(responder.error_status(),
            blink::mojom::ModelStreamingResponseStatus::kErrorGenericFailure);

  // Recreating the session should be fine.
  session = CreateSession();
  EXPECT_THAT(Prompt(*session, MakeInput("foo")),
              ElementsAreArray(FormatResponses({"UfooEM"})));
}

TEST_F(AILanguageModelTest, CrashRecovery) {
  auto session = CreateSession();
  Append(*session, MakeInput("foo"));

  fake_broker_.CrashService();

  EXPECT_THAT(Prompt(*session, MakeInput("bar")),
              ElementsAre("UfooE", "UbarEM"));
}

TEST_F(AILanguageModelTest, CrashRecoveryWithMultipleCrashes) {
  auto session = CreateSession();
  Append(*session, MakeInput("foo"));
  fake_broker_.CrashService();

  Append(*session, MakeInput("bar"));
  fake_broker_.CrashService();

  EXPECT_THAT(Prompt(*session, MakeInput("baz")),
              ElementsAre("UfooEUbarE", "UbazEM"));
}

TEST_F(AILanguageModelTest, CrashRecoveryWithInitialPrompts) {
  auto options = blink::mojom::AILanguageModelCreateOptions::New();
  options->initial_prompts.push_back(MakePrompt(Role::kSystem, "hi"));
  auto session = CreateSession(std::move(options));
  Append(*session, MakeInput("foo"));

  fake_broker_.CrashService();

  EXPECT_THAT(Prompt(*session, MakeInput("bar")),
              ElementsAre("ShiE", "UfooE", "UbarEM"));
}

TEST_F(AILanguageModelTest, CrashRecoveryMeasureInputUsage) {
  auto session = CreateSession();
  Append(*session, MakeInput("foo"));

  fake_broker_.CrashService();

  base::test::TestFuture<std::optional<uint32_t>> measure_future;
  session->MeasureInputUsage(MakeInput("foo"), measure_future.GetCallback());
  EXPECT_EQ(measure_future.Get(), std::string("UfooEM").size());
}

TEST_F(AILanguageModelTest, CanCreate_WaitsForEligibility) {
  base::test::TestFuture<base::OnceCallback<void(
      optimization_guide::OnDeviceModelEligibilityReason)>>
      eligibility_future;
  EXPECT_CALL(*mock_optimization_guide_keyed_service_,
              GetOnDeviceModelEligibilityAsync(_, _, _))
      .WillOnce([&](auto feature, auto capabilities, auto callback) {
        eligibility_future.SetValue(std::move(callback));
      });

  base::test::TestFuture<blink::mojom::ModelAvailabilityCheckResult>
      result_future;
  GetAIManagerInterface()->CanCreateLanguageModel({},
                                                  result_future.GetCallback());
  // Session should not be ready until eligibility callback has run.
  EXPECT_FALSE(result_future.IsReady());
  eligibility_future.Take().Run(
      optimization_guide::OnDeviceModelEligibilityReason::kSuccess);
  EXPECT_EQ(result_future.Get(),
            blink::mojom::ModelAvailabilityCheckResult::kAvailable);
}

TEST_F(AILanguageModelTest, CanCreate_SupportedLanguages) {
  base::MockCallback<AIManager::CanCreateLanguageModelCallback> callback;
  auto options = blink::mojom::AILanguageModelCreateOptions::New();
  options->expected_inputs.emplace();
  options->expected_inputs->push_back(
      blink::mojom::AILanguageModelExpected::New(
          blink::mojom::AILanguageModelPromptType::kText,
          AITestUtils::ToMojoLanguageCodes({"en"})));
  options->expected_outputs.emplace();
  options->expected_outputs->push_back(
      blink::mojom::AILanguageModelExpected::New(
          blink::mojom::AILanguageModelPromptType::kText,
          AITestUtils::ToMojoLanguageCodes({"en"})));
  EXPECT_CALL(callback,
              Run(blink::mojom::ModelAvailabilityCheckResult::kAvailable));
  GetAIManagerInterface()->CanCreateLanguageModel(std::move(options),
                                                  callback.Get());
}

TEST_F(AILanguageModelTest, CanCreate_UnsupportedInputLanguages) {
  base::MockCallback<AIManager::CanCreateLanguageModelCallback> callback;
  EXPECT_CALL(callback, Run(blink::mojom::ModelAvailabilityCheckResult::
                                kUnavailableUnsupportedLanguage));
  auto options = blink::mojom::AILanguageModelCreateOptions::New();
  options->expected_inputs.emplace();
  options->expected_inputs->push_back(
      blink::mojom::AILanguageModelExpected::New(
          blink::mojom::AILanguageModelPromptType::kText,
          AITestUtils::ToMojoLanguageCodes({"fr"})));
  GetAIManagerInterface()->CanCreateLanguageModel(std::move(options),
                                                  callback.Get());
}

TEST_F(AILanguageModelTest, CanCreate_UnsupportedOutputLanguages) {
  base::MockCallback<AIManager::CanCreateLanguageModelCallback> callback;
  EXPECT_CALL(callback, Run(blink::mojom::ModelAvailabilityCheckResult::
                                kUnavailableUnsupportedLanguage));
  auto options = blink::mojom::AILanguageModelCreateOptions::New();
  options->expected_outputs.emplace();
  options->expected_outputs->push_back(
      blink::mojom::AILanguageModelExpected::New(
          blink::mojom::AILanguageModelPromptType::kText,
          AITestUtils::ToMojoLanguageCodes({"fr"})));
  GetAIManagerInterface()->CanCreateLanguageModel(std::move(options),
                                                  callback.Get());
}

TEST_F(AILanguageModelTest, CanCreate_UnavailableWhenAdaptationNotAvailable) {
  EXPECT_CALL(*mock_optimization_guide_keyed_service_,
              GetOnDeviceModelEligibilityAsync(_, _, _))
      .WillOnce([](auto feature, auto capabilities, auto callback) {
        std::move(callback).Run(
            optimization_guide::OnDeviceModelEligibilityReason::
                kModelAdaptationNotAvailable);
      });

  base::test::TestFuture<blink::mojom::ModelAvailabilityCheckResult>
      result_future;
  GetAIManagerInterface()->CanCreateLanguageModel({},
                                                  result_future.GetCallback());
  EXPECT_EQ(result_future.Get(), blink::mojom::ModelAvailabilityCheckResult::
                                     kUnavailableModelAdaptationNotAvailable);
}

// Tests the `AILanguageModel::Context` that's initialized with/without any
// initial prompt.
class AILanguageModelContextTest : public testing::Test {
 public:
  AILanguageModel::Context context_{kTestMaxContextToken};
};

// Tests `GetContextString()` when the context is empty.
TEST_F(AILanguageModelContextTest, TestContextOperation_Empty) {
  EXPECT_EQ(GetContextString(context_), "");
}

// Tests `GetContextString()` when some items are added to the context.
TEST_F(AILanguageModelContextTest, TestContextOperation_NonEmpty) {
  EXPECT_EQ(context_.AddContextItem(SimpleContextItem("test", 1u)),
            AILanguageModel::Context::SpaceReservationResult::kSufficientSpace);
  EXPECT_EQ(GetContextString(context_), "S: test");

  context_.AddContextItem(SimpleContextItem(" test again", 2u));
  EXPECT_EQ(GetContextString(context_), "S: testS:  test again");
}

// Tests `GetContextString()` when the items overflow.
TEST_F(AILanguageModelContextTest, TestContextOperation_Overflow) {
  EXPECT_EQ(context_.AddContextItem(SimpleContextItem("test", 1u)),
            AILanguageModel::Context::SpaceReservationResult::kSufficientSpace);
  EXPECT_EQ(GetContextString(context_), "S: test");

  // Since the total number of tokens will exceed `kTestMaxContextToken`, the
  // old item will be evicted.
  EXPECT_EQ(
      context_.AddContextItem(
          SimpleContextItem("long token", kTestMaxContextToken)),
      AILanguageModel::Context::SpaceReservationResult::kSpaceMadeAvailable);
  EXPECT_EQ(GetContextString(context_), "S: long token");
}

TEST_F(AILanguageModelContextTest, TestContextOperation_PartialOverflow) {
  EXPECT_EQ(context_.AddContextItem(SimpleContextItem("foo", 1u)),
            AILanguageModel::Context::SpaceReservationResult::kSufficientSpace);
  EXPECT_EQ(GetContextString(context_), "S: foo");

  EXPECT_EQ(context_.AddContextItem(SimpleContextItem("bar", 1u)),
            AILanguageModel::Context::SpaceReservationResult::kSufficientSpace);
  EXPECT_EQ(GetContextString(context_), "S: fooS: bar");

  // Add 1 token less than `kTestMaxContextToken` so one of the previous items
  // will be kept.
  EXPECT_EQ(
      context_.AddContextItem(
          SimpleContextItem("long token", kTestMaxContextToken - 1)),
      AILanguageModel::Context::SpaceReservationResult::kSpaceMadeAvailable);
  EXPECT_EQ(GetContextString(context_), "S: barS: long token");
}

// Tests `GetContextString()` when the items overflow on the first insertion.
TEST_F(AILanguageModelContextTest, TestContextOperation_OverflowOnFirstItem) {
  EXPECT_EQ(
      context_.AddContextItem(
          SimpleContextItem("test very long token", kTestMaxContextToken + 1u)),
      AILanguageModel::Context::SpaceReservationResult::kInsufficientSpace);
  EXPECT_EQ(GetContextString(context_), "");
}

TEST_F(AILanguageModelTest, Priority) {
  fake_broker_.settings().set_execute_result({"hi"});
  auto session = CreateSession();

  EXPECT_THAT(Prompt(*session, MakeInput("foo")), ElementsAre("hi"));

  main_rfh()->GetRenderWidgetHost()->GetView()->Hide();
  EXPECT_THAT(Prompt(*session, MakeInput("bar")),
              ElementsAre("Priority: background", "hi"));

  auto fork = Fork(*session);
  EXPECT_THAT(Prompt(*fork, MakeInput("bar")),
              ElementsAre("Priority: background", "hi"));

  main_rfh()->GetRenderWidgetHost()->GetView()->Show();
  EXPECT_THAT(Prompt(*session, MakeInput("baz")), ElementsAre("hi"));
}

}  // namespace
