blob: f68542827cee7bb29310cda543507c69ec127259 [file] [log] [blame]
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/new_tab_page/modules/history_clusters/ranking/history_clusters_module_ranking_model_handler.h"
#include "base/path_service.h"
#include "base/run_loop.h"
#include "base/test/task_environment.h"
#include "chrome/browser/new_tab_page/modules/history_clusters/ranking/history_clusters_module_ranking_model_metadata.pb.h"
#include "chrome/browser/new_tab_page/modules/history_clusters/ranking/history_clusters_module_ranking_signals.h"
#include "components/optimization_guide/core/test_model_info_builder.h"
#include "components/optimization_guide/core/test_optimization_guide_model_provider.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace {
using ::testing::ElementsAre;
class FakeModelProvider
: public optimization_guide::TestOptimizationGuideModelProvider {
public:
void AddObserverForOptimizationTargetModel(
optimization_guide::proto::OptimizationTarget optimization_target,
const absl::optional<optimization_guide::proto::Any>& model_metadata,
optimization_guide::OptimizationTargetModelObserver* observer) override {
CHECK_EQ(
optimization_target,
optimization_guide::proto::
OPTIMIZATION_TARGET_NEW_TAB_PAGE_HISTORY_CLUSTERS_MODULE_RANKING);
was_registered_ = true;
}
bool was_registered() const { return was_registered_; }
private:
bool was_registered_ = false;
};
class HistoryClustersModuleRankingModelHandlerTest : public testing::Test {
public:
HistoryClustersModuleRankingModelHandlerTest() {
model_provider_ = std::make_unique<FakeModelProvider>();
model_handler_ = std::make_unique<HistoryClustersModuleRankingModelHandler>(
model_provider_.get());
EXPECT_TRUE(model_provider_->was_registered());
// Just use the omnibox fake model since it just adds two floats which
// is all we need right now.
base::FilePath source_root_dir;
base::PathService::Get(base::DIR_SOURCE_ROOT, &source_root_dir);
model_file_path_ = source_root_dir.AppendASCII("components")
.AppendASCII("test")
.AppendASCII("data")
.AppendASCII("omnibox")
.AppendASCII("adder.tflite");
}
~HistoryClustersModuleRankingModelHandlerTest() override {
// Make sure everything is deleted on the bg thread the underlying executor
// is on.
model_handler_.reset();
task_environment_.RunUntilIdle();
}
HistoryClustersModuleRankingModelHandler* model_handler() {
return model_handler_.get();
}
void PushModelFileToModelExecutor(
absl::optional<
new_tab_page::proto::HistoryClustersModuleRankingModelMetadata>
metadata) {
absl::optional<optimization_guide::proto::Any> any;
// Craft a correct Any proto in the case we passed in metadata.
if (metadata.has_value()) {
std::string serialized_metadata;
(*metadata).SerializeToString(&serialized_metadata);
optimization_guide::proto::Any any_proto;
any = absl::make_optional(any_proto);
any->set_value(serialized_metadata);
any->set_type_url(
"type.googleapis.com/"
"new_tab_page.protos.HistoryClustersModuleRankingModelMetadata");
}
auto model_metadata = optimization_guide::TestModelInfoBuilder()
.SetModelMetadata(any)
.SetModelFilePath(model_file_path_)
.SetVersion(123)
.Build();
model_handler()->OnModelUpdated(
optimization_guide::proto::
OPTIMIZATION_TARGET_NEW_TAB_PAGE_HISTORY_CLUSTERS_MODULE_RANKING,
*model_metadata);
task_environment_.RunUntilIdle();
}
std::vector<float> GetOutputs(
std::vector<HistoryClustersModuleRankingSignals> inputs) {
std::vector<float> outputs;
base::RunLoop run_loop;
model_handler()->ExecuteBatch(
inputs, base::BindOnce(
[](base::RunLoop* run_loop, std::vector<float>* out_outputs,
std::vector<float> outputs) {
*out_outputs = std::move(outputs);
run_loop->Quit();
},
&run_loop, &outputs));
run_loop.Run();
return outputs;
}
private:
std::unique_ptr<FakeModelProvider> model_provider_;
std::unique_ptr<HistoryClustersModuleRankingModelHandler> model_handler_;
base::FilePath model_file_path_;
base::test::TaskEnvironment task_environment_;
};
TEST_F(HistoryClustersModuleRankingModelHandlerTest, ModelNotAvailable) {
EXPECT_FALSE(model_handler()->CanExecuteAvailableModel());
}
TEST_F(HistoryClustersModuleRankingModelHandlerTest, ModelUpdatedBadMetadata) {
PushModelFileToModelExecutor(/*metadata=*/absl::nullopt);
EXPECT_FALSE(model_handler()->CanExecuteAvailableModel());
HistoryClustersModuleRankingSignals inputs1;
inputs1.duration_since_most_recent_visit = base::Minutes(2);
inputs1.belongs_to_boosted_category = false;
HistoryClustersModuleRankingSignals inputs2;
inputs2.duration_since_most_recent_visit = base::Minutes(5);
inputs2.belongs_to_boosted_category = true;
// Should return 0 if tried to execute model.
EXPECT_THAT(GetOutputs({inputs1, inputs2}), ElementsAre(0, 0));
}
TEST_F(HistoryClustersModuleRankingModelHandlerTest,
ModelUpdatedVersionTooHigh) {
new_tab_page::proto::HistoryClustersModuleRankingModelMetadata metadata;
metadata.set_version(HistoryClustersModuleRankingSignals::kClientVersion + 1);
PushModelFileToModelExecutor(metadata);
EXPECT_FALSE(model_handler()->CanExecuteAvailableModel());
}
TEST_F(HistoryClustersModuleRankingModelHandlerTest,
ModelUpdatedVersionCorrectVersion) {
new_tab_page::proto::HistoryClustersModuleRankingModelMetadata metadata;
metadata.set_version(HistoryClustersModuleRankingSignals::kClientVersion);
PushModelFileToModelExecutor(metadata);
EXPECT_TRUE(model_handler()->CanExecuteAvailableModel());
}
TEST_F(HistoryClustersModuleRankingModelHandlerTest,
ModelExecutedMultipleInputs) {
new_tab_page::proto::HistoryClustersModuleRankingModelMetadata metadata;
metadata.set_version(HistoryClustersModuleRankingSignals::kClientVersion);
metadata.add_signals(
new_tab_page::proto::
HISTORY_CLUSTERS_MODULE_RANKING_MINUTES_SINCE_MOST_RECENT_VISIT);
metadata.add_signals(
new_tab_page::proto::
HISTORY_CLUSTERS_MODULE_RANKING_BELONGS_TO_BOOSTED_CATEGORY);
PushModelFileToModelExecutor(metadata);
EXPECT_TRUE(model_handler()->CanExecuteAvailableModel());
HistoryClustersModuleRankingSignals inputs1;
inputs1.duration_since_most_recent_visit = base::Minutes(2);
inputs1.belongs_to_boosted_category = false;
HistoryClustersModuleRankingSignals inputs2;
inputs2.duration_since_most_recent_visit = base::Minutes(5);
inputs2.belongs_to_boosted_category = true;
EXPECT_THAT(GetOutputs({inputs1, inputs2}), ElementsAre(2, 5 + 1));
}
} // namespace