// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/browser/ai/ai_test_utils.h"

#include <cstdint>
#include <utility>

#include "chrome/browser/ai/ai_manager.h"
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service_factory.h"
#include "third_party/blink/public/mojom/ai/model_download_progress_observer.mojom.h"

AITestUtils::MockModelStreamingResponder::MockModelStreamingResponder() =
    default;
AITestUtils::MockModelStreamingResponder::~MockModelStreamingResponder() =
    default;

mojo::PendingRemote<blink::mojom::ModelStreamingResponder>
AITestUtils::MockModelStreamingResponder::BindNewPipeAndPassRemote() {
  return receiver_.BindNewPipeAndPassRemote();
}

AITestUtils::MockModelDownloadProgressMonitor::
    MockModelDownloadProgressMonitor() = default;
AITestUtils::MockModelDownloadProgressMonitor::
    ~MockModelDownloadProgressMonitor() = default;

mojo::PendingRemote<blink::mojom::ModelDownloadProgressObserver>
AITestUtils::MockModelDownloadProgressMonitor::BindNewPipeAndPassRemote() {
  return receiver_.BindNewPipeAndPassRemote();
}

AITestUtils::MockCreateLanguageModelClient::MockCreateLanguageModelClient() =
    default;
AITestUtils::MockCreateLanguageModelClient::~MockCreateLanguageModelClient() =
    default;

mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
AITestUtils::MockCreateLanguageModelClient::BindNewPipeAndPassRemote() {
  return receiver_.BindNewPipeAndPassRemote();
}

mojo::PendingRemote<blink::mojom::ModelDownloadProgressObserver>
AITestUtils::FakeMonitor::BindNewPipeAndPassRemote() {
  return mock_monitor_.BindNewPipeAndPassRemote();
}

void AITestUtils::FakeMonitor::ExpectReceivedUpdate(
    uint64_t expected_downloaded_bytes,
    uint64_t expected_total_bytes,
    base::OnceClosure callback) {
  EXPECT_CALL(mock_monitor_, OnDownloadProgressUpdate(testing::_, testing::_))
      .WillOnce([callback = std::move(callback), expected_downloaded_bytes,
                 expected_total_bytes](uint64_t downloaded_bytes,
                                       uint64_t total_bytes) mutable {
        EXPECT_EQ(downloaded_bytes, expected_downloaded_bytes);
        EXPECT_EQ(total_bytes, expected_total_bytes);
        std::move(callback).Run();
      });
}

void AITestUtils::FakeMonitor::ExpectReceivedUpdate(
    uint64_t expected_downloaded_bytes,
    uint64_t expected_total_bytes) {
  base::RunLoop download_progress_run_loop;
  ExpectReceivedUpdate(expected_downloaded_bytes, expected_total_bytes,
                       download_progress_run_loop.QuitClosure());
  download_progress_run_loop.Run();
}

void AITestUtils::FakeMonitor::ExpectReceivedNormalizedUpdate(
    uint64_t expected_downloaded_bytes,
    uint64_t expected_total_bytes,
    base::OnceClosure callback) {
  ExpectReceivedUpdate(AIUtils::NormalizeModelDownloadProgress(
                           expected_downloaded_bytes, expected_total_bytes),
                       AIUtils::kNormalizedDownloadProgressMax,
                       std::move(callback));
}

void AITestUtils::FakeMonitor::ExpectReceivedNormalizedUpdate(
    uint64_t expected_downloaded_bytes,
    uint64_t expected_total_bytes) {
  ExpectReceivedUpdate(AIUtils::NormalizeModelDownloadProgress(
                           expected_downloaded_bytes, expected_total_bytes),
                       AIUtils::kNormalizedDownloadProgressMax);
}

void AITestUtils::FakeMonitor::ExpectNoUpdate() {
  EXPECT_CALL(mock_monitor_, OnDownloadProgressUpdate(testing::_, testing::_))
      .Times(0);
}

AITestUtils::FakeComponent::FakeComponent(std::string id, uint64_t total_bytes)
    : id_(std::move(id)), total_bytes_(total_bytes) {}

component_updater::CrxUpdateItem AITestUtils::FakeComponent::CreateUpdateItem(
    update_client::ComponentState state,
    uint64_t downloaded_bytes) const {
  component_updater::CrxUpdateItem update_item;
  update_item.state = state;
  update_item.id = id_;
  update_item.downloaded_bytes = downloaded_bytes;
  update_item.total_bytes = total_bytes_;
  return update_item;
}

AITestUtils::MockComponentUpdateService::MockComponentUpdateService() = default;
AITestUtils::MockComponentUpdateService::~MockComponentUpdateService() =
    default;

void AITestUtils::MockComponentUpdateService::AddObserver(Observer* observer) {
  observer_list_.AddObserver(observer);
}

