blob: a2ae54ad3d7ef4e50f062a6aa5de27839b3a5725 [file] [log] [blame]
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/new_tab_page/modules/history_clusters/ranking/history_clusters_module_ranker.h"
#include "base/feature_list.h"
#include "base/ranges/algorithm.h"
#include "chrome/browser/cart/cart_db.h"
#include "chrome/browser/cart/cart_service.h"
#include "chrome/browser/new_tab_page/modules/history_clusters/history_clusters_module_util.h"
#include "chrome/browser/new_tab_page/new_tab_page_util.h"
#include "components/history_clusters/core/history_clusters_util.h"
#include "components/optimization_guide/core/optimization_guide_model_provider.h"
#include "components/optimization_guide/machine_learning_tflite_buildflags.h"
#include "components/search/ntp_features.h"
#if BUILDFLAG(BUILD_WITH_TFLITE_LIB)
#include "chrome/browser/new_tab_page/modules/history_clusters/ranking/history_clusters_module_ranking_model_handler.h"
#include "chrome/browser/new_tab_page/modules/history_clusters/ranking/history_clusters_module_ranking_signals.h"
#endif
HistoryClustersModuleRanker::HistoryClustersModuleRanker(
optimization_guide::OptimizationGuideModelProvider* model_provider,
CartService* cart_service,
const base::flat_set<std::string>& category_boostlist)
: cart_service_(cart_service), category_boostlist_(category_boostlist) {
#if BUILDFLAG(BUILD_WITH_TFLITE_LIB)
if (model_provider) {
model_handler_ = std::make_unique<HistoryClustersModuleRankingModelHandler>(
model_provider);
}
#endif
}
HistoryClustersModuleRanker::~HistoryClustersModuleRanker() = default;
void HistoryClustersModuleRanker::RankClusters(
std::vector<history::Cluster> clusters,
ClustersCallback callback) {
if (IsCartModuleEnabled() && cart_service_) {
cart_service_->LoadAllActiveCarts(
base::BindOnce(&HistoryClustersModuleRanker::OnAllSignalsReady,
weak_ptr_factory_.GetWeakPtr(), std::move(clusters),
std::move(callback)));
} else {
OnAllSignalsReady(std::move(clusters), std::move(callback),
/*success=*/false, /*active_carts=*/{});
}
}
void HistoryClustersModuleRanker::OnAllSignalsReady(
std::vector<history::Cluster> clusters,
ClustersCallback callback,
bool success,
std::vector<CartDB::KeyAndValue> active_carts) {
#if BUILDFLAG(BUILD_WITH_TFLITE_LIB)
if (model_handler_ && model_handler_->CanExecuteAvailableModel()) {
std::vector<HistoryClustersModuleRankingSignals> ranking_signals;
ranking_signals.reserve(clusters.size());
for (const auto& cluster : clusters) {
ranking_signals.emplace_back(active_carts, category_boostlist_, cluster);
}
model_handler_->ExecuteBatch(
ranking_signals,
base::BindOnce(
&HistoryClustersModuleRanker::OnBatchModelExecutionComplete,
weak_ptr_factory_.GetWeakPtr(), std::move(clusters),
std::move(callback)));
return;
}
#endif
RunFallbackHeuristic(std::move(clusters), std::move(callback));
}
void HistoryClustersModuleRanker::RunFallbackHeuristic(
std::vector<history::Cluster> clusters,
ClustersCallback callback) {
SortClustersUsingHeuristic(category_boostlist_, clusters);
std::move(callback).Run(std::move(clusters));
}
#if BUILDFLAG(BUILD_WITH_TFLITE_LIB)
void HistoryClustersModuleRanker::OverrideModelHandlerForTesting(
std::unique_ptr<HistoryClustersModuleRankingModelHandler> model_handler) {
model_handler_ = std::move(model_handler);
}
void HistoryClustersModuleRanker::OnBatchModelExecutionComplete(
std::vector<history::Cluster> clusters,
ClustersCallback callback,
std::vector<float> outputs) {
CHECK_EQ(clusters.size(), outputs.size());
// Sort clusters by model score.
std::vector<std::tuple<history::Cluster, float>> clusters_with_scores;
clusters_with_scores.reserve(clusters.size());
for (size_t i = 0; i < clusters.size(); i++) {
clusters_with_scores.emplace_back(std::move(clusters[i]), outputs[i]);
}
base::ranges::stable_sort(clusters_with_scores,
[](const auto& c1, const auto& c2) {
return std::get<float>(c1) < std::get<float>(c2);
});
// Cull clusters based on how many we need.
std::vector<history::Cluster> output_clusters;
for (auto& cluster_and_score : clusters_with_scores) {
output_clusters.push_back(
std::move(std::get<history::Cluster>(cluster_and_score)));
}
std::move(callback).Run(std::move(output_clusters));
}
#endif