blob: 57debf6a6f6d54e3f4cd5b0677f18cbd99eb1941 [file] [log] [blame]
// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_SELECTION_SEGMENT_RESULT_PROVIDER_H_
#define COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_SELECTION_SEGMENT_RESULT_PROVIDER_H_
#include "base/functional/callback.h"
#include "base/memory/scoped_refptr.h"
#include "components/segmentation_platform/internal/database/segment_info_database.h"
#include "components/segmentation_platform/internal/execution/execution_request.h"
#include "components/segmentation_platform/public/input_context.h"
#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
#include "third_party/abseil-cpp/absl/types/optional.h"
namespace base {
class Clock;
}
namespace segmentation_platform {
class DefaultModelManager;
class ExecutionService;
class SignalStorageConfig;
// Used for retrieving the result of a particular model.
// The steps to get result for the model are as follows:
// 1. Returns score from database as result if present and valid. Do this step
// only if `ignore_db_score = false` else jump to step 2.
// 2. If there is no valid score present in database or `ignore_db_scores =
// true`, run the server model. If a valid score is computed return it as
// result and save it to database.
// 3. If there is no valid score returned, execute and get score from default
// model and return that as result.
// When the database will have support for storing default models, before
// running default models score would be checked for them database. This isn't
// implemented yet.
class SegmentResultProvider {
public:
SegmentResultProvider() = default;
virtual ~SegmentResultProvider() = default;
enum class ResultState {
kUnknown = 0,
kSuccessFromDatabase = 1,
kSegmentNotAvailable = 2,
kSignalsNotCollected = 3,
kDatabaseScoreNotReady = 4,
kDefaultModelSignalNotCollected = 5,
kDefaultModelMetadataMissing = 6,
kDefaultModelExecutionFailed = 7,
kDefaultModelScoreUsed = 8,
kTfliteModelExecutionFailed = 9,
kTfliteModelScoreUsed = 10,
};
struct SegmentResult {
explicit SegmentResult(ResultState state);
SegmentResult(ResultState state,
const proto::PredictionResult& result,
float rank);
~SegmentResult();
SegmentResult(const SegmentResult&) = delete;
SegmentResult& operator=(const SegmentResult&) = delete;
ResultState state = ResultState::kUnknown;
// Contains the raw scores along with output config.
proto::PredictionResult result;
// TODO(shaktisahu): Delete this after full migration.
absl::optional<float> rank;
};
using SegmentResultCallback =
base::OnceCallback<void(std::unique_ptr<SegmentResult>)>;
// Creates the instance.
static std::unique_ptr<SegmentResultProvider> Create(
SegmentInfoDatabase* segment_info_database,
SignalStorageConfig* signal_storage_config,
DefaultModelManager* default_model_manager,
ExecutionService* execution_service,
base::Clock* clock,
bool force_refresh_results);
// Options for `GetSegmentResult()`.
struct GetResultOptions {
GetResultOptions();
~GetResultOptions();
// The segment ID to fetch result for.
SegmentId segment_id = SegmentId::OPTIMIZATION_TARGET_UNKNOWN;
// The key is needed for computing segment rank from discrete mapping.
std::string discrete_mapping_key;
// Ignores model results stored in database and executes them to fetch
// results. When set to false, the result could be from following:
// * Score cached in the database
// * When score is missing, executes Database model if exists.
// * If database model fails to execute, executes default model.
// When set to true, the result could be from following:
// * Execution of TFLite model.
// * Fallback to default when model is missing.
bool ignore_db_scores = false;
// If `ignore_db_scores` is true and TFLite model is available, then write
// the results to database. Used when user wants to rerun the database
// model.
// TODO(ssid): `callback` is sometimes not called if this field is set to
// true. Fix execution scheduler to run callback always even if save to
// database is true.
// TODO(ssid): Consider moving this option out as a different SaveRequest
// method in this class.
bool save_results_to_db = false;
// Callback to return the segment result.
SegmentResultCallback callback;
// Current context of the browser that is needed by feature processor for
// some of the models.
scoped_refptr<InputContext> input_context;
};
// Returns latest available score for the segment.
virtual void GetSegmentResult(std::unique_ptr<GetResultOptions> options) = 0;
};
} // namespace segmentation_platform
#endif // COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_SELECTION_SEGMENT_RESULT_PROVIDER_H_