blob: 29934afa83f10f3bcf4a43779a8d192b7ff7c504 [file] [log] [blame]
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_OMNIBOX_BROWSER_ON_DEVICE_TAIL_MODEL_SERVICE_H_
#define COMPONENTS_OMNIBOX_BROWSER_ON_DEVICE_TAIL_MODEL_SERVICE_H_
#include <memory>
#include "base/functional/callback.h"
#include "base/memory/scoped_refptr.h"
#include "base/task/sequenced_task_runner.h"
#include "base/timer/timer.h"
#include "components/keyed_service/core/keyed_service.h"
#include "components/omnibox/browser/on_device_tail_model_executor.h"
#include "components/optimization_guide/core/optimization_guide_model_provider.h"
// The key service holds on device tail model executor and its model observer.
class OnDeviceTailModelService
: public KeyedService,
public optimization_guide::OptimizationTargetModelObserver {
public:
using ResultCallback = base::OnceCallback<void(
std::vector<OnDeviceTailModelExecutor::Prediction>)>;
explicit OnDeviceTailModelService(
optimization_guide::OptimizationGuideModelProvider* model_provider);
~OnDeviceTailModelService() override;
// Disallow copy/assign.
OnDeviceTailModelService(const OnDeviceTailModelService&) = delete;
OnDeviceTailModelService& operator=(const OnDeviceTailModelService&) = delete;
// optimization_guide::OptimizationTargetModelObserver implementation:
void OnModelUpdated(
optimization_guide::proto::OptimizationTarget optimization_target,
base::optional_ref<const optimization_guide::ModelInfo> model_info)
override;
// Calls the model executor to generate predictions for the input.
void GetPredictionsForInput(
const OnDeviceTailModelExecutor::ModelInput& input,
ResultCallback result_callback);
private:
friend class OnDeviceTailModelServiceTest;
friend class FakeOnDeviceTailModelService;
// The default constructor used with tests only, which will create nullptrs
// for all private members such that tests can initialize members later on
// demand.
OnDeviceTailModelService();
// Checks if model executor is idle and maybe unload it from memory.
void CheckIfModelExecutorIdle();
// The task runner to run tail model executor.
scoped_refptr<base::SequencedTaskRunner> model_executor_task_runner_ =
nullptr;
using ExecutorUniquePtr =
std::unique_ptr<OnDeviceTailModelExecutor, base::OnTaskRunnerDeleter>;
// The executor to run the tail suggest model.
ExecutorUniquePtr tail_model_executor_;
// Optimization Guide Service that provides model files for this service.
raw_ptr<optimization_guide::OptimizationGuideModelProvider> model_provider_ =
nullptr;
// A periodic timer which checks the model executor.
base::RepeatingTimer timer_;
base::WeakPtrFactory<OnDeviceTailModelService> weak_ptr_factory_{this};
};
#endif // COMPONENTS_OMNIBOX_BROWSER_ON_DEVICE_TAIL_MODEL_SERVICE_H_