blob: 2796da932f7e1c2d369841a75fda5cd9b6bae1b8 [file] [log] [blame]
// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/segmentation_platform/internal/selection/request_dispatcher.h"
#include <set>
#include <utility>
#include "base/containers/circular_deque.h"
#include "base/functional/callback_forward.h"
#include "base/location.h"
#include "base/logging.h"
#include "base/memory/scoped_refptr.h"
#include "base/task/single_thread_task_runner.h"
#include "base/time/time.h"
#include "components/segmentation_platform/internal/database/config_holder.h"
#include "components/segmentation_platform/internal/post_processor/post_processor.h"
#include "components/segmentation_platform/internal/selection/request_handler.h"
#include "components/segmentation_platform/internal/selection/segment_result_provider.h"
#include "components/segmentation_platform/internal/stats.h"
#include "components/segmentation_platform/public/config.h"
#include "components/segmentation_platform/public/prediction_options.h"
#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
#include "components/segmentation_platform/public/result.h"
namespace segmentation_platform {
namespace {
// Amount of time to wait for model initialization. During this period requests
// for uninitialized models will be enqueued and processed either when the model
// is ready or when this timeout expires. Time is 200ms to cover 80% of cases
// (According to
// OptimizationGuide.ModelHandler.HandlerCreatedToModelAvailable histogram).
const int kModelInitializationTimeoutMs = 200;
void PostProcess(const RawResult& raw_result, ClassificationResult& result) {
result = PostProcessor().GetPostProcessedClassificationResult(
std::move(raw_result.result), raw_result.status);
result.request_id = raw_result.request_id;
}
void PostProcess(const RawResult& raw_result, AnnotatedNumericResult& result) {
result = raw_result;
}
} // namespace
RequestDispatcher::RequestDispatcher(StorageService* storage_service)
: storage_service_(storage_service) {
std::set<proto::SegmentId> found_segments;
// Individual models must be loaded from disk or fetched from network. Fill a
// list to keep track of which ones are still pending.
uninitialized_segmentation_keys_ =
storage_service_->config_holder()->non_legacy_segmentation_keys();
}
RequestDispatcher::~RequestDispatcher() = default;
void RequestDispatcher::OnPlatformInitialized(
bool success,
ExecutionService* execution_service,
std::map<std::string, std::unique_ptr<SegmentResultProvider>>
result_providers) {
storage_init_status_ = success;
// Only set request handlers if it has not been set for testing already.
if (request_handlers_.empty()) {
for (const auto& config : storage_service_->config_holder()->configs()) {
request_handlers_[config->segmentation_key] = RequestHandler::Create(
*config, std::move(result_providers[config->segmentation_key]),
execution_service);
}
}
// Set a timeout to execute all pending requests even if their models didn't
// initialize after |kModelInitializationTimeoutMs|. This is to avoid waiting
// for long periods of time when models need to be downloaded, and to avoid
// requests waiting forever when there's no model.
base::SingleThreadTaskRunner::GetCurrentDefault()->PostDelayedTask(
FROM_HERE,
base::BindOnce(&RequestDispatcher::OnModelInitializationTimeout,
weak_ptr_factory_.GetWeakPtr()),
base::Milliseconds(kModelInitializationTimeoutMs));
}
void RequestDispatcher::ExecuteAllPendingActions() {
while (!pending_actions_.empty()) {
ExecutePendingActionsForKey(pending_actions_.begin()->first);
}
}
void RequestDispatcher::ExecutePendingActionsForKey(
const std::string& segmentation_key) {
auto pending_actions_for_key = pending_actions_.find(segmentation_key);
if (pending_actions_for_key == pending_actions_.end()) {
return;
}
while (!pending_actions_for_key->second.empty()) {
auto callback = std::move(pending_actions_for_key->second.front());
pending_actions_for_key->second.pop_front();
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, std::move(callback));
}
pending_actions_.erase(segmentation_key);
}
void RequestDispatcher::OnModelUpdated(proto::SegmentId segment_id) {
auto key_for_updated_segment =
storage_service_->config_holder()->GetKeyForSegmentId(segment_id);
if (!key_for_updated_segment) {
return;
}
const std::string& segmentation_key = *key_for_updated_segment;
uninitialized_segmentation_keys_.erase(segmentation_key);
ExecutePendingActionsForKey(segmentation_key);
}
void RequestDispatcher::OnModelInitializationTimeout() {
uninitialized_segmentation_keys_.clear();
ExecuteAllPendingActions();
}
template <typename ResultType>
void RequestDispatcher::CallbackWrapper(
const std::string& segmentation_key,
base::Time start_time,
base::OnceCallback<void(const ResultType&)> callback,
bool is_cached_result,
const RawResult& raw_result) {
Config* config =
storage_service_->config_holder()->GetConfigForSegmentationKey(
segmentation_key);
CHECK(config);
stats::RecordClassificationRequestTotalDuration(
*config, base::Time::Now() - start_time);
if (!is_cached_result && raw_result.status == PredictionStatus::kSucceeded) {
// Verify if this does not accidentally overwrite results for cached
// segments.
// TODO(ssid): Remove this check in the future if current system looks good.
CHECK(!config->auto_execute_and_cache)
<< "Overwriting results without checking TTL "
<< config->segmentation_key;
// Cache model execution results in prefs in case they are useful to fetch
// results early startup without database, or to record field trials for the
// session based on ondemand executions.
storage_service_->cached_result_writer()->CacheModelExecution(
config, raw_result.result);
}
ResultType result(PredictionStatus::kFailed);
PostProcess(std::move(raw_result), result);
VLOG(1) << "Computed result for " << segmentation_key << ": "
<< result.ToDebugString();
std::move(callback).Run(result);
}
void RequestDispatcher::GetModelResult(
const std::string& segmentation_key,
const PredictionOptions& options,
scoped_refptr<InputContext> input_context,
WrappedCallback callback) {
if (storage_service_->config_holder()->IsLegacySegmentationKey(
segmentation_key)) {
LOG(ERROR)
<< "Segmentation key: " << segmentation_key
<< " is using a legacy config with the new API which is not "
"supported. Legacy segments should migrate to the new config.";
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE,
base::BindOnce(std::move(callback), /*is_cached_result=*/false,
RawResult(PredictionStatus::kFailed)));
return;
}
Config* config =
storage_service_->config_holder()->GetConfigForSegmentationKey(
segmentation_key);
CHECK(config);
if (!options.on_demand_execution) {
// Returns result directly from prefs for non-ondemand models.
auto pred_result = storage_service_->cached_result_provider()
->GetPredictionResultForClient(segmentation_key);
RawResult raw_result(PredictionStatus::kFailed);
if (pred_result) {
raw_result = PostProcessor().GetRawResult(*pred_result,
PredictionStatus::kSucceeded);
storage_service_->cached_result_writer()->MarkResultAsUsed(config);
stats::RecordSegmentSelectionFailure(
*config, stats::SegmentationSelectionFailureReason::
kClassificationResultFromPrefs);
} else {
stats::RecordSegmentSelectionFailure(
*config, stats::SegmentationSelectionFailureReason::
kClassificationResultNotAvailableInPrefs);
}
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE,
base::BindOnce(std::move(callback), /*is_cached_result=*/true,
std::move(raw_result)));
return;
}
// TODO(ssid): Support cached results for all APIs.
DCHECK(options.on_demand_execution);
// For on-demand results, we need to run the models for which we need DB
// initialization to be complete. Hence cache the request if platform
// initialization isn't completed yet.
if (!storage_init_status_.has_value() ||
uninitialized_segmentation_keys_.contains(segmentation_key)) {
// If the platform isn't fully initialized, cache the input arguments to
// run later.
pending_actions_[segmentation_key].push_back(
base::BindOnce(&RequestDispatcher::GetModelResult,
weak_ptr_factory_.GetWeakPtr(), segmentation_key,
options, std::move(input_context), std::move(callback)));
return;
}
// If the platform initialization failed, invoke callback to return invalid
// results.
if (!storage_init_status_.value()) {
stats::RecordSegmentSelectionFailure(
*config, stats::SegmentationSelectionFailureReason::kDBInitFailure);
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE,
base::BindOnce(std::move(callback), /*is_cached_result=*/false,
RawResult(PredictionStatus::kFailed)));
return;
}
auto iter = request_handlers_.find(segmentation_key);
CHECK(iter != request_handlers_.end());
auto final_callback =
base::BindOnce(std::move(callback), /*is_cached_result=*/false);
iter->second->GetPredictionResult(options, input_context,
std::move(final_callback));
}
void RequestDispatcher::GetClassificationResult(
const std::string& segmentation_key,
const PredictionOptions& options,
scoped_refptr<InputContext> input_context,
ClassificationResultCallback callback) {
auto wrapped_callback =
base::BindOnce(&RequestDispatcher::CallbackWrapper<ClassificationResult>,
weak_ptr_factory_.GetWeakPtr(), segmentation_key,
base::Time::Now(), std::move(callback));
GetModelResult(segmentation_key, options, input_context,
std::move(wrapped_callback));
}
void RequestDispatcher::GetAnnotatedNumericResult(
const std::string& segmentation_key,
const PredictionOptions& options,
scoped_refptr<InputContext> input_context,
AnnotatedNumericResultCallback callback) {
auto wrapped_callback = base::BindOnce(
&RequestDispatcher::CallbackWrapper<AnnotatedNumericResult>,
weak_ptr_factory_.GetWeakPtr(), segmentation_key, base::Time::Now(),
std::move(callback));
GetModelResult(segmentation_key, options, input_context,
std::move(wrapped_callback));
}
int RequestDispatcher::GetPendingActionCountForTesting() {
int total_actions = 0;
for (auto& actions_for_key : pending_actions_) {
total_actions += actions_for_key.second.size();
}
return total_actions;
}
} // namespace segmentation_platform