blob: f7c34a2be83bced362c456057dec4ef07c113da2 [file] [log] [blame]
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/new_tab_page/modules/history_clusters/ranking/history_clusters_module_ranker.h"
#include "base/run_loop.h"
#include "base/test/gmock_callback_support.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/task_environment.h"
#include "chrome/browser/cart/cart_service.h"
#include "chrome/browser/history/history_service_factory.h"
#include "chrome/test/base/testing_profile.h"
#include "components/history_clusters/core/clustering_test_utils.h"
#include "components/history_clusters/core/history_clusters_util.h"
#include "components/optimization_guide/core/test_optimization_guide_model_provider.h"
#include "components/optimization_guide/machine_learning_tflite_buildflags.h"
#include "components/search/ntp_features.h"
#include "content/public/test/browser_task_environment.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
#if BUILDFLAG(BUILD_WITH_TFLITE_LIB)
#include "chrome/browser/new_tab_page/modules/history_clusters/ranking/history_clusters_module_ranking_model_handler.h"
#include "chrome/browser/new_tab_page/modules/history_clusters/ranking/history_clusters_module_ranking_signals.h"
#endif
namespace {
using ::testing::ElementsAre;
class MockCartService : public CartService {
public:
explicit MockCartService(Profile* profile) : CartService(profile) {}
MOCK_METHOD1(LoadAllActiveCarts, void(CartDB::LoadCallback callback));
};
class HistoryClustersModuleRankerTest : public testing::Test {
public:
HistoryClustersModuleRankerTest() = default;
~HistoryClustersModuleRankerTest() override = default;
std::vector<history::Cluster> RankClusters(
HistoryClustersModuleRanker* ranker,
std::vector<history::Cluster> in_clusters) {
// Within each cluster, sort visits.
for (auto& cluster : in_clusters) {
history_clusters::StableSortVisits(cluster.visits);
}
std::vector<history::Cluster> clusters;
base::RunLoop run_loop;
ranker->RankClusters(std::move(in_clusters),
base::BindOnce(
[](base::RunLoop* run_loop,
std::vector<history::Cluster>* out_clusters,
std::vector<history::Cluster> clusters) {
*out_clusters = std::move(clusters);
run_loop->Quit();
},
&run_loop, &clusters));
run_loop.Run();
return clusters;
}
private:
content::BrowserTaskEnvironment task_environment_;
};
TEST_F(HistoryClustersModuleRankerTest, RecencyOnly) {
history::Cluster cluster1;
cluster1.cluster_id = 1;
history::AnnotatedVisit visit =
history_clusters::testing::CreateDefaultAnnotatedVisit(
1, GURL("https://github.com/"));
visit.visit_row.is_known_to_sync = true;
visit.content_annotations.has_url_keyed_image = true;
visit.content_annotations.model_annotations.categories = {{"category1", 90},
{"category2", 84}};
history::AnnotatedVisit visit2 =
history_clusters::testing::CreateDefaultAnnotatedVisit(
2, GURL("https://search.com/"));
visit2.visit_row.visit_time = base::Time::FromTimeT(3);
visit2.content_annotations.search_terms = u"search";
visit2.content_annotations.related_searches = {"relsearch1", "relsearch2"};
history::AnnotatedVisit visit4 =
history_clusters::testing::CreateDefaultAnnotatedVisit(
4, GURL("https://github.com/2"));
visit4.content_annotations.model_annotations.categories = {{"category1", 85},
{"category3", 82}};
visit4.content_annotations.has_url_keyed_image = true;
visit4.visit_row.is_known_to_sync = true;
cluster1.visits = {history_clusters::testing::CreateClusterVisit(
visit, /*normalized_url=*/absl::nullopt, 0.1),
history_clusters::testing::CreateClusterVisit(
visit2, /*normalized_url=*/absl::nullopt, 1.0),
history_clusters::testing::CreateClusterVisit(
visit4, /*normalized_url=*/absl::nullopt, 0.3)};
history::Cluster cluster2 = cluster1;
// Make the visit time before the first cluster and the first visit have a
// different visit ID so we can differentiate the two clusters.
cluster2.visits[1].annotated_visit.visit_row.visit_id = 123;
cluster2.visits[1].annotated_visit.visit_row.visit_time =
base::Time::FromTimeT(10);
base::flat_set<std::string> boost = {};
auto module_ranker = std::make_unique<HistoryClustersModuleRanker>(
/*optimization_guide_model_provider=*/nullptr, /*cart_service=*/nullptr,
boost);
std::vector<history::Cluster> clusters =
RankClusters(module_ranker.get(), {cluster1, cluster2});
EXPECT_THAT(
history_clusters::testing::ToVisitResults(clusters),
ElementsAre(ElementsAre(history_clusters::testing::VisitResult(
123, 1.0, {}, u"search"),
history_clusters::testing::VisitResult(4, 0.3),
history_clusters::testing::VisitResult(1, 0.1)),
ElementsAre(history_clusters::testing::VisitResult(2, 1.0, {},
u"search"),
history_clusters::testing::VisitResult(4, 0.3),
history_clusters::testing::VisitResult(1, 0.1))));
}
TEST_F(HistoryClustersModuleRankerTest, WithCategoryBoosting) {
history::Cluster cluster1;
cluster1.cluster_id = 1;
history::AnnotatedVisit visit =
history_clusters::testing::CreateDefaultAnnotatedVisit(
1, GURL("https://github.com/"));
visit.visit_row.is_known_to_sync = true;
visit.content_annotations.has_url_keyed_image = true;
visit.content_annotations.model_annotations.categories = {
{"category1", 90}, {"boostedbuthidden", 84}};
history::AnnotatedVisit visit2 =
history_clusters::testing::CreateDefaultAnnotatedVisit(
2, GURL("https://search.com/"));
visit2.visit_row.visit_time = base::Time::FromTimeT(100);
visit2.content_annotations.search_terms = u"search";
visit2.content_annotations.related_searches = {"relsearch1", "relsearch2"};
history::AnnotatedVisit visit4 =
history_clusters::testing::CreateDefaultAnnotatedVisit(
4, GURL("https://github.com/2"));
visit4.content_annotations.model_annotations.categories = {{"category1", 85},
{"category3", 82}};
visit4.content_annotations.has_url_keyed_image = true;
visit4.visit_row.is_known_to_sync = true;
cluster1.visits = {history_clusters::testing::CreateClusterVisit(
visit, /*normalized_url=*/absl::nullopt, 0.0),
history_clusters::testing::CreateClusterVisit(
visit2, /*normalized_url=*/absl::nullopt, 1.0),
history_clusters::testing::CreateClusterVisit(
visit4, /*normalized_url=*/absl::nullopt, 0.3)};
history::Cluster cluster2;
cluster2.cluster_id = 2;
history::AnnotatedVisit c2_visit =
history_clusters::testing::CreateDefaultAnnotatedVisit(
111, GURL("https://github.com/"));
c2_visit.visit_row.is_known_to_sync = true;
c2_visit.content_annotations.has_url_keyed_image = true;
c2_visit.content_annotations.model_annotations.categories = {
{"category1", 90}, {"boosted", 84}};
history::AnnotatedVisit c2_visit2 =
history_clusters::testing::CreateDefaultAnnotatedVisit(
222, GURL("https://search.com/"));
c2_visit2.visit_row.visit_time = base::Time::FromTimeT(3);
c2_visit2.content_annotations.search_terms = u"search";
c2_visit2.content_annotations.related_searches = {"relsearch1", "relsearch2"};
history::AnnotatedVisit c2_visit4 =
history_clusters::testing::CreateDefaultAnnotatedVisit(
444, GURL("https://github.com/2"));
c2_visit4.content_annotations.model_annotations.categories = {
{"category1", 85}, {"category3", 82}};
c2_visit4.content_annotations.has_url_keyed_image = true;
c2_visit4.visit_row.is_known_to_sync = true;
cluster2.visits = {history_clusters::testing::CreateClusterVisit(
c2_visit, /*normalized_url=*/absl::nullopt, 0.8),
history_clusters::testing::CreateClusterVisit(
c2_visit2, /*normalized_url=*/absl::nullopt, 1.0),
history_clusters::testing::CreateClusterVisit(
c2_visit4, /*normalized_url=*/absl::nullopt, 0.6)};
history::Cluster cluster3 = cluster2;
cluster3.cluster_id = 3;
for (auto& cluster_visit : cluster3.visits) {
// Increment the visits to differentiate the cluster.
cluster_visit.annotated_visit.visit_row.visit_id++;
// Change the time to be earlier.
cluster_visit.annotated_visit.visit_row.visit_time =
base::Time::FromTimeT(1);
}
base::flat_set<std::string> boost = {"boosted", "boostedbuthidden"};
auto module_ranker = std::make_unique<HistoryClustersModuleRanker>(
/*optimization_guide_model_provider=*/nullptr, /*cart_service=*/nullptr,
boost);
std::vector<history::Cluster> clusters =
RankClusters(module_ranker.get(), {cluster1, cluster2, cluster3});
// The second and third clusters should be picked since it contains a boosted
// category even though they were earlier than the first cluster and the
// visits should be sorted according to score. Tiebreaker between multiple
// clusters is still time.
EXPECT_THAT(
history_clusters::testing::ToVisitResults(clusters),
ElementsAre(ElementsAre(history_clusters::testing::VisitResult(
222, 1.0, {}, u"search"),
history_clusters::testing::VisitResult(111, 0.8),
history_clusters::testing::VisitResult(444, 0.6)),
ElementsAre(history_clusters::testing::VisitResult(
223, 1.0, {}, u"search"),
history_clusters::testing::VisitResult(112, 0.8),
history_clusters::testing::VisitResult(445, 0.6)),
ElementsAre(history_clusters::testing::VisitResult(2, 1.0, {},
u"search"),
history_clusters::testing::VisitResult(4, 0.3),
history_clusters::testing::VisitResult(1, 0.0))));
}
#if BUILDFLAG(BUILD_WITH_TFLITE_LIB)
class FakeModelHandler : public HistoryClustersModuleRankingModelHandler {
public:
explicit FakeModelHandler(
optimization_guide::OptimizationGuideModelProvider* provider)
: HistoryClustersModuleRankingModelHandler(provider) {}
~FakeModelHandler() override = default;
bool CanExecuteAvailableModel() override { return true; }
void ExecuteBatch(
const std::vector<HistoryClustersModuleRankingSignals>& inputs,
ExecuteBatchCallback callback) override {
std::vector<float> outputs;
outputs.reserve(inputs.size());
for (const auto& input : inputs) {
float score = input.belongs_to_boosted_category ? 1 : 0;
outputs.push_back(
score + static_cast<float>(
input.duration_since_most_recent_visit.InMinutes()));
;
}
std::move(callback).Run(std::move(outputs));
}
};
class HistoryClustersModuleRankerWithModelTest
: public HistoryClustersModuleRankerTest {
public:
HistoryClustersModuleRankerWithModelTest() {
scoped_feature_list_.InitWithFeatures(
{ntp_features::kNtpHistoryClustersModuleUseModelRanking,
ntp_features::kNtpChromeCartModule},
{});
}
private:
base::test::ScopedFeatureList scoped_feature_list_;
};
TEST_F(HistoryClustersModuleRankerWithModelTest,
ModelNotAvailableUsesFallback) {
history::Cluster cluster1;
cluster1.cluster_id = 1;
history::AnnotatedVisit visit =
history_clusters::testing::CreateDefaultAnnotatedVisit(
1, GURL("https://github.com/"));
visit.visit_row.is_known_to_sync = true;
visit.content_annotations.has_url_keyed_image = true;
visit.content_annotations.model_annotations.categories = {{"category1", 90},
{"category2", 84}};
history::AnnotatedVisit visit2 =
history_clusters::testing::CreateDefaultAnnotatedVisit(
2, GURL("https://search.com/"));
visit2.visit_row.visit_time = base::Time::FromTimeT(3);
visit2.content_annotations.search_terms = u"search";
visit2.content_annotations.related_searches = {"relsearch1", "relsearch2"};
history::AnnotatedVisit visit4 =
history_clusters::testing::CreateDefaultAnnotatedVisit(
4, GURL("https://github.com/2"));
visit4.content_annotations.model_annotations.categories = {{"category1", 85},
{"category3", 82}};
visit4.content_annotations.has_url_keyed_image = true;
visit4.visit_row.is_known_to_sync = true;
cluster1.visits = {history_clusters::testing::CreateClusterVisit(
visit, /*normalized_url=*/absl::nullopt, 0.1),
history_clusters::testing::CreateClusterVisit(
visit2, /*normalized_url=*/absl::nullopt, 1.0),
history_clusters::testing::CreateClusterVisit(
visit4, /*normalized_url=*/absl::nullopt, 0.3)};
history::Cluster cluster2 = cluster1;
// Make the visit time before the first cluster and the first visit have a
// different visit ID so we can differentiate the two clusters.
cluster2.visits[1].annotated_visit.visit_row.visit_id = 123;
cluster2.visits[1].annotated_visit.visit_row.visit_time =
base::Time::FromTimeT(10);
auto model_provider = std::make_unique<
optimization_guide::TestOptimizationGuideModelProvider>();
base::flat_set<std::string> boost = {};
auto module_ranker = std::make_unique<HistoryClustersModuleRanker>(
model_provider.get(), /*cart_service=*/nullptr, boost);
std::vector<history::Cluster> clusters =
RankClusters(module_ranker.get(), {cluster1, cluster2});
EXPECT_THAT(
history_clusters::testing::ToVisitResults(clusters),
ElementsAre(ElementsAre(history_clusters::testing::VisitResult(
123, 1.0, {}, u"search"),
history_clusters::testing::VisitResult(4, 0.3),
history_clusters::testing::VisitResult(1, 0.1)),
ElementsAre(history_clusters::testing::VisitResult(2, 1.0, {},
u"search"),
history_clusters::testing::VisitResult(4, 0.3),
history_clusters::testing::VisitResult(1, 0.1))));
}
TEST_F(HistoryClustersModuleRankerWithModelTest, ModelAvailable) {
base::Time now = base::Time::Now();
history::Cluster cluster1;
cluster1.cluster_id = 1;
history::AnnotatedVisit visit =
history_clusters::testing::CreateDefaultAnnotatedVisit(
1, GURL("https://github.com/"));
visit.visit_row.is_known_to_sync = true;
visit.content_annotations.has_url_keyed_image = true;
visit.content_annotations.model_annotations.categories = {
{"category1", 90}, {"boostedbuthidden", 84}};
history::AnnotatedVisit visit2 =
history_clusters::testing::CreateDefaultAnnotatedVisit(
2, GURL("https://search.com/"));
visit2.visit_row.visit_time = base::Time::FromTimeT(100);
visit2.content_annotations.search_terms = u"search";
visit2.content_annotations.related_searches = {"relsearch1", "relsearch2"};
history::AnnotatedVisit visit4 =
history_clusters::testing::CreateDefaultAnnotatedVisit(
4, GURL("https://github.com/2"));
visit4.content_annotations.model_annotations.categories = {{"category1", 85},
{"category3", 82}};
visit4.content_annotations.has_url_keyed_image = true;
visit4.visit_row.is_known_to_sync = true;
cluster1.visits = {history_clusters::testing::CreateClusterVisit(
visit, /*normalized_url=*/absl::nullopt, 0.0),
history_clusters::testing::CreateClusterVisit(
visit2, /*normalized_url=*/absl::nullopt, 1.0),
history_clusters::testing::CreateClusterVisit(
visit4, /*normalized_url=*/absl::nullopt, 0.3)};
history::Cluster cluster2;
cluster2.cluster_id = 2;
history::AnnotatedVisit c2_visit =
history_clusters::testing::CreateDefaultAnnotatedVisit(
111, GURL("https://github.com/"));
c2_visit.visit_row.is_known_to_sync = true;
c2_visit.content_annotations.has_url_keyed_image = true;
c2_visit.content_annotations.model_annotations.categories = {
{"category1", 90}, {"boosted", 84}};
history::AnnotatedVisit c2_visit2 =
history_clusters::testing::CreateDefaultAnnotatedVisit(
222, GURL("https://search.com/"));
// Pick a time that is 3 minutes ago.
c2_visit2.visit_row.visit_time = now - base::Minutes(3);
c2_visit2.content_annotations.search_terms = u"search";
c2_visit2.content_annotations.related_searches = {"relsearch1", "relsearch2"};
history::AnnotatedVisit c2_visit4 =
history_clusters::testing::CreateDefaultAnnotatedVisit(
444, GURL("https://github.com/2"));
c2_visit4.content_annotations.model_annotations.categories = {
{"category1", 85}, {"category3", 82}};
c2_visit4.content_annotations.has_url_keyed_image = true;
c2_visit4.visit_row.is_known_to_sync = true;
cluster2.visits = {history_clusters::testing::CreateClusterVisit(
c2_visit, /*normalized_url=*/absl::nullopt, 0.8),
history_clusters::testing::CreateClusterVisit(
c2_visit2, /*normalized_url=*/absl::nullopt, 1.0),
history_clusters::testing::CreateClusterVisit(
c2_visit4, /*normalized_url=*/absl::nullopt, 0.6)};
history::Cluster cluster3 = cluster2;
cluster3.cluster_id = 3;
for (auto& cluster_visit : cluster3.visits) {
// Increment the visits to differentiate the cluster.
cluster_visit.annotated_visit.visit_row.visit_id++;
// Change the time to be 1 minute ago.
cluster_visit.annotated_visit.visit_row.visit_time = now - base::Minutes(1);
}
base::flat_set<std::string> boost = {"boosted", "boostedbuthidden"};
auto model_provider = std::make_unique<
optimization_guide::TestOptimizationGuideModelProvider>();
TestingProfile::Builder profile_builder;
profile_builder.AddTestingFactory(HistoryServiceFactory::GetInstance(),
HistoryServiceFactory::GetDefaultFactory());
auto testing_profile = profile_builder.Build();
auto cart_service = std::make_unique<MockCartService>(testing_profile.get());
EXPECT_CALL(*cart_service,
LoadAllActiveCarts(base::test::IsNotNullCallback()))
.WillOnce(testing::WithArgs<0>(
testing::Invoke([&](CartDB::LoadCallback callback) -> void {
std::move(callback).Run(true, {});
})));
auto module_ranker = std::make_unique<HistoryClustersModuleRanker>(
model_provider.get(), cart_service.get(), boost);
auto model_handler = std::make_unique<FakeModelHandler>(model_provider.get());
module_ranker->OverrideModelHandlerForTesting(std::move(model_handler));
std::vector<history::Cluster> clusters =
RankClusters(module_ranker.get(), {cluster1, cluster2, cluster3});
EXPECT_THAT(
history_clusters::testing::ToVisitResults(clusters),
ElementsAre(ElementsAre(history_clusters::testing::VisitResult(
223, 1.0, {}, u"search"),
history_clusters::testing::VisitResult(112, 0.8),
history_clusters::testing::VisitResult(445, 0.6)),
ElementsAre(history_clusters::testing::VisitResult(
222, 1.0, {}, u"search"),
history_clusters::testing::VisitResult(111, 0.8),
history_clusters::testing::VisitResult(444, 0.6)),
ElementsAre(history_clusters::testing::VisitResult(2, 1.0, {},
u"search"),
history_clusters::testing::VisitResult(4, 0.3),
history_clusters::testing::VisitResult(1, 0.0))));
}
#endif // BUILDFLAG(BUILD_WITH_TFLITE_LIB)
} // namespace