void AITestUtils::MockComponentUpdateService::RemoveObserver(
    Observer* observer) {
  observer_list_.RemoveObserver(observer);
}

void AITestUtils::MockComponentUpdateService::SendUpdate(
    const component_updater::CrxUpdateItem& item) {
  for (Observer& observer : observer_list_) {
    observer.OnEvent(item);
  }
}

AITestUtils::AITestBase::AITestBase() = default;
AITestUtils::AITestBase::~AITestBase() = default;

void AITestUtils::AITestBase::SetUp() {
  ChromeRenderViewHostTestHarness::SetUp();
  ai_manager_ = std::make_unique<AIManager>(
      main_rfh()->GetBrowserContext(), &component_update_service_, main_rfh());
}

void AITestUtils::AITestBase::TearDown() {
  mock_optimization_guide_keyed_service_ = nullptr;
  ai_manager_.reset();
  ChromeRenderViewHostTestHarness::TearDown();
}

void AITestUtils::AITestBase::SetupMockOptimizationGuideKeyedService() {
  mock_optimization_guide_keyed_service_ =
      static_cast<MockOptimizationGuideKeyedService*>(
          OptimizationGuideKeyedServiceFactory::GetInstance()
              ->SetTestingFactoryAndUse(
                  profile(),
                  base::BindRepeating([](content::BrowserContext* context)
                                          -> std::unique_ptr<KeyedService> {
                    return std::make_unique<
                        testing::NiceMock<MockOptimizationGuideKeyedService>>();
                  })));
  ON_CALL(*mock_optimization_guide_keyed_service_,
          GetOnDeviceModelEligibilityAsync(testing::_, testing::_, testing::_))
      .WillByDefault([](auto feature, auto capabilities, auto callback) {
        std::move(callback).Run(
            optimization_guide::OnDeviceModelEligibilityReason::kSuccess);
      });
}

void AITestUtils::AITestBase::SetupNullOptimizationGuideKeyedService() {
  OptimizationGuideKeyedServiceFactory::GetInstance()->SetTestingFactoryAndUse(
      profile(), base::BindRepeating(
                     [](content::BrowserContext* context)
                         -> std::unique_ptr<KeyedService> { return nullptr; }));
}

void AITestUtils::AITestBase::SetupMockSession() {
  ON_CALL(*mock_optimization_guide_keyed_service_,
          StartSession(testing::_, testing::_, testing::_))
      .WillByDefault([&] {
        return std::make_unique<
            testing::NiceMock<optimization_guide::MockSession>>(&session_);
      });
  ON_CALL(session_, GetExecutionInputSizeInTokens(testing::_, testing::_))
      .WillByDefault(
          [&](optimization_guide::MultimodalMessageReadView request_metadata,
              optimization_guide::OptimizationGuideModelSizeInTokenCallback
                  callback) {
            std::move(callback).Run(
                blink::mojom::kWritingAssistanceMaxInputTokenSize);
          });
}

blink::mojom::AIManager* AITestUtils::AITestBase::GetAIManagerInterface() {
  return ai_manager_.get();
}

mojo::Remote<blink::mojom::AIManager>
AITestUtils::AITestBase::GetAIManagerRemote() {
  mojo::Remote<blink::mojom::AIManager> ai_manager;
  ai_manager_->AddReceiver(ai_manager.BindNewPipeAndPassReceiver());
  return ai_manager;
}

size_t AITestUtils::AITestBase::GetAIManagerDownloadProgressObserversSize() {
  return ai_manager_->GetDownloadProgressObserversSizeForTesting();
}

size_t AITestUtils::AITestBase::GetAIManagerContextBoundObjectSetSize() {
  return ai_manager_->GetContextBoundObjectSetSizeForTesting();
}

// static
const optimization_guide::TokenLimits& AITestUtils::GetFakeTokenLimits() {
  static const optimization_guide::TokenLimits limits{
      .max_tokens = 4096,
      .max_context_tokens = 2048,
      .max_execute_tokens = 1024,
      .max_output_tokens = 1024,
  };
  return limits;
}

// static
const optimization_guide::proto::Any& AITestUtils::GetFakeFeatureMetadata() {
  static base::NoDestructor<optimization_guide::proto::Any> data;
  return *data;
}

// static
std::vector<blink::mojom::AILanguageCodePtr> AITestUtils::ToMojoLanguageCodes(
    const std::vector<std::string>& language_codes) {
  std::vector<blink::mojom::AILanguageCodePtr> result;
  result.reserve(language_codes.size());
  std::ranges::transform(
      language_codes, std::back_inserter(result),
      [](const std::string& language_code) {
        return blink::mojom::AILanguageCode::New(language_code);
      });
  return result;
}
