blob: e07f164989041a7c6d0f1124dbbb8fe3ac553f9d [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef CHROME_BROWSER_AI_AI_TEST_UTILS_H_
#define CHROME_BROWSER_AI_AI_TEST_UTILS_H_
#include <cstdint>
#include <vector>
#include "base/functional/callback_forward.h"
#include "base/supports_user_data.h"
#include "chrome/browser/ai/ai_manager.h"
#include "chrome/browser/optimization_guide/mock_optimization_guide_keyed_service.h"
#include "chrome/test/base/chrome_render_view_host_test_harness.h"
#include "components/component_updater/mock_component_updater_service.h"
#include "components/optimization_guide/core/mock_optimization_guide_model_executor.h"
#include "components/update_client/crx_update_item.h"
#include "mojo/public/cpp/bindings/receiver.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/blink/public/mojom/ai/ai_common.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_language_model.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_manager.mojom.h"
#include "third_party/blink/public/mojom/ai/model_download_progress_observer.mojom.h"
#include "third_party/blink/public/mojom/ai/model_streaming_responder.mojom.h"
class AITestUtils {
public:
class MockModelStreamingResponder
: public blink::mojom::ModelStreamingResponder {
public:
MockModelStreamingResponder();
~MockModelStreamingResponder() override;
MockModelStreamingResponder(const MockModelStreamingResponder&) = delete;
MockModelStreamingResponder& operator=(const MockModelStreamingResponder&) =
delete;
mojo::PendingRemote<blink::mojom::ModelStreamingResponder>
BindNewPipeAndPassRemote();
MOCK_METHOD(void, OnStreaming, (const std::string& text), (override));
MOCK_METHOD(void,
OnError,
(blink::mojom::ModelStreamingResponseStatus status,
blink::mojom::QuotaErrorInfoPtr quota_error_info),
(override));
MOCK_METHOD(void,
OnCompletion,
(blink::mojom::ModelExecutionContextInfoPtr context_info),
(override));
MOCK_METHOD(void, OnQuotaOverflow, (), (override));
private:
mojo::Receiver<blink::mojom::ModelStreamingResponder> receiver_{this};
};
class MockModelDownloadProgressMonitor
: public blink::mojom::ModelDownloadProgressObserver {
public:
MockModelDownloadProgressMonitor();
~MockModelDownloadProgressMonitor() override;
MockModelDownloadProgressMonitor(const MockModelDownloadProgressMonitor&) =
delete;
MockModelDownloadProgressMonitor& operator=(
const MockModelDownloadProgressMonitor&) = delete;
mojo::PendingRemote<blink::mojom::ModelDownloadProgressObserver>
BindNewPipeAndPassRemote();
// `blink::mojom::ModelDownloadProgressObserver` implementation.
MOCK_METHOD(void,
OnDownloadProgressUpdate,
(uint64_t downloaded_bytes, uint64_t total_bytes),
(override));
private:
mojo::Receiver<blink::mojom::ModelDownloadProgressObserver> receiver_{this};
};
class MockCreateLanguageModelClient
: public blink::mojom::AIManagerCreateLanguageModelClient {
public:
MockCreateLanguageModelClient();
~MockCreateLanguageModelClient() override;
MockCreateLanguageModelClient(const MockCreateLanguageModelClient&) =
delete;
MockCreateLanguageModelClient& operator=(
const MockCreateLanguageModelClient&) = delete;
mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient>
BindNewPipeAndPassRemote();
MOCK_METHOD(
void,
OnResult,
(mojo::PendingRemote<blink::mojom::AILanguageModel> language_model,
blink::mojom::AILanguageModelInstanceInfoPtr info),
(override));
MOCK_METHOD(void,
OnError,
(blink::mojom::AIManagerCreateClientError error,
blink::mojom::QuotaErrorInfoPtr quota_error_info),
(override));
private:
mojo::Receiver<blink::mojom::AIManagerCreateLanguageModelClient> receiver_{
this};
};
class FakeMonitor {
public:
mojo::PendingRemote<blink::mojom::ModelDownloadProgressObserver>
BindNewPipeAndPassRemote();
// Expects that the next `OnDownloadProgressUpdate` is called with
// `expected_downloaded_bytes` and `expected_total_bytes`. Once it receives
// an update, calls `callback`.
void ExpectReceivedUpdate(uint64_t expected_downloaded_bytes,
uint64_t expected_total_bytes,
base::OnceClosure callback);
// Overload that waits until the update is received.
void ExpectReceivedUpdate(uint64_t expected_downloaded_bytes,
uint64_t expected_total_bytes);
// Same as `ExpectReceivedUpdate` except it normalizes
// `expected_downloaded_bytes` and `expected_total_bytes`.
void ExpectReceivedNormalizedUpdate(uint64_t expected_downloaded_bytes,
uint64_t expected_total_bytes,
base::OnceClosure callback);
// Overload that waits until the update is received.
void ExpectReceivedNormalizedUpdate(uint64_t expected_downloaded_bytes,
uint64_t expected_total_bytes);
void ExpectNoUpdate();
private:
AITestUtils::MockModelDownloadProgressMonitor mock_monitor_;
};
class FakeComponent {
public:
FakeComponent(std::string id, uint64_t total_bytes);
component_updater::CrxUpdateItem CreateUpdateItem(
update_client::ComponentState state,
uint64_t downloaded_bytes) const;
const std::string& id() { return id_; }
uint64_t total_bytes() { return total_bytes_; }
private:
std::string id_;
uint64_t total_bytes_;
};
class MockComponentUpdateService
: public component_updater::MockComponentUpdateService {
public:
MockComponentUpdateService();
~MockComponentUpdateService() override;
void AddObserver(Observer* observer) override;
void RemoveObserver(Observer* observer) override;
void SendUpdate(const component_updater::CrxUpdateItem& item);
// Not copyable or movable.
MockComponentUpdateService(const MockComponentUpdateService&) = delete;
MockComponentUpdateService& operator=(const MockComponentUpdateService&) =
delete;
private:
base::ObserverList<Observer>::Unchecked observer_list_;
};
class AITestBase : public ChromeRenderViewHostTestHarness {
public:
AITestBase();
~AITestBase() override;
void SetUp() override;
void TearDown() override;
protected:
virtual void SetupMockOptimizationGuideKeyedService();
virtual void SetupNullOptimizationGuideKeyedService();
// Optimization guide keyed service should be set up before calling this
// method.
void SetupMockSession();
blink::mojom::AIManager* GetAIManagerInterface();
mojo::Remote<blink::mojom::AIManager> GetAIManagerRemote();
size_t GetAIManagerContextBoundObjectSetSize();
size_t GetAIManagerDownloadProgressObserversSize();
raw_ptr<MockOptimizationGuideKeyedService>
mock_optimization_guide_keyed_service_;
testing::NiceMock<optimization_guide::MockSession> session_;
AITestUtils::MockComponentUpdateService component_update_service_;
std::unique_ptr<AIManager> ai_manager_;
};
static const optimization_guide::TokenLimits& GetFakeTokenLimits();
static const optimization_guide::proto::Any& GetFakeFeatureMetadata();
// Converts string language codes to AILanguageCode mojo struct.
static std::vector<blink::mojom::AILanguageCodePtr> ToMojoLanguageCodes(
const std::vector<std::string>& language_codes);
};
#endif // CHROME_BROWSER_AI_AI_TEST_UTILS_H_