| // Copyright 2022 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "components/segmentation_platform/internal/scheduler/execution_service.h" |
| |
| #include "base/memory/raw_ptr.h" |
| #include "base/task/sequenced_task_runner.h" |
| #include "components/prefs/pref_service.h" |
| #include "components/segmentation_platform/internal/database/cached_result_provider.h" |
| #include "components/segmentation_platform/internal/database/storage_service.h" |
| #include "components/segmentation_platform/internal/execution/execution_request.h" |
| #include "components/segmentation_platform/internal/execution/model_executor_impl.h" |
| #include "components/segmentation_platform/internal/execution/processing/feature_aggregator_impl.h" |
| #include "components/segmentation_platform/internal/execution/processing/feature_list_query_processor.h" |
| #include "components/segmentation_platform/internal/scheduler/model_execution_scheduler_impl.h" |
| #include "components/segmentation_platform/internal/segmentation_ukm_helper.h" |
| #include "components/segmentation_platform/internal/signals/signal_handler.h" |
| #include "components/segmentation_platform/public/config.h" |
| #include "components/segmentation_platform/public/input_delegate.h" |
| #include "components/segmentation_platform/public/model_provider.h" |
| #include "components/segmentation_platform/public/proto/model_metadata.pb.h" |
| |
| namespace segmentation_platform { |
| |
| ExecutionService::ExecutionService() = default; |
| ExecutionService::~ExecutionService() = default; |
| |
| void ExecutionService::InitForTesting( |
| std::unique_ptr<processing::FeatureListQueryProcessor> feature_processor, |
| std::unique_ptr<ModelExecutor> executor, |
| std::unique_ptr<ModelExecutionScheduler> scheduler, |
| ModelManager* model_manager) { |
| feature_list_query_processor_ = std::move(feature_processor); |
| model_executor_ = std::move(executor); |
| model_execution_scheduler_ = std::move(scheduler); |
| model_manager_ = model_manager; |
| } |
| |
| void ExecutionService::Initialize( |
| StorageService* storage_service, |
| SignalHandler* signal_handler, |
| base::Clock* clock, |
| scoped_refptr<base::SequencedTaskRunner> task_runner, |
| const base::flat_set<SegmentId>& legacy_output_segment_ids, |
| ModelProviderFactory* model_provider_factory, |
| std::vector<raw_ptr<ModelExecutionScheduler::Observer, |
| VectorExperimental>>&& observers, |
| const PlatformOptions& platform_options, |
| std::unique_ptr<processing::InputDelegateHolder> input_delegate_holder, |
| PrefService* profile_prefs, |
| CachedResultProvider* cached_result_provider) { |
| storage_service_ = storage_service; |
| |
| feature_list_query_processor_ = |
| std::make_unique<processing::FeatureListQueryProcessor>( |
| storage_service, std::move(input_delegate_holder), |
| std::make_unique<processing::FeatureAggregatorImpl>()); |
| |
| training_data_collector_ = TrainingDataCollector::Create( |
| platform_options, feature_list_query_processor_.get(), |
| signal_handler->deprecated_histogram_signal_handler(), |
| signal_handler->user_action_signal_handler(), storage_service, |
| profile_prefs, clock, cached_result_provider); |
| |
| model_executor_ = std::make_unique<ModelExecutorImpl>( |
| clock, storage_service->segment_info_database(), |
| feature_list_query_processor_.get()); |
| |
| model_manager_ = storage_service->model_manager(); |
| |
| model_execution_scheduler_ = std::make_unique<ModelExecutionSchedulerImpl>( |
| std::move(observers), storage_service->segment_info_database(), |
| storage_service->signal_storage_config(), model_manager_, |
| model_executor_.get(), legacy_output_segment_ids, clock, |
| platform_options); |
| } |
| |
| void ExecutionService::OnNewModelInfoReadyLegacy( |
| const proto::SegmentInfo& segment_info) { |
| // TODO(crbug.com/40258591): Change path flow as |
| // SPSI->RRM->EE::RequestModelExecution and migrate |
| // MES::CancelOutstandingExecutionRequests() to EE. |
| model_execution_scheduler_->OnNewModelInfoReady(segment_info); |
| } |
| |
| ModelProvider* ExecutionService::GetModelProvider(SegmentId segment_id, |
| ModelSource model_source) { |
| return model_manager_->GetModelProvider(segment_id, model_source); |
| } |
| |
| void ExecutionService::RequestModelExecution( |
| std::unique_ptr<ExecutionRequest> request) { |
| DCHECK_NE(request->segment_id, SegmentId::OPTIMIZATION_TARGET_UNKNOWN); |
| DCHECK_NE(request->model_source, proto::ModelSource::UNKNOWN_MODEL_SOURCE); |
| DCHECK(!request->callback.is_null()); |
| model_executor_->ExecuteModel(std::move(request)); |
| } |
| |
| void ExecutionService::OverwriteModelExecutionResult( |
| proto::SegmentId segment_id, |
| const std::pair<float, ModelExecutionStatus>& result) { |
| // TODO(ritikagup): Change the use of this according to MultiOutputModel. |
| auto execution_result = std::make_unique<ModelExecutionResult>( |
| ModelProvider::Request(), ModelProvider::Response(1, result.first)); |
| proto::SegmentInfo segment_info; |
| segment_info.set_segment_id(segment_id); |
| model_execution_scheduler_->OnModelExecutionCompleted( |
| segment_info, std::move(execution_result)); |
| } |
| |
| void ExecutionService::RefreshModelResults() { |
| model_execution_scheduler_->RequestModelExecutionForEligibleSegments( |
| /*expired_only=*/true); |
| } |
| |
| void ExecutionService::RunDailyTasks(bool is_startup) { |
| RefreshModelResults(); |
| |
| if (is_startup) { |
| // This will trigger data collection after initialization finishes. |
| training_data_collector_->OnServiceInitialized(); |
| } else { |
| training_data_collector_->ReportCollectedContinuousTrainingData(); |
| } |
| } |
| |
| } // namespace segmentation_platform |