blob: b936f7470f5b0991c42eb65c82793fd0fcbce2a9 [file] [log] [blame]
// Copyright 2021 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/optimization_guide/core/model_validator.h"
#include <memory>
#include <vector>
#include "third_party/tflite/src/tensorflow/lite/c/common.h"
#include "third_party/tflite_support/src/tensorflow_lite_support/cc/task/core/task_utils.h"
namespace optimization_guide {
ModelValidatorHandler::ModelValidatorHandler(
OptimizationGuideModelProvider* model_provider,
scoped_refptr<base::SequencedTaskRunner> background_task_runner)
: ModelHandler<float, const std::vector<float>&>(
model_provider,
background_task_runner,
std::make_unique<ModelValidatorExecutor>(),
proto::OPTIMIZATION_TARGET_MODEL_VALIDATION,
/*model_metadata=*/absl::nullopt) {}
ModelValidatorHandler::~ModelValidatorHandler() = default;
void ModelValidatorHandler::OnModelExecutionComplete(
const absl::optional<float>& output) {
// Delete |this| since the model load completed successfully or failed.
delete this;
}
void ModelValidatorHandler::OnModelUpdated(
optimization_guide::proto::OptimizationTarget optimization_target,
const optimization_guide::ModelInfo& model_info) {
// First invoke parent to update internal status.
optimization_guide::ModelHandler<
float, const std::vector<float>&>::OnModelUpdated(optimization_target,
model_info);
// The parent class should always set the model availability to true after
// having received an updated model.
DCHECK(ModelAvailable());
// Try executing the model which will wait until the model is loaded, execute
// it, and callback when execution is finished.
ExecuteModelWithInput(
base::BindOnce(&ModelValidatorHandler::OnModelExecutionComplete,
weak_ptr_factory_.GetWeakPtr()),
std::vector<float>());
}
ModelValidatorExecutor::ModelValidatorExecutor() = default;
ModelValidatorExecutor::~ModelValidatorExecutor() = default;
absl::Status ModelValidatorExecutor::Preprocess(
const std::vector<TfLiteTensor*>& input_tensors,
const std::vector<float>& input) {
// Return error so that actual model execution does not happen.
return absl::Status(absl::StatusCode::kUnimplemented,
"Model execution not supported");
}
float ModelValidatorExecutor::Postprocess(
const std::vector<const TfLiteTensor*>& output_tensors) {
std::vector<float> data;
tflite::task::core::PopulateVector<float>(output_tensors[0], &data);
return data[0];
}
} // namespace optimization_guide