blob: 7320d3efb00f218681ba7e90763821e322d5703d [file] [log] [blame]
// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/segmentation_platform/internal/scheduler/execution_service.h"
#include "base/memory/raw_ptr.h"
#include "base/task/sequenced_task_runner.h"
#include "components/prefs/pref_service.h"
#include "components/segmentation_platform/internal/database/cached_result_provider.h"
#include "components/segmentation_platform/internal/database/storage_service.h"
#include "components/segmentation_platform/internal/execution/execution_request.h"
#include "components/segmentation_platform/internal/execution/model_executor_impl.h"
#include "components/segmentation_platform/internal/execution/processing/feature_aggregator_impl.h"
#include "components/segmentation_platform/internal/execution/processing/feature_list_query_processor.h"
#include "components/segmentation_platform/internal/scheduler/model_execution_scheduler_impl.h"
#include "components/segmentation_platform/internal/segmentation_ukm_helper.h"
#include "components/segmentation_platform/internal/signals/signal_handler.h"
#include "components/segmentation_platform/public/config.h"
#include "components/segmentation_platform/public/input_delegate.h"
#include "components/segmentation_platform/public/model_provider.h"
#include "components/segmentation_platform/public/proto/model_metadata.pb.h"
namespace segmentation_platform {
ExecutionService::ExecutionService() = default;
ExecutionService::~ExecutionService() = default;
void ExecutionService::InitForTesting(
std::unique_ptr<processing::FeatureListQueryProcessor> feature_processor,
std::unique_ptr<ModelExecutor> executor,
std::unique_ptr<ModelExecutionScheduler> scheduler,
ModelManager* model_manager) {
feature_list_query_processor_ = std::move(feature_processor);
model_executor_ = std::move(executor);
model_execution_scheduler_ = std::move(scheduler);
model_manager_ = model_manager;
}
void ExecutionService::Initialize(
StorageService* storage_service,
SignalHandler* signal_handler,
base::Clock* clock,
scoped_refptr<base::SequencedTaskRunner> task_runner,
const base::flat_set<SegmentId>& legacy_output_segment_ids,
ModelProviderFactory* model_provider_factory,
std::vector<raw_ptr<ModelExecutionScheduler::Observer,
VectorExperimental>>&& observers,
const PlatformOptions& platform_options,
std::unique_ptr<processing::InputDelegateHolder> input_delegate_holder,
PrefService* profile_prefs,
CachedResultProvider* cached_result_provider) {
storage_service_ = storage_service;
feature_list_query_processor_ =
std::make_unique<processing::FeatureListQueryProcessor>(
storage_service, std::move(input_delegate_holder),
std::make_unique<processing::FeatureAggregatorImpl>());
training_data_collector_ = TrainingDataCollector::Create(
platform_options, feature_list_query_processor_.get(),
signal_handler->deprecated_histogram_signal_handler(),
signal_handler->user_action_signal_handler(), storage_service,
profile_prefs, clock, cached_result_provider);
model_executor_ = std::make_unique<ModelExecutorImpl>(
clock, storage_service->segment_info_database(),
feature_list_query_processor_.get());
model_manager_ = storage_service->model_manager();
model_execution_scheduler_ = std::make_unique<ModelExecutionSchedulerImpl>(
std::move(observers), storage_service->segment_info_database(),
storage_service->signal_storage_config(), model_manager_,
model_executor_.get(), legacy_output_segment_ids, clock,
platform_options);
}
void ExecutionService::OnNewModelInfoReadyLegacy(
const proto::SegmentInfo& segment_info) {
// TODO(crbug.com/40258591): Change path flow as
// SPSI->RRM->EE::RequestModelExecution and migrate
// MES::CancelOutstandingExecutionRequests() to EE.
model_execution_scheduler_->OnNewModelInfoReady(segment_info);
}
ModelProvider* ExecutionService::GetModelProvider(SegmentId segment_id,
ModelSource model_source) {
return model_manager_->GetModelProvider(segment_id, model_source);
}
void ExecutionService::RequestModelExecution(
std::unique_ptr<ExecutionRequest> request) {
DCHECK_NE(request->segment_id, SegmentId::OPTIMIZATION_TARGET_UNKNOWN);
DCHECK_NE(request->model_source, proto::ModelSource::UNKNOWN_MODEL_SOURCE);
DCHECK(!request->callback.is_null());
model_executor_->ExecuteModel(std::move(request));
}
void ExecutionService::OverwriteModelExecutionResult(
proto::SegmentId segment_id,
const std::pair<float, ModelExecutionStatus>& result) {
// TODO(ritikagup): Change the use of this according to MultiOutputModel.
auto execution_result = std::make_unique<ModelExecutionResult>(
ModelProvider::Request(), ModelProvider::Response(1, result.first));
proto::SegmentInfo segment_info;
segment_info.set_segment_id(segment_id);
model_execution_scheduler_->OnModelExecutionCompleted(
segment_info, std::move(execution_result));
}
void ExecutionService::RefreshModelResults() {
model_execution_scheduler_->RequestModelExecutionForEligibleSegments(
/*expired_only=*/true);
}
void ExecutionService::RunDailyTasks(bool is_startup) {
RefreshModelResults();
if (is_startup) {
// This will trigger data collection after initialization finishes.
training_data_collector_->OnServiceInitialized();
} else {
training_data_collector_->ReportCollectedContinuousTrainingData();
}
}
} // namespace segmentation_platform