| // Copyright 2018 The Chromium Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "chrome/browser/chromeos/power/ml/smart_dim/ml_service_client.h" |
| |
| #include "base/bind.h" |
| #include "base/memory/weak_ptr.h" |
| #include "base/metrics/histogram_macros.h" |
| #include "base/threading/thread.h" |
| #include "base/time/time.h" |
| #include "chrome/browser/chromeos/power/ml/smart_dim/model_impl.h" |
| #include "chromeos/services/machine_learning/public/cpp/service_connection.h" |
| #include "chromeos/services/machine_learning/public/mojom/graph_executor.mojom.h" |
| #include "chromeos/services/machine_learning/public/mojom/model.mojom.h" |
| #include "chromeos/services/machine_learning/public/mojom/tensor.mojom.h" |
| #include "mojo/public/cpp/bindings/map.h" |
| |
| using ::chromeos::machine_learning::mojom::CreateGraphExecutorResult; |
| using ::chromeos::machine_learning::mojom::ExecuteResult; |
| using ::chromeos::machine_learning::mojom::FloatList; |
| using ::chromeos::machine_learning::mojom::Int64List; |
| using ::chromeos::machine_learning::mojom::LoadModelResult; |
| using ::chromeos::machine_learning::mojom::ModelId; |
| using ::chromeos::machine_learning::mojom::ModelSpec; |
| using ::chromeos::machine_learning::mojom::ModelSpecPtr; |
| using ::chromeos::machine_learning::mojom::Tensor; |
| using ::chromeos::machine_learning::mojom::TensorPtr; |
| using ::chromeos::machine_learning::mojom::ValueList; |
| |
| namespace chromeos { |
| namespace power { |
| namespace ml { |
| |
| namespace { |
| |
| // TODO(crbug.com/893425): This should exist in only one location, so it should |
| // be merged with its duplicate in model_impl.cc to a common location. |
| void LogPowerMLSmartDimModelResult(SmartDimModelResult result) { |
| UMA_HISTOGRAM_ENUMERATION("PowerML.SmartDimModel.Result", result); |
| } |
| |
| // Real impl of MlServiceClient. |
| class MlServiceClientImpl : public MlServiceClient { |
| public: |
| MlServiceClientImpl(); |
| ~MlServiceClientImpl() override {} |
| |
| // MlServiceClient: |
| void DoInference( |
| const std::vector<float>& features, |
| base::RepeatingCallback<UserActivityEvent::ModelPrediction(float)> |
| get_prediction_callback, |
| SmartDimModel::DimDecisionCallback decision_callback) override; |
| |
| private: |
| // Various callbacks that get invoked by the Mojo framework. |
| void LoadModelCallback( |
| ::chromeos::machine_learning::mojom::LoadModelResult result); |
| void CreateGraphExecutorCallback( |
| ::chromeos::machine_learning::mojom::CreateGraphExecutorResult result); |
| |
| // Callback executed by ML Service when an Execute call is complete. |
| // |
| // The |get_prediction_callback| and the |decision_callback| are bound |
| // to the ExecuteCallback during while calling the Execute() function |
| // on the Mojo API. |
| void ExecuteCallback( |
| base::RepeatingCallback<UserActivityEvent::ModelPrediction(float)> |
| get_prediction_callback, |
| SmartDimModel::DimDecisionCallback decision_callback, |
| ::chromeos::machine_learning::mojom::ExecuteResult result, |
| base::Optional< |
| std::vector<::chromeos::machine_learning::mojom::TensorPtr>> outputs); |
| // Initializes the various handles to the ML service if they're not already |
| // available. |
| void InitMlServiceHandlesIfNeeded(); |
| |
| void OnConnectionError(); |
| |
| // Pointers used to execute functions in the ML service server end. |
| ::chromeos::machine_learning::mojom::ModelPtr model_; |
| ::chromeos::machine_learning::mojom::GraphExecutorPtr executor_; |
| |
| base::WeakPtrFactory<MlServiceClientImpl> weak_factory_; |
| |
| DISALLOW_COPY_AND_ASSIGN(MlServiceClientImpl); |
| }; |
| |
| MlServiceClientImpl::MlServiceClientImpl() |
| : MlServiceClient(), weak_factory_(this) {} |
| |
| void MlServiceClientImpl::LoadModelCallback(LoadModelResult result) { |
| if (result != LoadModelResult::OK) { |
| // TODO(crbug.com/893425): Log to UMA. |
| LOG(ERROR) << "Failed to load Smart Dim model."; |
| } |
| } |
| |
| void MlServiceClientImpl::CreateGraphExecutorCallback( |
| CreateGraphExecutorResult result) { |
| if (result != CreateGraphExecutorResult::OK) { |
| // TODO(crbug.com/893425): Log to UMA. |
| LOG(ERROR) << "Failed to create Smart Dim Graph Executor."; |
| } |
| } |
| |
| void MlServiceClientImpl::ExecuteCallback( |
| base::Callback<UserActivityEvent::ModelPrediction(float)> |
| get_prediction_callback, |
| SmartDimModel::DimDecisionCallback decision_callback, |
| const ExecuteResult result, |
| const base::Optional<std::vector<TensorPtr>> outputs) { |
| UserActivityEvent::ModelPrediction prediction; |
| |
| if (result != ExecuteResult::OK) { |
| LOG(ERROR) << "Smart Dim inference execution failed."; |
| prediction.set_response(UserActivityEvent::ModelPrediction::MODEL_ERROR); |
| LogPowerMLSmartDimModelResult(SmartDimModelResult::kOtherError); |
| } else { |
| float inactivity_score = |
| (outputs.value())[0]->data->get_float_list()->value[0]; |
| prediction = get_prediction_callback.Run(inactivity_score); |
| LogPowerMLSmartDimModelResult(SmartDimModelResult::kSuccess); |
| } |
| |
| std::move(decision_callback).Run(prediction); |
| } |
| |
| void MlServiceClientImpl::InitMlServiceHandlesIfNeeded() { |
| if (!model_) { |
| // Load the model. |
| ModelSpecPtr spec = ModelSpec::New(ModelId::SMART_DIM); |
| chromeos::machine_learning::ServiceConnection::GetInstance()->LoadModel( |
| std::move(spec), mojo::MakeRequest(&model_), |
| base::BindOnce(&MlServiceClientImpl::LoadModelCallback, |
| weak_factory_.GetWeakPtr())); |
| } |
| |
| if (!executor_) { |
| // Get the graph executor. |
| model_->CreateGraphExecutor( |
| mojo::MakeRequest(&executor_), |
| base::BindOnce(&MlServiceClientImpl::CreateGraphExecutorCallback, |
| weak_factory_.GetWeakPtr())); |
| executor_.set_connection_error_handler(base::BindOnce( |
| &MlServiceClientImpl::OnConnectionError, weak_factory_.GetWeakPtr())); |
| } |
| } |
| |
| void MlServiceClientImpl::OnConnectionError() { |
| // TODO(crbug.com/893425): Log to UMA. |
| LOG(WARNING) << "Mojo connection for ML service closed."; |
| executor_.reset(); |
| model_.reset(); |
| } |
| |
| void MlServiceClientImpl::DoInference( |
| const std::vector<float>& features, |
| base::Callback<UserActivityEvent::ModelPrediction(float)> |
| get_prediction_callback, |
| SmartDimModel::DimDecisionCallback decision_callback) { |
| InitMlServiceHandlesIfNeeded(); |
| |
| // Prepare the input tensor. |
| std::map<std::string, TensorPtr> inputs; |
| auto tensor = Tensor::New(); |
| tensor->shape = Int64List::New(); |
| tensor->shape->value = std::vector<int64_t>({1, features.size()}); |
| tensor->data = ValueList::New(); |
| tensor->data->set_float_list(FloatList::New()); |
| tensor->data->get_float_list()->value = |
| std::vector<double>(std::begin(features), std::end(features)); |
| inputs.emplace(std::string("input"), std::move(tensor)); |
| |
| std::vector<std::string> outputs({std::string("output")}); |
| |
| executor_->Execute( |
| mojo::MapToFlatMap(std::move(inputs)), std::move(outputs), |
| base::BindOnce(&MlServiceClientImpl::ExecuteCallback, |
| weak_factory_.GetWeakPtr(), get_prediction_callback, |
| std::move(decision_callback))); |
| } |
| |
| } // namespace |
| |
| std::unique_ptr<MlServiceClient> CreateMlServiceClient() { |
| return std::make_unique<MlServiceClientImpl>(); |
| } |
| |
| } // namespace ml |
| } // namespace power |
| } // namespace chromeos |