blob: 06211164fe01a904a8c3f1827982559b62afcf2d [file] [log] [blame]
// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/optimization_guide/model_validator_keyed_service.h"
#include "base/command_line.h"
#include "base/files/file_util.h"
#include "base/metrics/histogram_macros_local.h"
#include "base/task/sequenced_task_runner.h"
#include "base/task/task_traits.h"
#include "base/task/thread_pool.h"
#include "build/build_config.h"
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service.h"
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service_factory.h"
#include "chrome/browser/profiles/profile.h"
#include "chrome/browser/signin/identity_manager_factory.h"
#include "components/optimization_guide/core/model_execution/feature_keys.h"
#include "components/optimization_guide/core/model_execution/on_device_model_component.h"
#include "components/optimization_guide/core/model_execution/on_device_model_execution_proto_descriptors.h"
#include "components/optimization_guide/core/optimization_guide_constants.h"
#include "components/optimization_guide/core/optimization_guide_features.h"
#include "components/optimization_guide/core/optimization_guide_model_executor.h"
#include "components/optimization_guide/core/optimization_guide_switches.h"
#include "components/optimization_guide/core/optimization_guide_util.h"
#include "components/optimization_guide/machine_learning_tflite_buildflags.h"
#include "components/optimization_guide/proto/features/compose.pb.h"
#include "components/optimization_guide/proto/model_execution.pb.h"
#include "components/optimization_guide/proto/model_validation.pb.h"
#include "components/optimization_guide/proto/string_value.pb.h"
#if BUILDFLAG(BUILD_WITH_TFLITE_LIB)
#include "components/optimization_guide/core/inference/model_validator.h"
#endif // BUILD_WITH_TFLITE_LIB
namespace {
std::unique_ptr<optimization_guide::proto::ModelValidationInput>
ParseRequestFromFile(base::FilePath path) {
std::string serialized_request;
if (!base::ReadFileToString(path, &serialized_request)) {
return nullptr;
}
auto request =
std::make_unique<optimization_guide::proto::ModelValidationInput>();
if (!request->ParseFromString(serialized_request)) {
return nullptr;
}
return request;
}
void WriteResponseToFile(
base::FilePath path,
optimization_guide::proto::ModelValidationOutput validation_output) {
std::string serialized_output;
if (!validation_output.SerializeToString(&serialized_output)) {
return;
}
bool write_file_success = base::WriteFile(path, serialized_output);
DCHECK(write_file_success);
}
} // namespace
namespace optimization_guide {
ModelValidatorKeyedService::ModelValidatorKeyedService(Profile* profile)
: profile_(profile) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK(ShouldStartModelValidator());
auto* opt_guide_service =
OptimizationGuideKeyedServiceFactory::GetForProfile(profile);
if (!opt_guide_service) {
return;
}
#if BUILDFLAG(BUILD_WITH_TFLITE_LIB)
if (switches::ShouldValidateModel()) {
// Create the validator object which will get destroyed when the model
// load is complete.
new ModelValidatorHandler(
opt_guide_service,
base::ThreadPool::CreateSequencedTaskRunner(
{base::MayBlock(), base::TaskPriority::BEST_EFFORT}));
}
#endif // BUILD_WITH_TFLITE_LIB
if (switches::ShouldValidateModelExecution()) {
auto* identity_manager = IdentityManagerFactory::GetForProfile(profile_);
if (!identity_manager) {
return;
}
if (!identity_manager->HasPrimaryAccount(signin::ConsentLevel::kSignin)) {
identity_manager_observation_.Observe(identity_manager);
return;
}
base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE,
base::BindOnce(
&ModelValidatorKeyedService::StartModelExecutionValidation,
weak_ptr_factory_.GetWeakPtr()));
}
if (switches::GetOnDeviceValidationRequestOverride()) {
base::FilePath ondevice_override_file =
switches::GetOnDeviceValidationRequestOverride().value();
base::ThreadPool::PostTaskAndReplyWithResult(
FROM_HERE, {base::MayBlock()},
base::BindOnce(&ParseRequestFromFile, ondevice_override_file),
base::BindOnce(
&ModelValidatorKeyedService::StartOnDeviceModelExecutionValidation,
weak_ptr_factory_.GetWeakPtr()));
}
}
ModelValidatorKeyedService::~ModelValidatorKeyedService() = default;
void ModelValidatorKeyedService::OnPrimaryAccountChanged(
const signin::PrimaryAccountChangeEvent& event_details) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!switches::ShouldValidateModelExecution()) {
return;
}
auto* identity_manager = IdentityManagerFactory::GetForProfile(profile_);
if (!identity_manager) {
return;
}
if (!identity_manager->HasPrimaryAccount(signin::ConsentLevel::kSignin)) {
return;
}
identity_manager_observation_.Reset();
base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE,
base::BindOnce(&ModelValidatorKeyedService::StartModelExecutionValidation,
weak_ptr_factory_.GetWeakPtr()));
}
void ModelValidatorKeyedService::StartModelExecutionValidation() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
auto* opt_guide_service =
OptimizationGuideKeyedServiceFactory::GetForProfile(profile_);
if (!opt_guide_service) {
return;
}
base::CommandLine* command_line = base::CommandLine::ForCurrentProcess();
std::string model_execution_input =
command_line->GetSwitchValueASCII(switches::kModelExecutionValidate);
if (model_execution_input.empty()) {
return;
}
proto::StringValue request;
request.set_value(model_execution_input);
opt_guide_service->ExecuteModel(
ModelBasedCapabilityKey::kTest, request,
/*execution_timeout=*/std::nullopt,
base::BindOnce(&ModelValidatorKeyedService::OnModelExecuteResponse,
weak_ptr_factory_.GetWeakPtr()));
}
void ModelValidatorKeyedService::StartOnDeviceModelExecutionValidation(
std::unique_ptr<optimization_guide::proto::ModelValidationInput> input) {
if (!input) {
return;
}
base::SequencedTaskRunner::GetCurrentDefault()->PostDelayedTask(
FROM_HERE,
base::BindOnce(
&ModelValidatorKeyedService::PerformOnDeviceModelExecutionValidation,
weak_ptr_factory_.GetWeakPtr(), std::move(input)),
features::GetOnDeviceModelExecutionValidationStartupDelay());
}
void ModelValidatorKeyedService::PerformOnDeviceModelExecutionValidation(
std::unique_ptr<optimization_guide::proto::ModelValidationInput> input) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
auto* opt_guide_service =
OptimizationGuideKeyedServiceFactory::GetForProfile(profile_);
if (!opt_guide_service) {
return;
}
if (!input || input->requests_size() == 0) {
return;
}
// TODO: b/345495541 - Add support for conducting inference within a loop.
// For now, we are just using the first request in the ModelValidationInput.
auto request = input->requests(0);
auto request_copy =
std::make_unique<optimization_guide::proto::ExecuteRequest>(request);
auto capability_key = ToModelBasedCapabilityKey(request.feature());
auto eligibility =
opt_guide_service->GetOnDeviceModelEligibility(capability_key);
if (eligibility != OnDeviceModelEligibilityReason::kSuccess) {
LOG(FATAL) << "Failed to create on-device session for validation with "
<< "OnDeviceModelEligibilityReason: "
<< static_cast<int>(eligibility);
}
using optimization_guide::SessionConfigParams;
on_device_validation_session_ = opt_guide_service->StartSession(
capability_key,
SessionConfigParams{
.execution_mode = SessionConfigParams::ExecutionMode::kOnDeviceOnly,
});
auto metadata = GetProtoFromAny(request.request_metadata());
on_device_validation_session_->AddContext(*metadata);
base::SequencedTaskRunner::GetCurrentDefault()->PostDelayedTask(
FROM_HERE,
base::BindOnce(&ModelValidatorKeyedService::ExecuteModel,
weak_ptr_factory_.GetWeakPtr(), std::move(request_copy)),
base::Seconds(30));
}
void ModelValidatorKeyedService::ExecuteModel(
std::unique_ptr<optimization_guide::proto::ExecuteRequest> request) {
DCHECK(on_device_validation_session_);
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
DCHECK(request);
auto metadata = GetProtoFromAny(request->request_metadata());
on_device_validation_session_->ExecuteModel(
*metadata, base::BindRepeating(
&ModelValidatorKeyedService::OnDeviceModelExecuteResponse,
weak_ptr_factory_.GetWeakPtr(), std::move(request)));
}
void ModelValidatorKeyedService::OnDeviceModelExecuteResponse(
const std::unique_ptr<optimization_guide::proto::ExecuteRequest>& request,
OptimizationGuideModelStreamingExecutionResult result) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (result.response.has_value() && !result.response->is_complete) {
// Ignore partial responses.
return;
}
// Complete responses with empty log entry indicate errors.
if (!result.execution_info || !result.provided_by_on_device) {
LOCAL_HISTOGRAM_BOOLEAN(
"OptimizationGuide.ModelValidation.OnDevice.DidError", true);
}
proto::ModelValidationOutput output;
optimization_guide::proto::ModelCall* model_call = output.add_model_calls();
model_call->mutable_request()->CopyFrom(*request);
optimization_guide::proto::ModelExecutionInfo* model_execution_info =
model_call->mutable_model_execution_info();
if (result.response.has_value()) {
model_call->mutable_response()->CopyFrom(result.response.value().response);
} else {
model_execution_info->set_model_execution_error_enum(
static_cast<uint32_t>(result.response.error().error()));
}
// TODO(crbug.com/372535824): store on-device execution log.
auto out_file = switches::GetOnDeviceValidationWriteToFile();
if (!out_file) {
return;
}
base::ThreadPool::PostTask(
FROM_HERE, {base::MayBlock()},
base::BindOnce(&WriteResponseToFile, *out_file, output));
}
void ModelValidatorKeyedService::OnModelExecuteResponse(
OptimizationGuideModelExecutionResult result,
std::unique_ptr<ModelQualityLogEntry> log_entry) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
}
} // namespace optimization_guide