blob: faebcc266f90f5beb5d1d9eb07dc8b064d4c700f [file] [log] [blame]
// Copyright 2021 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_SELECTION_SEGMENT_RESULT_PROVIDER_H_
#define COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_SELECTION_SEGMENT_RESULT_PROVIDER_H_
#include "base/callback.h"
#include "base/memory/scoped_refptr.h"
#include "components/segmentation_platform/internal/database/segment_info_database.h"
#include "components/segmentation_platform/public/input_context.h"
#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
#include "third_party/abseil-cpp/absl/types/optional.h"
namespace base {
class Clock;
}
namespace segmentation_platform {
class DefaultModelManager;
class ExecutionService;
class SignalStorageConfig;
// Used for retrieving the result of a particular model.
// Supports 3 use cases:
// 1. Fetching cached and valid results from the segment database.
// 2. Fallback to default model when cached results are missing. Executes the
// default model and provides the result.
// 3. Execute the TFLite model and provide the result when `ignore_db_scores`
// is set.
class SegmentResultProvider {
public:
SegmentResultProvider() = default;
virtual ~SegmentResultProvider() = default;
enum class ResultState {
kUnknown = 0,
kSuccessFromDatabase = 1,
kSegmentNotAvailable = 2,
kSignalsNotCollected = 3,
kDatabaseScoreNotReady = 4,
kDefaultModelSignalNotCollected = 5,
kDefaultModelMetadataMissing = 6,
kDefaultModelExecutionFailed = 7,
kDefaultModelScoreUsed = 8,
kTfliteModelExecutionFailed = 9,
kTfliteModelScoreUsed = 10,
};
struct SegmentResult {
explicit SegmentResult(ResultState state);
SegmentResult(ResultState state, int rank);
~SegmentResult();
SegmentResult(SegmentResult&) = delete;
SegmentResult& operator=(SegmentResult&) = delete;
ResultState state = ResultState::kUnknown;
absl::optional<int> rank;
};
using SegmentResultCallback =
base::OnceCallback<void(std::unique_ptr<SegmentResult>)>;
// Creates the instance.
static std::unique_ptr<SegmentResultProvider> Create(
SegmentInfoDatabase* segment_info_database,
SignalStorageConfig* signal_storage_config,
DefaultModelManager* default_model_manager,
ExecutionService* execution_service,
base::Clock* clock,
bool force_refresh_results);
// Options for `GetSegmentResult()`.
struct GetResultOptions {
GetResultOptions();
~GetResultOptions();
// The segment ID to fetch result for.
SegmentId segment_id = SegmentId::OPTIMIZATION_TARGET_UNKNOWN;
// The key is needed for computing segment from discrete mapping.
std::string segmentation_key;
// Ignores model results stored in database and executes them to fetch
// results. When set to false, the result could be from following:
// * Score cached in the database
// * Execution of default model when score is missing.
// When set to true, the result could be from following:
// * Execution of TFLite model.
// * TODO(ssid): Support fallback to default when model is missing.
bool ignore_db_scores = false;
// Callback to return the segment result.
SegmentResultCallback callback;
// Current context of the browser that is needed by feature processor for
// some of the models.
scoped_refptr<InputContext> input_context;
};
// Returns latest available score for the segment.
virtual void GetSegmentResult(std::unique_ptr<GetResultOptions> options) = 0;
};
} // namespace segmentation_platform
#endif // COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_SELECTION_SEGMENT_RESULT_PROVIDER_H_