blob: 733a3da837221e90378de0232c831a13e19f6261 [file] [log] [blame]
// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/passage_embeddings/passage_embedder_model_observer.h"
#include <memory>
#include "base/memory/raw_ptr.h"
#include "base/test/task_environment.h"
#include "base/test/test_future.h"
#include "components/optimization_guide/core/delivery/test_optimization_guide_model_provider.h"
#include "components/passage_embeddings/passage_embeddings_service_controller.h"
#include "components/passage_embeddings/passage_embeddings_test_util.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace passage_embeddings {
class FakePassageEmbeddingsServiceController
: public passage_embeddings::PassageEmbeddingsServiceController {
public:
explicit FakePassageEmbeddingsServiceController(
base::test::TestFuture<bool>* model_info_future)
: model_info_received_future_(model_info_future) {}
~FakePassageEmbeddingsServiceController() override = default;
// passage_embeddings::PassageEmbeddingsServiceController:
bool MaybeUpdateModelInfo(
base::optional_ref<const optimization_guide::ModelInfo> model_info)
override {
const bool received_model_info = model_info.has_value();
model_info_received_future_->SetValue(received_model_info);
return received_model_info;
}
void MaybeLaunchService() override {}
void ResetServiceRemote() override {}
protected:
raw_ptr<base::test::TestFuture<bool>> model_info_received_future_;
};
class TestOptimizationGuideModelProvider
: public optimization_guide::TestOptimizationGuideModelProvider {
public:
explicit TestOptimizationGuideModelProvider(
base::test::TestFuture<bool>* target_observed_future)
: target_observed_future_(target_observed_future) {}
// optimization_guide::OptimizationGuideModelProvider:
void AddObserverForOptimizationTargetModel(
optimization_guide::proto::OptimizationTarget optimization_target,
const std::optional<optimization_guide::proto::Any>& model_metadata,
scoped_refptr<base::SequencedTaskRunner> model_task_runner,
optimization_guide::OptimizationTargetModelObserver* observer) override {
target_observed_future_->SetValue(
optimization_target ==
optimization_guide::proto::OPTIMIZATION_TARGET_PASSAGE_EMBEDDER);
observer_list_.AddObserver(observer);
NotifyObservers();
}
void RemoveObserverForOptimizationTargetModel(
optimization_guide::proto::OptimizationTarget optimization_target,
optimization_guide::OptimizationTargetModelObserver* observer) override {
observer_list_.RemoveObserver(observer);
}
// Set the model info to be sent to the observer.
void SetModelInfo(std::unique_ptr<optimization_guide::ModelInfo> model_info) {
model_info_ = std::move(model_info);
NotifyObservers();
}
private:
void NotifyObservers() {
if (model_info_) {
observer_list_.Notify(
&optimization_guide::OptimizationTargetModelObserver::OnModelUpdated,
optimization_guide::proto::OPTIMIZATION_TARGET_PASSAGE_EMBEDDER,
*model_info_);
} else {
observer_list_.Notify(
&optimization_guide::OptimizationTargetModelObserver::OnModelUpdated,
optimization_guide::proto::OPTIMIZATION_TARGET_PASSAGE_EMBEDDER,
std::nullopt);
}
}
base::test::TaskEnvironment task_environment_;
raw_ptr<base::test::TestFuture<bool>> target_observed_future_;
base::ObserverList<optimization_guide::OptimizationTargetModelObserver>
observer_list_;
std::unique_ptr<optimization_guide::ModelInfo> model_info_;
};
class PassageEmbedderModelObserverTest : public testing::Test {
protected:
base::test::TestFuture<bool> target_observed_future_;
base::test::TestFuture<bool> model_info_received_future_;
};
TEST_F(PassageEmbedderModelObserverTest, ObservesTargetAndNotifiesObserver) {
auto model_provider = std::make_unique<TestOptimizationGuideModelProvider>(
&target_observed_future_);
EXPECT_FALSE(target_observed_future_.IsReady());
auto service_controller =
std::make_unique<FakePassageEmbeddingsServiceController>(
&model_info_received_future_);
EXPECT_FALSE(model_info_received_future_.IsReady());
auto passage_embedder_model_observer =
std::make_unique<PassageEmbedderModelObserver>(
model_provider.get(), service_controller.get(), false);
EXPECT_TRUE(target_observed_future_.IsReady());
EXPECT_TRUE(target_observed_future_.Take());
EXPECT_TRUE(model_info_received_future_.IsReady());
EXPECT_FALSE(model_info_received_future_.Take());
model_provider->SetModelInfo(GetBuilderWithValidModelInfo().Build());
EXPECT_TRUE(model_info_received_future_.IsReady());
EXPECT_TRUE(model_info_received_future_.Take());
}
} // namespace passage_embeddings