| // Copyright 2024 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #ifndef CHROME_BROWSER_AI_AI_TEST_UTILS_H_ |
| #define CHROME_BROWSER_AI_AI_TEST_UTILS_H_ |
| |
| #include <cstdint> |
| #include <vector> |
| |
| #include "base/functional/callback_forward.h" |
| #include "base/supports_user_data.h" |
| #include "chrome/browser/ai/ai_manager.h" |
| #include "chrome/browser/optimization_guide/mock_optimization_guide_keyed_service.h" |
| #include "chrome/test/base/chrome_render_view_host_test_harness.h" |
| #include "components/component_updater/mock_component_updater_service.h" |
| #include "components/optimization_guide/core/mock_optimization_guide_model_executor.h" |
| #include "components/update_client/crx_update_item.h" |
| #include "mojo/public/cpp/bindings/receiver.h" |
| #include "mojo/public/cpp/bindings/remote.h" |
| #include "testing/gmock/include/gmock/gmock.h" |
| #include "testing/gtest/include/gtest/gtest.h" |
| #include "third_party/blink/public/mojom/ai/ai_common.mojom.h" |
| #include "third_party/blink/public/mojom/ai/ai_language_model.mojom.h" |
| #include "third_party/blink/public/mojom/ai/ai_manager.mojom.h" |
| #include "third_party/blink/public/mojom/ai/model_download_progress_observer.mojom.h" |
| #include "third_party/blink/public/mojom/ai/model_streaming_responder.mojom.h" |
| |
| class AITestUtils { |
| public: |
| class MockModelStreamingResponder |
| : public blink::mojom::ModelStreamingResponder { |
| public: |
| MockModelStreamingResponder(); |
| ~MockModelStreamingResponder() override; |
| MockModelStreamingResponder(const MockModelStreamingResponder&) = delete; |
| MockModelStreamingResponder& operator=(const MockModelStreamingResponder&) = |
| delete; |
| |
| mojo::PendingRemote<blink::mojom::ModelStreamingResponder> |
| BindNewPipeAndPassRemote(); |
| |
| MOCK_METHOD(void, OnStreaming, (const std::string& text), (override)); |
| MOCK_METHOD(void, |
| OnError, |
| (blink::mojom::ModelStreamingResponseStatus status, |
| blink::mojom::QuotaErrorInfoPtr quota_error_info), |
| (override)); |
| MOCK_METHOD(void, |
| OnCompletion, |
| (blink::mojom::ModelExecutionContextInfoPtr context_info), |
| (override)); |
| MOCK_METHOD(void, OnQuotaOverflow, (), (override)); |
| |
| private: |
| mojo::Receiver<blink::mojom::ModelStreamingResponder> receiver_{this}; |
| }; |
| |
| class MockModelDownloadProgressMonitor |
| : public blink::mojom::ModelDownloadProgressObserver { |
| public: |
| MockModelDownloadProgressMonitor(); |
| ~MockModelDownloadProgressMonitor() override; |
| MockModelDownloadProgressMonitor(const MockModelDownloadProgressMonitor&) = |
| delete; |
| MockModelDownloadProgressMonitor& operator=( |
| const MockModelDownloadProgressMonitor&) = delete; |
| |
| mojo::PendingRemote<blink::mojom::ModelDownloadProgressObserver> |
| BindNewPipeAndPassRemote(); |
| |
| // `blink::mojom::ModelDownloadProgressObserver` implementation. |
| MOCK_METHOD(void, |
| OnDownloadProgressUpdate, |
| (uint64_t downloaded_bytes, uint64_t total_bytes), |
| (override)); |
| |
| private: |
| mojo::Receiver<blink::mojom::ModelDownloadProgressObserver> receiver_{this}; |
| }; |
| |
| class MockCreateLanguageModelClient |
| : public blink::mojom::AIManagerCreateLanguageModelClient { |
| public: |
| MockCreateLanguageModelClient(); |
| ~MockCreateLanguageModelClient() override; |
| MockCreateLanguageModelClient(const MockCreateLanguageModelClient&) = |
| delete; |
| MockCreateLanguageModelClient& operator=( |
| const MockCreateLanguageModelClient&) = delete; |
| |
| mojo::PendingRemote<blink::mojom::AIManagerCreateLanguageModelClient> |
| BindNewPipeAndPassRemote(); |
| |
| MOCK_METHOD( |
| void, |
| OnResult, |
| (mojo::PendingRemote<blink::mojom::AILanguageModel> language_model, |
| blink::mojom::AILanguageModelInstanceInfoPtr info), |
| (override)); |
| |
| MOCK_METHOD(void, |
| OnError, |
| (blink::mojom::AIManagerCreateClientError error, |
| blink::mojom::QuotaErrorInfoPtr quota_error_info), |
| (override)); |
| |
| private: |
| mojo::Receiver<blink::mojom::AIManagerCreateLanguageModelClient> receiver_{ |
| this}; |
| }; |
| |
| class FakeMonitor { |
| public: |
| mojo::PendingRemote<blink::mojom::ModelDownloadProgressObserver> |
| BindNewPipeAndPassRemote(); |
| |
| // Expects that the next `OnDownloadProgressUpdate` is called with |
| // `expected_downloaded_bytes` and `expected_total_bytes`. Once it receives |
| // an update, calls `callback`. |
| void ExpectReceivedUpdate(uint64_t expected_downloaded_bytes, |
| uint64_t expected_total_bytes, |
| base::OnceClosure callback); |
| |
| // Overload that waits until the update is received. |
| void ExpectReceivedUpdate(uint64_t expected_downloaded_bytes, |
| uint64_t expected_total_bytes); |
| |
| // Same as `ExpectReceivedUpdate` except it normalizes |
| // `expected_downloaded_bytes` and `expected_total_bytes`. |
| void ExpectReceivedNormalizedUpdate(uint64_t expected_downloaded_bytes, |
| uint64_t expected_total_bytes, |
| base::OnceClosure callback); |
| |
| // Overload that waits until the update is received. |
| void ExpectReceivedNormalizedUpdate(uint64_t expected_downloaded_bytes, |
| uint64_t expected_total_bytes); |
| |
| void ExpectNoUpdate(); |
| |
| private: |
| AITestUtils::MockModelDownloadProgressMonitor mock_monitor_; |
| }; |
| |
| class FakeComponent { |
| public: |
| FakeComponent(std::string id, uint64_t total_bytes); |
| |
| component_updater::CrxUpdateItem CreateUpdateItem( |
| update_client::ComponentState state, |
| uint64_t downloaded_bytes) const; |
| |
| const std::string& id() { return id_; } |
| uint64_t total_bytes() { return total_bytes_; } |
| |
| private: |
| std::string id_; |
| uint64_t total_bytes_; |
| }; |
| |
| class MockComponentUpdateService |
| : public component_updater::MockComponentUpdateService { |
| public: |
| MockComponentUpdateService(); |
| ~MockComponentUpdateService() override; |
| |
| void AddObserver(Observer* observer) override; |
| |
| void RemoveObserver(Observer* observer) override; |
| |
| void SendUpdate(const component_updater::CrxUpdateItem& item); |
| |
| // Not copyable or movable. |
| MockComponentUpdateService(const MockComponentUpdateService&) = delete; |
| MockComponentUpdateService& operator=(const MockComponentUpdateService&) = |
| delete; |
| |
| private: |
| base::ObserverList<Observer>::Unchecked observer_list_; |
| }; |
| |
| class AITestBase : public ChromeRenderViewHostTestHarness { |
| public: |
| AITestBase(); |
| ~AITestBase() override; |
| |
| void SetUp() override; |
| void TearDown() override; |
| |
| protected: |
| virtual void SetupMockOptimizationGuideKeyedService(); |
| virtual void SetupNullOptimizationGuideKeyedService(); |
| |
| // Optimization guide keyed service should be set up before calling this |
| // method. |
| void SetupMockSession(); |
| |
| blink::mojom::AIManager* GetAIManagerInterface(); |
| mojo::Remote<blink::mojom::AIManager> GetAIManagerRemote(); |
| size_t GetAIManagerContextBoundObjectSetSize(); |
| size_t GetAIManagerDownloadProgressObserversSize(); |
| |
| raw_ptr<MockOptimizationGuideKeyedService> |
| mock_optimization_guide_keyed_service_; |
| testing::NiceMock<optimization_guide::MockSession> session_; |
| AITestUtils::MockComponentUpdateService component_update_service_; |
| |
| std::unique_ptr<AIManager> ai_manager_; |
| }; |
| |
| static const optimization_guide::TokenLimits& GetFakeTokenLimits(); |
| static const optimization_guide::proto::Any& GetFakeFeatureMetadata(); |
| |
| // Converts string language codes to AILanguageCode mojo struct. |
| static std::vector<blink::mojom::AILanguageCodePtr> ToMojoLanguageCodes( |
| const std::vector<std::string>& language_codes); |
| }; |
| |
| #endif // CHROME_BROWSER_AI_AI_TEST_UTILS_H_ |