blob: d3273d4c62303ec2a4c4f2784b7517723b0aca54 [file] [log] [blame]
// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/optimization_guide/model_validator_keyed_service.h"
#include "base/command_line.h"
#include "base/files/file_util.h"
#include "base/task/sequenced_task_runner.h"
#include "base/task/task_traits.h"
#include "base/task/thread_pool.h"
#include "build/build_config.h"
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service.h"
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service_factory.h"
#include "chrome/browser/profiles/profile.h"
#include "chrome/browser/signin/identity_manager_factory.h"
#include "components/optimization_guide/core/model_execution/on_device_model_component.h"
#include "components/optimization_guide/core/optimization_guide_features.h"
#include "components/optimization_guide/core/optimization_guide_switches.h"
#include "components/optimization_guide/core/optimization_guide_util.h"
#include "components/optimization_guide/machine_learning_tflite_buildflags.h"
#include "components/optimization_guide/proto/string_value.pb.h"
#if BUILDFLAG(BUILD_WITH_TFLITE_LIB)
#include "components/optimization_guide/core/model_validator.h"
#endif // BUILD_WITH_TFLITE_LIB
namespace {
std::unique_ptr<optimization_guide::proto::ComposeRequest>
ParseComposeRequestFromFile(base::FilePath path) {
std::string serialized_request;
if (!base::ReadFileToString(path, &serialized_request)) {
return nullptr;
}
auto request = std::make_unique<optimization_guide::proto::ComposeRequest>();
if (!request->ParseFromString(serialized_request)) {
return nullptr;
}
return request;
}
void WriteResponseToFile(base::FilePath path,
optimization_guide::proto::ComposeResponse response) {
bool write_file_success = base::WriteFile(path, response.output());
DCHECK(write_file_success);
}
} // namespace
namespace optimization_guide {
ModelValidatorKeyedService::ModelValidatorKeyedService(Profile* profile)
: profile_(profile) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK(ShouldStartModelValidator());
auto* opt_guide_service =
OptimizationGuideKeyedServiceFactory::GetForProfile(profile);
if (!opt_guide_service) {
return;
}
#if BUILDFLAG(BUILD_WITH_TFLITE_LIB)
if (switches::ShouldValidateModel()) {
// Create the validator object which will get destroyed when the model
// load is complete.
new ModelValidatorHandler(
opt_guide_service,
base::ThreadPool::CreateSequencedTaskRunner(
{base::MayBlock(), base::TaskPriority::BEST_EFFORT}));
}
#endif // BUILD_WITH_TFLITE_LIB
if (switches::ShouldValidateModelExecution()) {
auto* identity_manager = IdentityManagerFactory::GetForProfile(profile_);
if (!identity_manager) {
return;
}
if (!identity_manager->HasPrimaryAccount(signin::ConsentLevel::kSignin)) {
identity_manager_observation_.Observe(identity_manager);
return;
}
base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE,
base::BindOnce(
&ModelValidatorKeyedService::StartModelExecutionValidation,
weak_ptr_factory_.GetWeakPtr()));
}
if (switches::GetOnDeviceValidationRequestOverride()) {
base::FilePath ondevice_override_file =
switches::GetOnDeviceValidationRequestOverride().value();
base::ThreadPool::PostTaskAndReplyWithResult(
FROM_HERE, {base::MayBlock()},
base::BindOnce(&ParseComposeRequestFromFile, ondevice_override_file),
base::BindOnce(
&ModelValidatorKeyedService::StartOnDeviceModelExecutionValidation,
weak_ptr_factory_.GetWeakPtr()));
}
}
ModelValidatorKeyedService::~ModelValidatorKeyedService() = default;
void ModelValidatorKeyedService::OnPrimaryAccountChanged(
const signin::PrimaryAccountChangeEvent& event_details) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!switches::ShouldValidateModelExecution()) {
return;
}
auto* identity_manager = IdentityManagerFactory::GetForProfile(profile_);
if (!identity_manager) {
return;
}
if (!identity_manager->HasPrimaryAccount(signin::ConsentLevel::kSignin)) {
return;
}
identity_manager_observation_.Reset();
base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE,
base::BindOnce(&ModelValidatorKeyedService::StartModelExecutionValidation,
weak_ptr_factory_.GetWeakPtr()));
}
void ModelValidatorKeyedService::StartModelExecutionValidation() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
auto* opt_guide_service =
OptimizationGuideKeyedServiceFactory::GetForProfile(profile_);
if (!opt_guide_service) {
return;
}
base::CommandLine* command_line = base::CommandLine::ForCurrentProcess();
std::string model_execution_input =
command_line->GetSwitchValueASCII(switches::kModelExecutionValidate);
if (model_execution_input.empty()) {
return;
}
proto::StringValue request;
request.set_value(model_execution_input);
opt_guide_service->ExecuteModel(
ModelBasedCapabilityKey::kTest, request,
base::BindOnce(&ModelValidatorKeyedService::OnModelExecuteResponse,
weak_ptr_factory_.GetWeakPtr()));
}
void ModelValidatorKeyedService::StartOnDeviceModelExecutionValidation(
std::unique_ptr<optimization_guide::proto::ComposeRequest> request) {
if (!request) {
return;
}
base::SequencedTaskRunner::GetCurrentDefault()->PostDelayedTask(
FROM_HERE,
base::BindOnce(
&ModelValidatorKeyedService::PerformOnDeviceModelExecutionValidation,
weak_ptr_factory_.GetWeakPtr(), std::move(request)),
features::GetOnDeviceModelExecutionValidationStartupDelay());
}
void ModelValidatorKeyedService::PerformOnDeviceModelExecutionValidation(
std::unique_ptr<optimization_guide::proto::ComposeRequest> request) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
auto* opt_guide_service =
OptimizationGuideKeyedServiceFactory::GetForProfile(profile_);
if (!opt_guide_service) {
return;
}
on_device_validation_session_ =
opt_guide_service->StartSession(ModelBasedCapabilityKey::kCompose,
/*config_params=*/std::nullopt);
on_device_validation_session_->ExecuteModel(
*request, base::RepeatingCallback(base::BindRepeating(
&ModelValidatorKeyedService::OnDeviceModelExecuteResponse,
weak_ptr_factory_.GetWeakPtr())));
}
void ModelValidatorKeyedService::OnDeviceModelExecuteResponse(
OptimizationGuideModelStreamingExecutionResult result) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!result.response.has_value() || !result.response->is_complete) {
return;
}
optimization_guide::proto::ComposeResponse compose_response;
if (!compose_response.ParseFromString(result.response->response.value())) {
return;
}
auto out_file = switches::GetOnDeviceValidationWriteToFile();
if (!out_file) {
return;
}
base::ThreadPool::PostTask(
FROM_HERE, {base::MayBlock()},
base::BindOnce(&WriteResponseToFile, *out_file, compose_response));
}
void ModelValidatorKeyedService::OnModelExecuteResponse(
OptimizationGuideModelExecutionResult result,
std::unique_ptr<ModelQualityLogEntry> log_entry) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
}
} // namespace optimization_guide