blob: b0a1ea3067b3dd4beb19a90ed755ea164cbdb912 [file] [log] [blame]
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/navigation_predictor/preloading_model_keyed_service.h"
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service.h"
#include "components/optimization_guide/machine_learning_tflite_buildflags.h"
#if BUILDFLAG(BUILD_WITH_TFLITE_LIB)
#include "chrome/browser/navigation_predictor/preloading_model_handler.h"
#endif
PreloadingModelKeyedService::Inputs::Inputs() = default;
PreloadingModelKeyedService::PreloadingModelKeyedService(
OptimizationGuideKeyedService* optimization_guide_keyed_service) {
#if BUILDFLAG(BUILD_WITH_TFLITE_LIB)
auto* model_provider =
static_cast<optimization_guide::OptimizationGuideModelProvider*>(
optimization_guide_keyed_service);
if (model_provider) {
preloading_model_handler_ =
std::make_unique<PreloadingModelHandler>(model_provider);
}
#endif
}
PreloadingModelKeyedService::~PreloadingModelKeyedService() = default;
void PreloadingModelKeyedService::AddOnModelUpdatedCallback(
base::OnceClosure callback) {
#if BUILDFLAG(BUILD_WITH_TFLITE_LIB)
if (!preloading_model_handler_) {
return;
}
preloading_model_handler_->AddOnModelUpdatedCallback(std::move(callback));
#endif
}
void PreloadingModelKeyedService::Score(base::CancelableTaskTracker* tracker,
const Inputs& inputs,
ResultCallback result_callback) {
#if BUILDFLAG(BUILD_WITH_TFLITE_LIB)
if (!preloading_model_handler_ ||
!preloading_model_handler_->ModelAvailable()) {
std::move(result_callback).Run(absl::nullopt);
return;
}
std::vector<float> model_input{
/* input 0 */ inputs.contains_image ? 1.0f : 0.0f,
/* input 1 */ static_cast<float>(inputs.font_size),
/* input 2 */ inputs.has_text_sibling ? 1.0f : 0.0f,
/* input 3 */ inputs.is_bold ? 1.0f : 0.0f,
/* input 4 */ inputs.is_in_iframe ? 1.0f : 0.0f,
/* input 5 */ inputs.is_url_incremented_by_one ? 1.0f : 0.0f,
/* input 6 */
static_cast<float>(
inputs.navigation_start_to_link_logged.InMillisecondsF()),
/* input 7 */ static_cast<float>(inputs.path_depth),
/* input 8 */ static_cast<float>(inputs.path_length),
/* input 9 */ static_cast<float>(inputs.percent_clickable_area),
/* input 10*/ static_cast<float>(inputs.percent_vertical_distance),
/* input 11*/ inputs.is_same_origin ? 1.0f : 0.0f,
/* input 12*/ inputs.is_in_viewport ? 1.0f : 0.0f,
/* input 13*/ inputs.is_pointer_hovering_over ? 1.0f : 0.0f,
/* input 14*/
static_cast<float>(
inputs.entered_viewport_to_left_viewport.InMillisecondsF()),
/* input 15*/
static_cast<float>(inputs.hover_dwell_time.InMillisecondsF()),
/* input 16*/ static_cast<float>(inputs.pointer_hovering_over_count)};
preloading_model_handler_->ExecuteModelWithInput(
tracker, std::move(result_callback), model_input);
#else
std::move(result_callback).Run(absl::nullopt);
#endif
}