| // Copyright 2020 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "chrome/browser/ash/power/ml/smart_dim/download_worker.h" |
| |
| #include "base/functional/bind.h" |
| #include "base/functional/callback.h" |
| #include "base/task/task_traits.h" |
| #include "chrome/browser/ash/power/ml/smart_dim/metrics.h" |
| #include "chrome/browser/ash/power/ml/smart_dim/ml_agent_util.h" |
| #include "chromeos/services/machine_learning/public/cpp/service_connection.h" |
| #include "components/assist_ranker/proto/example_preprocessor.pb.h" |
| #include "content/public/browser/browser_task_traits.h" |
| #include "content/public/browser/browser_thread.h" |
| #include "ui/base/resource/resource_bundle.h" |
| |
| namespace ash { |
| namespace power { |
| namespace ml { |
| |
| namespace { |
| using chromeos::machine_learning::mojom::FlatBufferModelSpec; |
| } // namespace |
| |
| DownloadWorker::DownloadWorker() : SmartDimWorker(), metrics_model_name_("") {} |
| |
| DownloadWorker::~DownloadWorker() = default; |
| |
| const assist_ranker::ExamplePreprocessorConfig* |
| DownloadWorker::GetPreprocessorConfig() { |
| return preprocessor_config_.get(); |
| } |
| |
| const mojo::Remote<chromeos::machine_learning::mojom::GraphExecutor>& |
| DownloadWorker::GetExecutor() { |
| return executor_; |
| } |
| |
| void DownloadWorker::LoadModelCallback( |
| chromeos::machine_learning::mojom::LoadModelResult result) { |
| if (result != chromeos::machine_learning::mojom::LoadModelResult::OK) { |
| LogLoadComponentEvent(LoadComponentEvent::kLoadModelError); |
| DVLOG(1) << "Failed to load Smart Dim flatbuffer model."; |
| } |
| } |
| |
| void DownloadWorker::CreateGraphExecutorCallback( |
| chromeos::machine_learning::mojom::CreateGraphExecutorResult result) { |
| if (result != |
| chromeos::machine_learning::mojom::CreateGraphExecutorResult::OK) { |
| LogLoadComponentEvent(LoadComponentEvent::kCreateGraphExecutorError); |
| DVLOG(1) << "Failed to create a Smart Dim graph executor."; |
| } else { |
| LogLoadComponentEvent(LoadComponentEvent::kSuccess); |
| } |
| } |
| |
| bool DownloadWorker::IsReady() { |
| return preprocessor_config_ && model_ && executor_ && |
| expected_feature_size_ > 0 && metrics_model_name_ != ""; |
| } |
| |
| void DownloadWorker::InitializeFromComponent( |
| const ComponentFileContents& contents) { |
| DCHECK_CURRENTLY_ON(content::BrowserThread::UI); |
| |
| auto [metadata_json, preprocessor_proto, model_flatbuffer] = contents; |
| |
| preprocessor_config_ = |
| std::make_unique<assist_ranker::ExamplePreprocessorConfig>(); |
| if (!preprocessor_config_->ParseFromString(preprocessor_proto)) { |
| LogLoadComponentEvent(LoadComponentEvent::kLoadPreprocessorError); |
| DVLOG(1) << "Failed to load preprocessor_config."; |
| preprocessor_config_.reset(); |
| return; |
| } |
| |
| // Meta data contains necessary info to construct FlatBufferModelSpec, and |
| // other optional info. |
| data_decoder::DataDecoder::ParseJsonIsolated( |
| std::move(metadata_json), |
| base::BindOnce(&DownloadWorker::OnJsonParsed, base::Unretained(this), |
| std::move(model_flatbuffer))); |
| } |
| |
| void DownloadWorker::SetOnReadyForTest(base::OnceClosure on_ready) { |
| on_ready_for_test_ = std::move(on_ready); |
| } |
| |
| void DownloadWorker::OnJsonParsed( |
| const std::string& model_flatbuffer, |
| const data_decoder::DataDecoder::ValueOrError result) { |
| DCHECK_CURRENTLY_ON(content::BrowserThread::UI); |
| if (!result.has_value() || !result->is_dict() || |
| !ParseMetaInfoFromJsonObject(*result, &metrics_model_name_, |
| &dim_threshold_, &expected_feature_size_, |
| &inputs_, &outputs_)) { |
| LogLoadComponentEvent(LoadComponentEvent::kLoadMetadataError); |
| DVLOG(1) << "Failed to parse meta info from metadata_json."; |
| return; |
| } |
| content::GetUIThreadTaskRunner({base::TaskPriority::BEST_EFFORT}) |
| ->PostTask( |
| FROM_HERE, |
| base::BindOnce(&DownloadWorker::LoadModelAndCreateGraphExecutor, |
| base::Unretained(this), std::move(model_flatbuffer))); |
| } |
| |
| void DownloadWorker::LoadModelAndCreateGraphExecutor( |
| const std::string& model_flatbuffer) { |
| DCHECK_CURRENTLY_ON(content::BrowserThread::UI); |
| DCHECK(!model_.is_bound() && !executor_.is_bound()); |
| |
| chromeos::machine_learning::ServiceConnection::GetInstance() |
| ->GetMachineLearningService() |
| .LoadFlatBufferModel( |
| FlatBufferModelSpec::New(std::move(model_flatbuffer), inputs_, |
| outputs_, metrics_model_name_), |
| model_.BindNewPipeAndPassReceiver(), |
| base::BindOnce(&DownloadWorker::LoadModelCallback, |
| base::Unretained(this))); |
| model_->CreateGraphExecutor( |
| chromeos::machine_learning::mojom::GraphExecutorOptions::New(), |
| executor_.BindNewPipeAndPassReceiver(), |
| base::BindOnce(&DownloadWorker::CreateGraphExecutorCallback, |
| base::Unretained(this))); |
| executor_.set_disconnect_handler(base::BindOnce( |
| &DownloadWorker::OnConnectionError, base::Unretained(this))); |
| if (on_ready_for_test_) { |
| std::move(on_ready_for_test_).Run(); |
| } |
| } |
| |
| } // namespace ml |
| } // namespace power |
| } // namespace ash |