blob: eb8511ee1f6a708d28d8dfe83f3a19e2b18a9253 [file] [log] [blame]
// Copyright 2021 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_OPTIMIZATION_GUIDE_CONTENT_BROWSER_MODEL_EXECUTOR_H_
#define COMPONENTS_OPTIMIZATION_GUIDE_CONTENT_BROWSER_MODEL_EXECUTOR_H_
#include "base/bind.h"
#include "base/callback_forward.h"
#include "base/files/memory_mapped_file.h"
#include "base/logging.h"
#include "base/memory/weak_ptr.h"
#include "base/metrics/histogram_functions.h"
#include "base/sequence_checker.h"
#include "base/threading/sequenced_task_runner_handle.h"
#include "base/time/time.h"
#include "components/optimization_guide/content/browser/optimization_guide_decider.h"
#include "components/optimization_guide/core/optimization_guide_enums.h"
#include "components/optimization_guide/core/optimization_guide_features.h"
#include "components/optimization_guide/core/optimization_guide_util.h"
#include "components/optimization_guide/core/optimization_target_model_observer.h"
#include "third_party/abseil-cpp/absl/types/optional.h"
#include "third_party/tflite-support/src/tensorflow_lite_support/cc/task/core/base_task_api.h"
#include "third_party/tflite/src/tensorflow/lite/c/common.h"
namespace optimization_guide {
namespace {
// Util class for recording the result of loading the detection model. The
// result is recorded when it goes out of scope and its destructor is called.
class ScopedModelExecutorLoadingResultRecorder {
public:
ScopedModelExecutorLoadingResultRecorder(
proto::OptimizationTarget optimization_target,
ModelExecutorLoadingState model_loading_state)
: optimization_target_(optimization_target),
model_loading_state_(model_loading_state),
start_time_(base::TimeTicks::Now()) {}
~ScopedModelExecutorLoadingResultRecorder() {
base::UmaHistogramEnumeration(
"OptimizationGuide.ModelExecutor.ModelLoadingResult." +
optimization_guide::GetStringNameForOptimizationTarget(
optimization_target_),
model_loading_state_);
base::UmaHistogramTimes(
"OptimizationGuide.ModelExecutor.ModelLoadingDuration." +
optimization_guide::GetStringNameForOptimizationTarget(
optimization_target_),
base::TimeTicks::Now() - start_time_);
}
void set_model_loading_state(ModelExecutorLoadingState model_executor_state) {
model_loading_state_ = model_executor_state;
}
private:
proto::OptimizationTarget optimization_target_;
ModelExecutorLoadingState model_loading_state_;
// The time at which this instance was constructed.
const base::TimeTicks start_time_;
};
} // namespace
// This class handles the execution, loading, unloading, and associated metrics
// of machine learning models in Optimization Guide on a background thread. This
// class is meant to be used and owned by an instance of |ModelHandler|. A
// ModelExecutor must be passed to a ModelHandler's constructor, this design
// allows the implementer of a ModelExecutor to define how the model is built
// and executed.. See also base_model_executor.h and
// base_model_executor_helpers.h in this directory for helpful derived classes.
//
// Lifetime: This class can be constructed on any thread but cannot do anything
// useful until |InitializeAndMoveToBackgroundThread| is called. After that
// method is called, all subsequent calls to this class must be made through the
// |background_task_runner| that was passed to initialize. Furthermore, all
// WeakPointers of this class must only be dereferenced on the background thread
// as well. This in turn means that this class must be destroyed on the
// background thread as well.
template <class OutputType, class... InputTypes>
class ModelExecutor {
public:
ModelExecutor() = default;
virtual ~ModelExecutor() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
}
// Should be called on the same sequence as the ctor, but once called |this|
// must only be used from a background thread/sequence.
void InitializeAndMoveToBackgroundThread(
proto::OptimizationTarget optimization_target,
scoped_refptr<base::SequencedTaskRunner> background_task_runner,
scoped_refptr<base::SequencedTaskRunner> reply_task_runner) {
DCHECK(!background_task_runner_);
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK_NE(optimization_target,
proto::OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN);
DETACH_FROM_SEQUENCE(sequence_checker_);
optimization_target_ = optimization_target;
background_task_runner_ = background_task_runner;
reply_task_runner_ = reply_task_runner;
}
// Called when a model file is available to load. Depending on feature flags,
// the model may or may not be immediately loaded.
void UpdateModelFile(const base::FilePath& file_path) {
DCHECK(background_task_runner_->RunsTasksInCurrentSequence());
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
ResetLoadedModel();
model_file_path_ = file_path;
if (!features::LoadModelFileForEachExecution()) {
LoadModelFile();
}
}
// Starts the execution of the model. When complete, |ui_callback_on_complete|
// will be run on the UI thread with the output of the model.
using ExecutionCallback =
base::OnceCallback<void(const absl::optional<OutputType>&)>;
void SendForExecution(ExecutionCallback ui_callback_on_complete,
base::TimeTicks start_time,
InputTypes... args) {
DCHECK(background_task_runner_->RunsTasksInCurrentSequence());
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK(reply_task_runner_);
base::TimeDelta task_scheduling_latency =
base::TimeTicks::Now() - start_time;
base::UmaHistogramMediumTimes(
"OptimizationGuide.ModelExecutor.TaskSchedulingLatency." +
optimization_guide::GetStringNameForOptimizationTarget(
optimization_target_),
task_scheduling_latency);
// Attempt to load the model file if it isn't loaded yet, fail if loading is
// unsuccessful or no model is available to load.
if (!loaded_model_ && !LoadModelFile()) {
reply_task_runner_->PostTask(
FROM_HERE,
base::BindOnce(std::move(ui_callback_on_complete), absl::nullopt));
return;
}
if (last_execution_time_) {
// The max of this histogram is 3m since only the distribution and count
// of smaller values is important.
base::UmaHistogramMediumTimes(
"OptimizationGuide.ModelExecutor.TimeSincePreviousRun." +
GetStringNameForOptimizationTarget(optimization_target_),
base::TimeTicks::Now() - *last_execution_time_);
}
last_execution_time_ = base::TimeTicks::Now();
DCHECK(loaded_model_);
absl::optional<OutputType> output = Execute(loaded_model_.get(), args...);
DCHECK(ui_callback_on_complete);
reply_task_runner_->PostTask(
FROM_HERE, base::BindOnce(std::move(ui_callback_on_complete), output));
// If the model file should only be loaded for execution, then unload it
// from memory. This can be done in a PostTask since it may take a while
// for big models and isn't very important.
if (features::LoadModelFileForEachExecution()) {
background_task_runner_->PostTask(
FROM_HERE, base::BindOnce(&ModelExecutor::ResetLoadedModel,
GetBackgroundWeakPtr()));
}
}
// IMPORTANT: These WeakPointers must only be dereferenced on the background
// thread.
base::WeakPtr<ModelExecutor> GetBackgroundWeakPtr() {
return background_weak_ptr_factory_.GetWeakPtr();
}
ModelExecutor(const ModelExecutor&) = delete;
ModelExecutor& operator=(const ModelExecutor&) = delete;
protected:
using ModelExecutionTask =
tflite::task::core::BaseTaskApi<OutputType, InputTypes...>;
// Executes the model using |execution_task| on |args|.
virtual absl::optional<OutputType> Execute(ModelExecutionTask* execution_task,
InputTypes... args) = 0;
// Builds a model execution task using |model_file|.
virtual std::unique_ptr<ModelExecutionTask> BuildModelExecutionTask(
base::MemoryMappedFile* model_file) = 0;
private:
void ResetLoadedModel() {
DCHECK(background_task_runner_->RunsTasksInCurrentSequence());
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
loaded_model_.reset();
model_fb_.reset();
}
// A true return value indicates the model was loaded successfully, false
// otherwise.
bool LoadModelFile() {
DCHECK(background_task_runner_->RunsTasksInCurrentSequence());
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
ScopedModelExecutorLoadingResultRecorder scoped_model_loading_recorder(
optimization_target_, ModelExecutorLoadingState::kModelFileInvalid);
ResetLoadedModel();
base::UmaHistogramBoolean(
"OptimizationGuide.ModelExecutor.ModelAvailableToLoad." +
GetStringNameForOptimizationTarget(optimization_target_),
!!model_file_path_);
if (!model_file_path_)
return false;
std::unique_ptr<base::MemoryMappedFile> model_fb =
std::make_unique<base::MemoryMappedFile>();
if (!model_fb->Initialize(*model_file_path_))
return false;
model_fb_ = std::move(model_fb);
loaded_model_ = BuildModelExecutionTask(model_fb_.get());
if (loaded_model_) {
scoped_model_loading_recorder.set_model_loading_state(
ModelExecutorLoadingState::kModelFileValidAndMemoryMapped);
}
return !!loaded_model_;
}
proto::OptimizationTarget optimization_target_ =
proto::OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN;
scoped_refptr<base::SequencedTaskRunner> background_task_runner_;
scoped_refptr<base::SequencedTaskRunner> reply_task_runner_;
// The time that the model was last executed. Logged in metrics for the second
// and following runs.
absl::optional<base::TimeTicks> last_execution_time_
GUARDED_BY_CONTEXT(sequence_checker_);
// The model file path to be loaded. May be nullopt if no model has been
// downloaded yet.
absl::optional<base::FilePath> model_file_path_
GUARDED_BY_CONTEXT(sequence_checker_);
// Note on lifetimes: |loaded_model_| and |model_fb_| both share the same
// lifetime, being set in |LoadModelFile()| and being destroyed in
// |ResetModelFile()|.
std::unique_ptr<ModelExecutionTask> loaded_model_
GUARDED_BY_CONTEXT(sequence_checker_);
// This will only be non-null when |model_file_path_| is set, and while the
// model is loaded which is managed by a feature flag.
std::unique_ptr<base::MemoryMappedFile> model_fb_
GUARDED_BY_CONTEXT(sequence_checker_);
SEQUENCE_CHECKER(sequence_checker_);
base::WeakPtrFactory<ModelExecutor> background_weak_ptr_factory_{this};
};
// This class owns and handles the execution of models on the UI thread. Derived
// classes must provide an implementation of |ModelExecutor|
// (see above) which is then owned by |this|. The passed executor will be called
// and destroyed on a background thread, which is all handled by this class.
template <class OutputType, class... InputTypes>
class ModelHandler : public OptimizationTargetModelObserver {
public:
ModelHandler(OptimizationGuideDecider* decider,
scoped_refptr<base::SequencedTaskRunner> background_task_runner,
std::unique_ptr<ModelExecutor<OutputType, InputTypes...>>
background_executor,
proto::OptimizationTarget optimization_target,
const absl::optional<proto::Any>& model_metadata)
: decider_(decider),
optimization_target_(optimization_target),
background_executor_(std::move(background_executor)),
background_task_runner_(background_task_runner) {
DCHECK(decider_);
DCHECK(background_executor_);
DCHECK_NE(optimization_target_,
proto::OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN);
decider_->AddObserverForOptimizationTargetModel(optimization_target_,
model_metadata, this);
background_executor_->InitializeAndMoveToBackgroundThread(
optimization_target_, background_task_runner_,
base::SequencedTaskRunnerHandle::Get());
}
~ModelHandler() override {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
decider_->RemoveObserverForOptimizationTargetModel(optimization_target_,
this);
// |background_executor_|'s WeakPtrs are used on the background thread, so
// that is also where the class must be destroyed.
background_task_runner_->DeleteSoon(FROM_HERE,
std::move(background_executor_));
}
ModelHandler(const ModelHandler&) = delete;
ModelHandler& operator=(const ModelHandler&) = delete;
// Executes the model using |input| and invokes |callback| on the UI thread
// when completed.
// TODO(crbug/1173328): Add a way to surface errors.
using ExecutionCallback =
base::OnceCallback<void(const absl::optional<OutputType>&)>;
void ExecuteModelWithInput(ExecutionCallback callback, InputTypes... input) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
base::TimeTicks now = base::TimeTicks::Now();
ExecutionCallback on_complete_callback =
base::BindOnce(&ModelHandler::OnExecutionCompleted, std::move(callback),
optimization_target_, now);
background_task_runner_->PostTask(
FROM_HERE,
base::BindOnce(
&ModelExecutor<OutputType, InputTypes...>::SendForExecution,
background_executor_->GetBackgroundWeakPtr(),
std::move(on_complete_callback), now, input...));
}
// OptimizationTargetModelObserver:
void OnModelFileUpdated(proto::OptimizationTarget optimization_target,
const absl::optional<proto::Any>& model_metadata,
const base::FilePath& file_path) override {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (optimization_target_ != optimization_target)
return;
supported_features_for_loaded_model_ = model_metadata;
model_available_ = true;
background_task_runner_->PostTask(
FROM_HERE,
base::BindOnce(
&ModelExecutor<OutputType, InputTypes...>::UpdateModelFile,
background_executor_->GetBackgroundWeakPtr(), file_path));
}
// Returns whether a model is available to be executed.
bool ModelAvailable() const {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
return model_available_;
}
// Returns the supported features for the loaded model, if the server provided
// any.
absl::optional<proto::Any> supported_features_for_loaded_model() const {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
return supported_features_for_loaded_model_;
}
// Validates that |supported_features_for_loaded_model_| is of the same type
// and is parseable as |T|. Will return metadata if all checks pass.
template <
class T,
class = typename std::enable_if<
std::is_convertible<T*, google::protobuf::MessageLite*>{}>::type>
absl::optional<T> ParsedSupportedFeaturesForLoadedModel() const {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!supported_features_for_loaded_model_)
return absl::nullopt;
return ParsedAnyMetadata<T>(*supported_features_for_loaded_model_);
}
private:
// This is called by |background_executor_|. This method does not have to be
// static, but because it is stateless we've made it static so that we don't
// have to have this class support WeakPointers.
static void OnExecutionCompleted(
ExecutionCallback callback,
proto::OptimizationTarget optimization_target,
base::TimeTicks model_execute_start_time,
const absl::optional<OutputType>& output) {
if (!output) {
std::move(callback).Run(output);
return;
}
base::TimeDelta execution_time =
base::TimeTicks::Now() - model_execute_start_time;
base::UmaHistogramMediumTimes(
"OptimizationGuide.ModelExecutor.TaskExecutionLatency." +
optimization_guide::GetStringNameForOptimizationTarget(
optimization_target),
execution_time);
std::move(callback).Run(output);
}
// Not owned. Guaranteed to outlive |this|.
OptimizationGuideDecider* decider_ GUARDED_BY_CONTEXT(sequence_checker_);
const proto::OptimizationTarget optimization_target_;
// The owned background executor.
std::unique_ptr<ModelExecutor<OutputType, InputTypes...>>
background_executor_;
// The background task runner. Note that whenever a task is posted here, the
// task takes a reference to the TaskRunner (in a cyclic dependency) so
// |base::Unretained| is not safe anywhere in this class or the
// |background_executor_|.
scoped_refptr<base::SequencedTaskRunner> background_task_runner_;
// Set in |OnModelFileUpdated|.
absl::optional<proto::Any> supported_features_for_loaded_model_
GUARDED_BY_CONTEXT(sequence_checker_);
// Set in |OnModelFileUpdated|.
bool model_available_ GUARDED_BY_CONTEXT(sequence_checker_) = false;
SEQUENCE_CHECKER(sequence_checker_);
};
} // namespace optimization_guide
#endif // COMPONENTS_OPTIMIZATION_GUIDE_CONTENT_BROWSER_MODEL_EXECUTOR_H_