blob: 847268adc46fc61e5618c4243a78626c9fe381e9 [file] [log] [blame]
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/segmentation_platform/internal/selection/result_refresh_manager.h"
#include "base/metrics/field_trial_params.h"
#include "components/segmentation_platform/internal/selection/selection_utils.h"
#include "components/segmentation_platform/internal/stats.h"
#include "components/segmentation_platform/public/config.h"
#include "components/segmentation_platform/public/constants.h"
#include "components/segmentation_platform/public/features.h"
namespace segmentation_platform {
namespace {
const int kModelInitializationTimeoutMs = 5000;
int GetModelInitializationTimeoutMs() {
return base::GetFieldTrialParamByFeatureAsInt(
features::kSegmentationPlatformModelInitializationDelay,
kModelInitializationDelay, kModelInitializationTimeoutMs);
}
// Checks if the model result supports multi output model.
bool SupportMultiOutput(SegmentResultProvider::SegmentResult* result) {
return result && result->result.has_output_config();
}
// Collects training data.
void CollectTrainingDataIfNeeded(const Config* config,
ExecutionService* execution_service) {
// The execution service and training data collector might be null in testing.
if (execution_service && execution_service->training_data_collector()) {
for (const auto& segment : config->segments) {
execution_service->training_data_collector()->OnDecisionTime(
segment.first, nullptr,
proto::TrainingOutputs::TriggerConfig::PERIODIC, std::nullopt,
/*decision_result_update_trigger=*/true);
}
}
}
} // namespace
ResultRefreshManager::ResultRefreshManager(
const ConfigHolder* config_holder,
CachedResultWriter* cached_result_writer,
const PlatformOptions& platform_options)
: config_holder_(config_holder),
cached_result_writer_(cached_result_writer),
platform_options_(platform_options) {}
ResultRefreshManager::~ResultRefreshManager() = default;
void ResultRefreshManager::Initialize(
std::map<std::string, std::unique_ptr<SegmentResultProvider>>
result_providers,
ExecutionService* execution_service) {
result_providers_ = std::move(result_providers);
execution_service_ = execution_service;
delay_state_ = platform_options_.disable_model_execution_delay
? DelayState::DELAY_EXECUTED
: DelayState::DELAY_NOT_HIT;
}
void ResultRefreshManager::RefreshModelResults(bool is_startup) {
if (delay_state_ == DelayState::DELAY_NOT_HIT && is_startup) {
// Set a delay timeout to execute all the models after the delay
// `kModelInitializationTimeoutMs` is hit. This is to get finch seed to load
// before model execution.
base::SingleThreadTaskRunner::GetCurrentDefault()->PostDelayedTask(
FROM_HERE,
base::BindOnce(&ResultRefreshManager::RefreshModelResultsInternal,
weak_ptr_factory_.GetWeakPtr()),
base::Milliseconds(GetModelInitializationTimeoutMs()));
return;
}
if (delay_state_ == DelayState::DELAY_EXECUTED) {
RefreshModelResultsInternal();
}
}
void ResultRefreshManager::RefreshModelResultsInternal() {
delay_state_ = DelayState::DELAY_EXECUTED;
for (const auto& config : config_holder_->configs()) {
GetCachedResultOrRunModel(config.get());
}
}
void ResultRefreshManager::GetCachedResultOrRunModel(const Config* config) {
if (!config->auto_execute_and_cache ||
metadata_utils::ConfigUsesLegacyOutput(config)) {
return;
}
auto* segment_result_provider =
result_providers_[config->segmentation_key].get();
auto result_options =
std::make_unique<SegmentResultProvider::GetResultOptions>();
// Not required, checking only for testing.
if (config->segments.empty()) {
return;
}
// Note that, this assumes that a client has only one model.
result_options->segment_id = config->segments.begin()->first;
result_options->ignore_db_scores = false;
result_options->save_results_to_db = true;
result_options->callback = base::BindOnce(
&ResultRefreshManager::OnGetCachedResultOrRunModel,
weak_ptr_factory_.GetWeakPtr(), segment_result_provider, config);
segment_result_provider->GetSegmentResult(std::move(result_options));
}
void ResultRefreshManager::OnModelUpdated(proto::SegmentInfo* segment_info) {
const Config* config =
config_holder_->GetConfigForSegmentId(segment_info->segment_id());
if (config->segmentation_key.empty() ||
delay_state_ == DelayState::DELAY_NOT_HIT) {
return;
}
GetCachedResultOrRunModel(config);
}
void ResultRefreshManager::OnGetCachedResultOrRunModel(
SegmentResultProvider* segment_result_provider,
const Config* config,
std::unique_ptr<SegmentResultProvider::SegmentResult> result) {
SegmentResultProvider::ResultState result_state = result->state;
// If the model result is available either from database or running the
// model, update prefs if expired.
PredictionStatus status =
selection_utils::ResultStateToPredictionStatus(result_state);
if (status != PredictionStatus::kSucceeded) {
stats::RecordSegmentSelectionFailure(
*config, stats::GetSuccessOrFailureReason(result_state));
return;
}
if (!SupportMultiOutput(result.get())) {
stats::RecordSegmentSelectionFailure(
*config,
stats::SegmentationSelectionFailureReason::kMultiOutputNotSupported);
return;
}
// Recording this even for success case.
stats::RecordSegmentSelectionFailure(
*config, stats::GetSuccessOrFailureReason(result_state));
stats::RecordClassificationResultComputed(*config, result->result);
proto::ClientResult client_result =
metadata_utils::CreateClientResultFromPredResult(
std::move(result->result), base::Time::Now());
bool is_pref_updated = cached_result_writer_->UpdatePrefsIfExpired(
config, std::move(client_result), platform_options_);
if (is_pref_updated) {
CollectTrainingDataIfNeeded(config, execution_service_);
}
}
} // namespace segmentation_platform