blob: 77a1aeaf8391b04b73767a936d38f52449a5749c [file] [log] [blame]
// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/segmentation_platform/internal/selection/request_handler.h"
#include "base/memory/raw_ref.h"
#include "base/memory/scoped_refptr.h"
#include "base/task/single_thread_task_runner.h"
#include "base/time/clock.h"
#include "components/segmentation_platform/internal/post_processor/post_processor.h"
#include "components/segmentation_platform/internal/selection/segment_result_provider.h"
#include "components/segmentation_platform/internal/stats.h"
#include "components/segmentation_platform/public/config.h"
#include "components/segmentation_platform/public/input_context.h"
#include "components/segmentation_platform/public/prediction_options.h"
#include "components/segmentation_platform/public/proto/prediction_result.pb.h"
#include "components/segmentation_platform/public/result.h"
#include "components/segmentation_platform/public/trigger.h"
namespace segmentation_platform {
namespace {
PredictionStatus ResultStateToPredictionStatus(
SegmentResultProvider::ResultState result_state) {
switch (result_state) {
case SegmentResultProvider::ResultState::kSuccessFromDatabase:
case SegmentResultProvider::ResultState::kDefaultModelScoreUsed:
case SegmentResultProvider::ResultState::kTfliteModelScoreUsed:
return PredictionStatus::kSucceeded;
case SegmentResultProvider::ResultState::kSignalsNotCollected:
return PredictionStatus::kNotReady;
default:
return PredictionStatus::kFailed;
}
}
class RequestHandlerImpl : public RequestHandler {
public:
RequestHandlerImpl(const Config& config,
std::unique_ptr<SegmentResultProvider> result_provider,
ExecutionService* execution_service);
~RequestHandlerImpl() override;
// Disallow copy/assign.
RequestHandlerImpl(const RequestHandlerImpl&) = delete;
RequestHandlerImpl& operator=(const RequestHandlerImpl&) = delete;
// Client API. See `SegmentationPlatformService::GetClassificationResult`.
void GetClassificationResult(const PredictionOptions& options,
scoped_refptr<InputContext> input_context,
ClassificationResultCallback callback) override;
void GetAnnotatedNumericResult(
const PredictionOptions& options,
scoped_refptr<InputContext> input_context,
AnnotatedNumericResultCallback callback) override;
private:
void GetModelResult(const PredictionOptions& options,
scoped_refptr<InputContext> input_context,
SegmentResultProvider::SegmentResultCallback callback);
void OnGetModelResultForClassification(
scoped_refptr<InputContext> input_context,
ClassificationResultCallback classification_callback,
std::unique_ptr<SegmentResultProvider::SegmentResult> result);
void OnGetAnnotatedNumericResult(
scoped_refptr<InputContext> input_context,
AnnotatedNumericResultCallback callback,
std::unique_ptr<SegmentResultProvider::SegmentResult> result);
TrainingRequestId CollectTrainingData(
scoped_refptr<InputContext> input_context);
// The config for providing client config params.
const raw_ref<const Config> config_;
// The result provider responsible for getting the result, either by running
// the model or getting results from the cache as necessary.
std::unique_ptr<SegmentResultProvider> result_provider_;
// Pointer to the execution service.
const raw_ptr<ExecutionService> execution_service_{};
base::WeakPtrFactory<RequestHandlerImpl> weak_ptr_factory_{this};
};
RequestHandlerImpl::RequestHandlerImpl(
const Config& config,
std::unique_ptr<SegmentResultProvider> result_provider,
ExecutionService* execution_service)
: config_(config),
result_provider_(std::move(result_provider)),
execution_service_(execution_service) {}
RequestHandlerImpl::~RequestHandlerImpl() = default;
void RequestHandlerImpl::GetClassificationResult(
const PredictionOptions& options,
scoped_refptr<InputContext> input_context,
ClassificationResultCallback callback) {
DCHECK(options.on_demand_execution);
GetModelResult(
options, input_context,
base::BindOnce(&RequestHandlerImpl::OnGetModelResultForClassification,
weak_ptr_factory_.GetWeakPtr(), input_context,
std::move(callback)));
}
void RequestHandlerImpl::GetAnnotatedNumericResult(
const PredictionOptions& options,
scoped_refptr<InputContext> input_context,
AnnotatedNumericResultCallback callback) {
DCHECK(options.on_demand_execution);
GetModelResult(
options, input_context,
base::BindOnce(&RequestHandlerImpl::OnGetAnnotatedNumericResult,
weak_ptr_factory_.GetWeakPtr(), input_context,
std::move(callback)));
}
void RequestHandlerImpl::GetModelResult(
const PredictionOptions& options,
scoped_refptr<InputContext> input_context,
SegmentResultProvider::SegmentResultCallback callback) {
DCHECK_EQ(config_->segments.size(), 1u);
auto result_options =
std::make_unique<SegmentResultProvider::GetResultOptions>();
// Note that, this assumes that a client has only one model.
result_options->segment_id = config_->segments.begin()->first;
result_options->ignore_db_scores = options.on_demand_execution;
result_options->input_context = input_context;
result_options->callback = std::move(callback);
result_provider_->GetSegmentResult(std::move(result_options));
}
void RequestHandlerImpl::OnGetModelResultForClassification(
scoped_refptr<InputContext> input_context,
ClassificationResultCallback classification_callback,
std::unique_ptr<SegmentResultProvider::SegmentResult> result) {
PostProcessor post_processor;
PredictionStatus status = PredictionStatus::kFailed;
proto::PredictionResult pred_result;
absl::optional<TrainingRequestId> request_id;
if (result) {
stats::RecordSegmentSelectionFailure(
*config_, stats::GetSuccessOrFailureReason(result->state));
status = ResultStateToPredictionStatus(result->state);
pred_result = result->result;
stats::RecordClassificationResultComputed(*config_, pred_result);
request_id = CollectTrainingData(input_context);
} else {
stats::RecordSegmentSelectionFailure(
*config_, stats::SegmentationSelectionFailureReason::
kOnDemandModelExecutionFailed);
}
ClassificationResult classification_result =
post_processor.GetPostProcessedClassificationResult(pred_result, status);
if (request_id && !request_id.value().is_null()) {
classification_result.request_id = request_id.value();
}
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, base::BindOnce(std::move(classification_callback),
classification_result));
}
void RequestHandlerImpl::OnGetAnnotatedNumericResult(
scoped_refptr<InputContext> input_context,
AnnotatedNumericResultCallback callback,
std::unique_ptr<SegmentResultProvider::SegmentResult> segment_result) {
PredictionStatus status = PredictionStatus::kFailed;
AnnotatedNumericResult result(status);
absl::optional<TrainingRequestId> request_id;
if (segment_result) {
status = ResultStateToPredictionStatus(segment_result->state);
result = PostProcessor().GetAnnotatedNumericResult(segment_result->result,
status);
request_id = CollectTrainingData(input_context);
}
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, base::BindOnce(std::move(callback), std::move(result)));
}
TrainingRequestId RequestHandlerImpl::CollectTrainingData(
scoped_refptr<InputContext> input_context) {
// The execution service and training data collector, might be null in
// testing.
if (!execution_service_ || !execution_service_->training_data_collector()) {
return TrainingRequestId();
}
return execution_service_->training_data_collector()->OnDecisionTime(
config_->segments.begin()->first, input_context,
proto::TrainingOutputs::TriggerConfig::ONDEMAND);
}
} // namespace
// static
std::unique_ptr<RequestHandler> RequestHandler::Create(
const Config& config,
std::unique_ptr<SegmentResultProvider> result_provider,
ExecutionService* execution_service) {
return std::make_unique<RequestHandlerImpl>(
config, std::move(result_provider), execution_service);
}
} // namespace segmentation_platform