blob: d88705af7b5a1d650f38f3216923a0fc05117fc4 [file] [log] [blame]
// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/segmentation_platform/internal/scheduler/model_execution_scheduler_impl.h"
#include "base/logging.h"
#include "base/time/clock.h"
#include "base/time/time.h"
#include "components/segmentation_platform/internal/database/segment_info_database.h"
#include "components/segmentation_platform/internal/database/signal_storage_config.h"
#include "components/segmentation_platform/internal/execution/execution_request.h"
#include "components/segmentation_platform/internal/execution/model_execution_manager_impl.h"
#include "components/segmentation_platform/internal/metadata/metadata_utils.h"
#include "components/segmentation_platform/internal/platform_options.h"
#include "components/segmentation_platform/internal/stats.h"
#include "components/segmentation_platform/public/model_provider.h"
#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
#include "third_party/abseil-cpp/absl/types/optional.h"
namespace segmentation_platform {
ModelExecutionSchedulerImpl::ModelExecutionSchedulerImpl(
std::vector<Observer*>&& observers,
SegmentInfoDatabase* segment_database,
SignalStorageConfig* signal_storage_config,
ModelExecutionManager* model_execution_manager,
ModelExecutor* model_executor,
base::flat_set<proto::SegmentId> segment_ids,
base::Clock* clock,
const PlatformOptions& platform_options)
: observers_(observers),
segment_database_(segment_database),
signal_storage_config_(signal_storage_config),
model_execution_manager_(model_execution_manager),
model_executor_(model_executor),
all_segment_ids_(segment_ids),
clock_(clock),
platform_options_(platform_options) {}
ModelExecutionSchedulerImpl::~ModelExecutionSchedulerImpl() = default;
void ModelExecutionSchedulerImpl::OnNewModelInfoReady(
const proto::SegmentInfo& segment_info) {
DCHECK(metadata_utils::ValidateSegmentInfoMetadataAndFeatures(segment_info) ==
metadata_utils::ValidationResult::kValidationSuccess);
if (!ShouldExecuteSegment(/*expired_only=*/true, segment_info)) {
// We usually cancel any outstanding requests right before executing the
// model, but in this case we alreday know that 1) we got a new model, and
// b) the new model is not yet valid for execution. Therefore, we cancel
// the current execution and we will have to execute this model later.
CancelOutstandingExecutionRequests(segment_info.segment_id());
return;
}
RequestModelExecution(segment_info);
}
void ModelExecutionSchedulerImpl::RequestModelExecutionForEligibleSegments(
bool expired_only) {
segment_database_->GetSegmentInfoForSegments(
all_segment_ids_,
base::BindOnce(&ModelExecutionSchedulerImpl::FilterEligibleSegments,
weak_ptr_factory_.GetWeakPtr(), expired_only));
}
void ModelExecutionSchedulerImpl::RequestModelExecution(
const proto::SegmentInfo& segment_info) {
SegmentId segment_id = segment_info.segment_id();
CancelOutstandingExecutionRequests(segment_id);
outstanding_requests_.insert(std::make_pair(
segment_id,
base::BindOnce(&ModelExecutionSchedulerImpl::OnModelExecutionCompleted,
weak_ptr_factory_.GetWeakPtr(), segment_info)));
auto request = std::make_unique<ExecutionRequest>();
request->model_provider =
model_execution_manager_->GetProvider(segment_info.segment_id());
DCHECK(request->model_provider);
request->segment_info = &segment_info;
request->callback = outstanding_requests_[segment_id].callback();
request->record_metrics_for_default = false;
model_executor_->ExecuteModel(std::move(request));
}
void ModelExecutionSchedulerImpl::OnModelExecutionCompleted(
const proto::SegmentInfo& segment_info,
std::unique_ptr<ModelExecutionResult> result) {
// TODO(shaktisahu): Check ModelExecutionStatus and handle failure cases.
// Should we save it to DB?
SegmentId segment_id = segment_info.segment_id();
proto::PredictionResult segment_result;
bool success = result->status == ModelExecutionStatus::kSuccess;
if (success) {
segment_result = metadata_utils::CreatePredictionResult(
result->scores, segment_info.model_metadata().output_config(),
clock_->Now());
}
segment_database_->SaveSegmentResult(
segment_id, success ? absl::make_optional(segment_result) : absl::nullopt,
base::BindOnce(&ModelExecutionSchedulerImpl::OnResultSaved,
weak_ptr_factory_.GetWeakPtr(), segment_id));
}
void ModelExecutionSchedulerImpl::FilterEligibleSegments(
bool expired_only,
std::unique_ptr<SegmentInfoDatabase::SegmentInfoList> all_segments) {
std::vector<const proto::SegmentInfo*> models_to_run;
for (const auto& pair : *all_segments) {
SegmentId segment_id = pair.first;
const proto::SegmentInfo& segment_info = pair.second;
if (!ShouldExecuteSegment(expired_only, segment_info)) {
VLOG(1) << "Segmentation scheduler: Skipped executed segment "
<< proto::SegmentId_Name(segment_id);
continue;
}
models_to_run.emplace_back(&segment_info);
}
for (const proto::SegmentInfo* segment_info : models_to_run)
RequestModelExecution(*segment_info);
}
bool ModelExecutionSchedulerImpl::ShouldExecuteSegment(
bool expired_only,
const proto::SegmentInfo& segment_info) {
if (platform_options_.force_refresh_results)
return true;
// Filter out the segments computed recently.
if (metadata_utils::HasFreshResults(segment_info, clock_->Now())) {
VLOG(1) << "Segmentation model not executed since it has fresh results, "
"segment:"
<< proto::SegmentId_Name(segment_info.segment_id());
stats::RecordModelExecutionStatus(
segment_info.segment_id(),
/*default_provider=*/false,
ModelExecutionStatus::kSkippedHasFreshResults);
return false;
}
// Filter out the segments that aren't expired yet.
if (expired_only && !metadata_utils::HasExpiredOrUnavailableResult(
segment_info, clock_->Now())) {
VLOG(1) << "Segmentation model not executed since results are not expired, "
"segment:"
<< proto::SegmentId_Name(segment_info.segment_id());
stats::RecordModelExecutionStatus(
segment_info.segment_id(),
/*default_provider=*/false,
ModelExecutionStatus::kSkippedResultNotExpired);
return false;
}
// Filter out segments that don't match signal collection min length.
if (!signal_storage_config_->MeetsSignalCollectionRequirement(
segment_info.model_metadata())) {
stats::RecordModelExecutionStatus(
segment_info.segment_id(),
/*default_provider=*/false,
ModelExecutionStatus::kSkippedNotEnoughSignals);
VLOG(1) << "Segmentation model not executed since metadata requirements "
"not met, segment:"
<< proto::SegmentId_Name(segment_info.segment_id());
return false;
}
return true;
}
void ModelExecutionSchedulerImpl::CancelOutstandingExecutionRequests(
SegmentId segment_id) {
const auto& iter = outstanding_requests_.find(segment_id);
if (iter != outstanding_requests_.end()) {
iter->second.Cancel();
outstanding_requests_.erase(iter);
}
}
void ModelExecutionSchedulerImpl::OnResultSaved(SegmentId segment_id,
bool success) {
stats::RecordModelExecutionSaveResult(segment_id, success);
if (!success) {
// TODO(ssid): Consider removing this enum, this is the only case where the
// execution status is recorded twice for the same execution request.
stats::RecordModelExecutionStatus(
segment_id,
/*default_provider=*/false,
ModelExecutionStatus::kFailedToSaveResultAfterSuccess);
return;
}
for (Observer* observer : observers_)
observer->OnModelExecutionCompleted(segment_id);
}
} // namespace segmentation_platform