| // Copyright 2025 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "chrome/browser/ai/ai_crx_component.h" |
| |
| #include <cstdint> |
| #include <memory> |
| #include <utility> |
| |
| #include "base/barrier_closure.h" |
| #include "base/task/current_thread.h" |
| #include "base/test/gtest_util.h" |
| #include "base/time/time.h" |
| #include "chrome/browser/ai/ai_model_download_progress_manager.h" |
| #include "chrome/browser/ai/ai_test_utils.h" |
| #include "chrome/browser/ai/ai_utils.h" |
| #include "testing/gmock/include/gmock/gmock-matchers.h" |
| #include "testing/gtest/include/gtest/gtest.h" |
| |
| namespace on_device_ai { |
| |
| using component_updater::CrxUpdateItem; |
| using testing::_; |
| using update_client::ComponentState; |
| |
| class AICrxComponentTest : public testing::Test { |
| public: |
| AICrxComponentTest() = default; |
| ~AICrxComponentTest() override = default; |
| |
| protected: |
| // Send a download update. |
| void SendUpdate(const AITestUtils::FakeComponent& component, |
| ComponentState state, |
| uint64_t downloaded_bytes) { |
| component_update_service_.SendUpdate( |
| component.CreateUpdateItem(state, downloaded_bytes)); |
| } |
| |
| void FastForwardBy(base::TimeDelta delta) { |
| task_environment_.FastForwardBy(delta); |
| } |
| |
| AITestUtils::FakeComponent& CreateComponent(std::string id, |
| uint64_t total_bytes) { |
| auto [iter, emplaced] = fake_components_.try_emplace(id, id, total_bytes); |
| CHECK(emplaced); |
| return iter->second; |
| } |
| |
| AITestUtils::MockComponentUpdateService component_update_service_; |
| |
| private: |
| void SetUp() override { |
| EXPECT_CALL(component_update_service_, GetComponentDetails(_, _)) |
| .WillRepeatedly([&](const std::string& id, CrxUpdateItem* item) { |
| auto iter = fake_components_.find(id); |
| if (iter == fake_components_.end()) { |
| return false; |
| } |
| |
| *item = iter->second.CreateUpdateItem( |
| update_client::ComponentState::kNew, 0); |
| |
| return true; |
| }); |
| } |
| |
| std::map<std::string, AITestUtils::FakeComponent> fake_components_; |
| |
| base::test::SingleThreadTaskEnvironment task_environment_{ |
| base::test::TaskEnvironment::TimeSource::MOCK_TIME}; |
| }; |
| |
| TEST_F(AICrxComponentTest, DoesntReceiveUpdatesForNonDownloadEvents) { |
| AIModelDownloadProgressManager manager; |
| AITestUtils::FakeMonitor monitor; |
| AITestUtils::FakeComponent& component = CreateComponent("component_id", 100); |
| |
| manager.AddObserver(monitor.BindNewPipeAndPassRemote(), |
| AICrxComponent::FromComponentIds( |
| &component_update_service_, {component.id()})); |
| |
| // Doesn't receive any update for these event states. |
| for (const auto state : { |
| ComponentState::kNew, |
| ComponentState::kChecking, |
| ComponentState::kCanUpdate, |
| ComponentState::kUpdated, |
| ComponentState::kUpdateError, |
| ComponentState::kRun, |
| }) { |
| SendUpdate(component, state, 10); |
| monitor.ExpectNoUpdate(); |
| FastForwardBy(base::Milliseconds(51)); |
| } |
| } |
| |
| TEST_F(AICrxComponentTest, |
| DoesntReceiveUpdatesForEventsWithNegativeDownloadedBytes) { |
| AIModelDownloadProgressManager manager; |
| AITestUtils::FakeMonitor monitor; |
| AITestUtils::FakeComponent& component = CreateComponent("component_id", 100); |
| |
| manager.AddObserver(monitor.BindNewPipeAndPassRemote(), |
| AICrxComponent::FromComponentIds( |
| &component_update_service_, {component.id()})); |
| |
| // Doesn't receive an update when the downloaded bytes are negative. |
| SendUpdate(component, ComponentState::kDownloading, -1); |
| monitor.ExpectNoUpdate(); |
| FastForwardBy(base::Milliseconds(51)); |
| } |
| |
| TEST_F(AICrxComponentTest, |
| DoesntReceiveUpdatesForEventsWithNegativeTotalBytes) { |
| AIModelDownloadProgressManager manager; |
| AITestUtils::FakeMonitor monitor; |
| AITestUtils::FakeComponent& component = CreateComponent("component_id", -1); |
| |
| manager.AddObserver(monitor.BindNewPipeAndPassRemote(), |
| AICrxComponent::FromComponentIds( |
| &component_update_service_, {component.id()})); |
| |
| // Doesn't receive an update when the total bytes are negative. |
| SendUpdate(component, ComponentState::kDownloading, 0); |
| monitor.ExpectNoUpdate(); |
| FastForwardBy(base::Milliseconds(51)); |
| } |
| |
| TEST_F(AICrxComponentTest, DoesntReceiveUpdatesForComponentsNotObserving) { |
| AIModelDownloadProgressManager manager; |
| AITestUtils::FakeMonitor monitor; |
| AITestUtils::FakeComponent& component_observed = |
| CreateComponent("component_id1", 100); |
| AITestUtils::FakeComponent& component_not_observed = |
| CreateComponent("component_id2", 100); |
| |
| manager.AddObserver( |
| monitor.BindNewPipeAndPassRemote(), |
| AICrxComponent::FromComponentIds(&component_update_service_, |
| {component_observed.id()})); |
| |
| // Doesn't receive any update for these event states. |
| SendUpdate(component_not_observed, ComponentState::kDownloading, 10); |
| monitor.ExpectNoUpdate(); |
| FastForwardBy(base::Milliseconds(51)); |
| } |
| |
| TEST_F(AICrxComponentTest, ObservesComponentsMidDownload) { |
| AIModelDownloadProgressManager manager; |
| AITestUtils::FakeMonitor monitor1; |
| AITestUtils::FakeMonitor monitor2; |
| AITestUtils::FakeComponent& component = CreateComponent("component_id", 100); |
| |
| // First, `monitor1` observes `component`. |
| { |
| manager.AddObserver(monitor1.BindNewPipeAndPassRemote(), |
| AICrxComponent::FromComponentIds( |
| &component_update_service_, {component.id()})); |
| } |
| |
| // Only `monitor1` will receive this update since `monitor2` is not observing. |
| SendUpdate(component, ComponentState::kDownloading, 0); |
| monitor1.ExpectReceivedNormalizedUpdate(0, component.total_bytes()); |
| monitor2.ExpectNoUpdate(); |
| |
| // Now both `monitor1` and `monitor2` are observing `component`. |
| { |
| manager.AddObserver(monitor2.BindNewPipeAndPassRemote(), |
| AICrxComponent::FromComponentIds( |
| &component_update_service_, {component.id()})); |
| } |
| |
| // Send the first update to for `monitor2` waiting more than 50ms so that both |
| // monitors receive it. |
| constexpr int64_t update1_for_monitor2 = 60; |
| FastForwardBy(base::Milliseconds(51)); |
| SendUpdate(component, ComponentState::kDownloading, update1_for_monitor2); |
| { |
| base::RunLoop run_loop; |
| base::RepeatingClosure update_callback = |
| base::BarrierClosure(2, run_loop.QuitClosure()); |
| |
| // `monitor1` should still be normalized against the total bytes of the |
| // component. |
| monitor1.ExpectReceivedNormalizedUpdate( |
| update1_for_monitor2, component.total_bytes(), update_callback); |
| |
| // This is `monitor2`'s first update so it should receive zero and be |
| // normalized against the remaining bytes. |
| monitor2.ExpectReceivedNormalizedUpdate( |
| 0, component.total_bytes() - update1_for_monitor2, update_callback); |
| |
| run_loop.Run(); |
| } |
| |
| // Send a second update to for `monitor2` waiting more than 50ms so that both |
| // monitors receive it. |
| constexpr int64_t update2_for_monitor2 = 75; |
| FastForwardBy(base::Milliseconds(51)); |
| SendUpdate(component, ComponentState::kDownloading, update2_for_monitor2); |
| { |
| base::RunLoop run_loop; |
| base::RepeatingClosure update_callback = |
| base::BarrierClosure(2, run_loop.QuitClosure()); |
| |
| // `monitor1` should still be normalized against the total bytes of the |
| // component. |
| monitor1.ExpectReceivedNormalizedUpdate( |
| update2_for_monitor2, component.total_bytes(), update_callback); |
| |
| // `monitor2` should still be normalized against the remaining bytes it |
| // observed on its first update. The downloaded bytes should also not |
| // include any bytes that were downloaded before `monitor2` started |
| // observing. |
| monitor2.ExpectReceivedNormalizedUpdate( |
| update2_for_monitor2 - update1_for_monitor2, |
| component.total_bytes() - update1_for_monitor2, update_callback); |
| |
| run_loop.Run(); |
| } |
| } |
| |
| } // namespace on_device_ai |