blob: 963b7ed863a40967f3d976457d134b971b85dc67 [file] [log] [blame]
// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/segmentation_platform/internal/selection/segment_result_provider.h"
#include <map>
#include "base/logging.h"
#include "base/memory/raw_ptr.h"
#include "base/task/sequenced_task_runner.h"
#include "components/segmentation_platform/internal/database/segment_info_database.h"
#include "components/segmentation_platform/internal/database/signal_storage_config.h"
#include "components/segmentation_platform/internal/execution/default_model_manager.h"
#include "components/segmentation_platform/internal/execution/execution_request.h"
#include "components/segmentation_platform/internal/logging.h"
#include "components/segmentation_platform/internal/metadata/metadata_utils.h"
#include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
#include "components/segmentation_platform/internal/scheduler/execution_service.h"
#include "components/segmentation_platform/internal/stats.h"
#include "components/segmentation_platform/public/model_provider.h"
#include "components/segmentation_platform/public/proto/model_metadata.pb.h"
namespace segmentation_platform {
namespace {
float ComputeDiscreteMapping(const std::string& discrete_mapping_key,
const proto::SegmentInfo& segment_info) {
float rank = metadata_utils::ConvertToDiscreteScore(
discrete_mapping_key, segment_info.prediction_result().result()[0],
segment_info.model_metadata());
VLOG(1) << __func__
<< ": segment=" << SegmentId_Name(segment_info.segment_id())
<< ": result=" << segment_info.prediction_result().result()[0]
<< ", rank=" << rank;
return rank;
}
proto::SegmentInfo* FilterSegmentInfoBySource(
DefaultModelManager::SegmentInfoList& available_segments,
DefaultModelManager::SegmentSource needed_source) {
proto::SegmentInfo* segment_info = nullptr;
for (const auto& info : available_segments) {
if (info->segment_source == needed_source) {
segment_info = &info->segment_info;
break;
}
}
return segment_info;
}
class SegmentResultProviderImpl : public SegmentResultProvider {
public:
SegmentResultProviderImpl(SegmentInfoDatabase* segment_database,
SignalStorageConfig* signal_storage_config,
DefaultModelManager* default_model_manager,
ExecutionService* execution_service,
base::Clock* clock,
bool force_refresh_results)
: segment_database_(segment_database),
signal_storage_config_(signal_storage_config),
default_model_manager_(default_model_manager),
execution_service_(execution_service),
clock_(clock),
force_refresh_results_(force_refresh_results),
task_runner_(base::SequencedTaskRunner::GetCurrentDefault()) {}
void GetSegmentResult(std::unique_ptr<GetResultOptions> options) override;
SegmentResultProviderImpl(const SegmentResultProviderImpl&) = delete;
SegmentResultProviderImpl& operator=(const SegmentResultProviderImpl&) =
delete;
private:
struct RequestState {
std::unordered_map<DefaultModelManager::SegmentSource,
raw_ptr<ModelProvider, DanglingUntriaged>>
model_providers;
DefaultModelManager::SegmentInfoList available_segments;
std::unique_ptr<GetResultOptions> options;
};
void OnGetSegmentInfo(
std::unique_ptr<GetResultOptions> options,
DefaultModelManager::SegmentInfoList available_segments);
// `fallback_source_to_execute` tells us whether to execute server or default
// model next. If database doesn't have score then database model is executed,
// and if its not present or fails, then default model is executed.
void OnGotDatabaseModelScore(
DefaultModelManager::SegmentSource fallback_source_to_execute,
std::unique_ptr<RequestState> request_state,
std::unique_ptr<SegmentResult> db_result);
void TryGetScoreFromDefaultModel(
std::unique_ptr<RequestState> request_state,
SegmentResultProvider::ResultState existing_state);
using ResultCallbackWithState =
base::OnceCallback<void(std::unique_ptr<RequestState>,
std::unique_ptr<SegmentResult>)>;
void GetCachedModelScore(std::unique_ptr<RequestState> request_state,
ResultCallbackWithState callback);
void ExecuteModelAndGetScore(std::unique_ptr<RequestState> request_state,
DefaultModelManager::SegmentSource source,
ResultCallbackWithState callback);
void OnModelExecuted(std::unique_ptr<RequestState> request_state,
DefaultModelManager::SegmentSource source,
ResultCallbackWithState callback,
std::unique_ptr<ModelExecutionResult> result);
void PostResultCallback(std::unique_ptr<RequestState> request_state,
std::unique_ptr<SegmentResult> result);
void RunCallback(SegmentId segment_id,
std::unique_ptr<RequestState> request_state,
std::unique_ptr<SegmentResult> segment_result,
ResultCallbackWithState callback,
bool success);
const raw_ptr<SegmentInfoDatabase, DanglingUntriaged> segment_database_;
const raw_ptr<SignalStorageConfig> signal_storage_config_;
const raw_ptr<DefaultModelManager> default_model_manager_;
const raw_ptr<ExecutionService, DanglingUntriaged> execution_service_;
const raw_ptr<base::Clock> clock_;
const bool force_refresh_results_;
scoped_refptr<base::SequencedTaskRunner> task_runner_;
base::WeakPtrFactory<SegmentResultProviderImpl> weak_ptr_factory_{this};
};
void SegmentResultProviderImpl::GetSegmentResult(
std::unique_ptr<GetResultOptions> options) {
const SegmentId segment_id = options->segment_id;
default_model_manager_->GetAllSegmentInfoFromBothModels(
{segment_id}, segment_database_,
base::BindOnce(&SegmentResultProviderImpl::OnGetSegmentInfo,
weak_ptr_factory_.GetWeakPtr(), std::move(options)));
}
void SegmentResultProviderImpl::OnGetSegmentInfo(
std::unique_ptr<GetResultOptions> options,
DefaultModelManager::SegmentInfoList available_segments) {
const SegmentId segment_id = options->segment_id;
auto request_state = std::make_unique<RequestState>();
request_state->options = std::move(options);
request_state->available_segments.swap(available_segments);
request_state->model_providers[DefaultModelManager::SegmentSource::DATABASE] =
execution_service_ ? execution_service_->GetModelProvider(segment_id)
: nullptr;
// Default manager can be null in tests.
request_state
->model_providers[DefaultModelManager::SegmentSource::DEFAULT_MODEL] =
default_model_manager_
? default_model_manager_->GetDefaultProvider(segment_id)
: nullptr;
// If `ignore_db_scores` is true than the server model will be executed now,
// if that fails to give result, fallback to default model, hence default
// model is the `fallback_source_to_execute` if `ignore_db_score` is true. If
// `ignore_db_scores` is false than the score from database would be read, if
// that fails to read score from database, fallback to running server model,
// hence running server model is the `fallback_source_to_execute` if
// `ignore_db_score` is false.
DefaultModelManager::SegmentSource fallback_source_to_execute =
request_state->options->ignore_db_scores
? DefaultModelManager::SegmentSource::DEFAULT_MODEL
: DefaultModelManager::SegmentSource::DATABASE;
auto db_score_callback = base::BindOnce(
&SegmentResultProviderImpl::OnGotDatabaseModelScore,
weak_ptr_factory_.GetWeakPtr(), fallback_source_to_execute);
if (request_state->options->ignore_db_scores) {
VLOG(1) << __func__ << ": segment="
<< SegmentId_Name(request_state->options->segment_id)
<< " ignoring DB score, executing model.";
ExecuteModelAndGetScore(std::move(request_state),
DefaultModelManager::SegmentSource::DATABASE,
std::move(db_score_callback));
return;
}
GetCachedModelScore(std::move(request_state), std::move(db_score_callback));
}
void SegmentResultProviderImpl::OnGotDatabaseModelScore(
DefaultModelManager::SegmentSource fallback_source_to_execute,
std::unique_ptr<RequestState> request_state,
std::unique_ptr<SegmentResult> db_result) {
if (db_result && db_result->rank.has_value()) {
PostResultCallback(std::move(request_state), std::move(db_result));
return;
}
// If previously the `fallback_source_to_execute` was server model, that means
// that the server model will be running this time, and if that fails to
// provide the result, the fallback to this fallback will be running default
// model. Hence in this case, `fallback_source_to_execute` is running default
// model.
if (fallback_source_to_execute ==
DefaultModelManager::SegmentSource::DATABASE) {
auto db_score_callback =
base::BindOnce(&SegmentResultProviderImpl::OnGotDatabaseModelScore,
weak_ptr_factory_.GetWeakPtr(),
DefaultModelManager::SegmentSource::DEFAULT_MODEL);
VLOG(1) << __func__ << ": segment="
<< SegmentId_Name(request_state->options->segment_id)
<< " failed to get score from database, executing server model.";
ExecuteModelAndGetScore(std::move(request_state),
fallback_source_to_execute,
std::move(db_score_callback));
return;
}
VLOG(1) << __func__
<< ": segment=" << SegmentId_Name(request_state->options->segment_id)
<< " failed to get database model score, trying default model.";
TryGetScoreFromDefaultModel(std::move(request_state), db_result->state);
}
void SegmentResultProviderImpl::TryGetScoreFromDefaultModel(
std::unique_ptr<RequestState> request_state,
SegmentResultProvider::ResultState existing_state) {
ModelProvider* default_model =
request_state
->model_providers[DefaultModelManager::SegmentSource::DEFAULT_MODEL];
if (!default_model || !default_model->ModelAvailable()) {
VLOG(1) << __func__ << ": segment="
<< SegmentId_Name(request_state->options->segment_id)
<< " default provider not available";
// Make sure the metrics record state of database model failure when client
// did not provide a default model.
PostResultCallback(std::move(request_state),
std::make_unique<SegmentResult>(existing_state));
return;
}
ExecuteModelAndGetScore(
std::move(request_state),
DefaultModelManager::SegmentSource::DEFAULT_MODEL,
base::BindOnce(&SegmentResultProviderImpl::PostResultCallback,
weak_ptr_factory_.GetWeakPtr()));
}
void SegmentResultProviderImpl::GetCachedModelScore(
std::unique_ptr<RequestState> request_state,
ResultCallbackWithState callback) {
const proto::SegmentInfo* db_segment_info =
FilterSegmentInfoBySource(request_state->available_segments,
DefaultModelManager::SegmentSource::DATABASE);
if (!db_segment_info) {
VLOG(1) << __func__ << ": segment="
<< SegmentId_Name(request_state->options->segment_id)
<< " does not have a segment info.";
std::move(callback).Run(
std::move(request_state),
std::make_unique<SegmentResult>(ResultState::kSegmentNotAvailable));
return;
}
if (metadata_utils::HasExpiredOrUnavailableResult(*db_segment_info,
clock_->Now())) {
VLOG(1) << __func__ << ": segment="
<< SegmentId_Name(request_state->options->segment_id)
<< " has expired or unavailable result.";
std::move(callback).Run(
std::move(request_state),
std::make_unique<SegmentResult>(ResultState::kDatabaseScoreNotReady));
return;
}
VLOG(1) << __func__ << ": Retrieved prediction from database: "
<< segmentation_platform::PredictionResultToDebugString(
db_segment_info->prediction_result())
<< " for segment "
<< proto::SegmentId_Name(request_state->options->segment_id);
float rank = ComputeDiscreteMapping(
request_state->options->discrete_mapping_key, *db_segment_info);
std::move(callback).Run(std::move(request_state),
std::make_unique<SegmentResult>(
ResultState::kSuccessFromDatabase,
db_segment_info->prediction_result(), rank));
}
void SegmentResultProviderImpl::ExecuteModelAndGetScore(
std::unique_ptr<RequestState> request_state,
DefaultModelManager::SegmentSource source,
ResultCallbackWithState callback) {
const proto::SegmentInfo* segment_info =
FilterSegmentInfoBySource(request_state->available_segments, source);
if (!segment_info) {
VLOG(1) << __func__ << ": segment="
<< SegmentId_Name(request_state->options->segment_id)
<< " default segment info not available";
auto state = source == DefaultModelManager::SegmentSource::DATABASE
? ResultState::kSegmentNotAvailable
: ResultState::kDefaultModelMetadataMissing;
std::move(callback).Run(std::move(request_state),
std::make_unique<SegmentResult>(state));
return;
}
DCHECK_EQ(metadata_utils::ValidationResult::kValidationSuccess,
metadata_utils::ValidateMetadata(segment_info->model_metadata()));
if (!force_refresh_results_ &&
!signal_storage_config_->MeetsSignalCollectionRequirement(
segment_info->model_metadata())) {
VLOG(1) << __func__ << ": segment="
<< SegmentId_Name(request_state->options->segment_id)
<< " signal collection not met";
auto state = source == DefaultModelManager::SegmentSource::DATABASE
? ResultState::kSignalsNotCollected
: ResultState::kDefaultModelSignalNotCollected;
std::move(callback).Run(std::move(request_state),
std::make_unique<SegmentResult>(state));
return;
}
ModelProvider* provider = request_state->model_providers[source];
auto request = std::make_unique<ExecutionRequest>();
// The pointer is kept alive by the `request_state`.
request->segment_info = segment_info;
request->record_metrics_for_default =
source == DefaultModelManager::SegmentSource::DEFAULT_MODEL;
request->input_context = request_state->options->input_context;
request->callback =
base::BindOnce(&SegmentResultProviderImpl::OnModelExecuted,
weak_ptr_factory_.GetWeakPtr(), std::move(request_state),
source, std::move(callback));
request->model_provider = provider;
execution_service_->RequestModelExecution(std::move(request));
}
void SegmentResultProviderImpl::OnModelExecuted(
std::unique_ptr<RequestState> request_state,
DefaultModelManager::SegmentSource source,
ResultCallbackWithState callback,
std::unique_ptr<ModelExecutionResult> result) {
auto* segment_info =
FilterSegmentInfoBySource(request_state->available_segments, source);
ResultState state = ResultState::kUnknown;
proto::PredictionResult prediction_result;
std::unique_ptr<SegmentResult> segment_result;
bool success = result->status == ModelExecutionStatus::kSuccess;
bool is_default_model =
source == DefaultModelManager::SegmentSource::DEFAULT_MODEL;
if (success) {
state = is_default_model ? ResultState::kDefaultModelScoreUsed
: ResultState::kTfliteModelScoreUsed;
prediction_result = metadata_utils::CreatePredictionResult(
result->scores, segment_info->model_metadata().output_config(),
clock_->Now());
segment_info->mutable_prediction_result()->CopyFrom(prediction_result);
float rank = ComputeDiscreteMapping(
request_state->options->discrete_mapping_key, *segment_info);
segment_result =
std::make_unique<SegmentResult>(state, prediction_result, rank);
VLOG(1) << __func__ << ": " << (is_default_model ? "Default" : "Server")
<< " model executed successfully. Result: "
<< segmentation_platform::PredictionResultToDebugString(
prediction_result)
<< " for segment "
<< proto::SegmentId_Name(request_state->options->segment_id);
} else {
state = is_default_model ? ResultState::kDefaultModelExecutionFailed
: ResultState::kTfliteModelExecutionFailed;
segment_result = std::make_unique<SegmentResult>(state);
VLOG(1) << __func__ << ": " << (is_default_model ? "Default" : "Server")
<< " model execution failed"
<< " for segment "
<< proto::SegmentId_Name(request_state->options->segment_id);
}
if (!is_default_model && request_state->options->save_results_to_db) {
// Saving results to database.
segment_database_->SaveSegmentResult(
segment_info->segment_id(),
success ? absl::make_optional(prediction_result) : absl::nullopt,
base::BindOnce(&SegmentResultProviderImpl::RunCallback,
weak_ptr_factory_.GetWeakPtr(),
segment_info->segment_id(), std::move(request_state),
std::move(segment_result), std::move(callback)));
return;
}
RunCallback(segment_info->segment_id(), std::move(request_state),
std::move(segment_result), std::move(callback), success);
}
void SegmentResultProviderImpl::PostResultCallback(
std::unique_ptr<RequestState> request_state,
std::unique_ptr<SegmentResult> result) {
task_runner_->PostTask(
FROM_HERE, base::BindOnce(std::move(request_state->options->callback),
std::move(result)));
}
void SegmentResultProviderImpl::RunCallback(
SegmentId segment_id,
std::unique_ptr<RequestState> request_state,
std::unique_ptr<SegmentResult> segment_result,
ResultCallbackWithState callback,
bool success) {
stats::RecordModelExecutionSaveResult(segment_id, success);
if (!success) {
// TODO(ssid): Consider removing this enum, this is the only case where the
// execution status is recorded twice for the same execution request.
stats::RecordModelExecutionStatus(
segment_id,
/*default_provider=*/false,
ModelExecutionStatus::kFailedToSaveResultAfterSuccess);
}
std::move(callback).Run(std::move(request_state), std::move(segment_result));
}
} // namespace
SegmentResultProvider::SegmentResult::SegmentResult(ResultState state)
: state(state) {}
SegmentResultProvider::SegmentResult::SegmentResult(
ResultState state,
const proto::PredictionResult& prediction_result,
float rank)
: state(state), result(prediction_result), rank(rank) {}
SegmentResultProvider::SegmentResult::~SegmentResult() = default;
SegmentResultProvider::GetResultOptions::GetResultOptions() = default;
SegmentResultProvider::GetResultOptions::~GetResultOptions() = default;
// static
std::unique_ptr<SegmentResultProvider> SegmentResultProvider::Create(
SegmentInfoDatabase* segment_database,
SignalStorageConfig* signal_storage_config,
DefaultModelManager* default_model_manager,
ExecutionService* execution_service,
base::Clock* clock,
bool force_refresh_results) {
return std::make_unique<SegmentResultProviderImpl>(
segment_database, signal_storage_config, default_model_manager,
execution_service, clock, force_refresh_results);
}
} // namespace segmentation_platform