blob: f4e69a788b3b414dffe815aeaba51ae22ba98986 [file] [log] [blame]
// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_SEGMENTATION_PLATFORM_PUBLIC_SEGMENTATION_PLATFORM_SERVICE_H_
#define COMPONENTS_SEGMENTATION_PLATFORM_PUBLIC_SEGMENTATION_PLATFORM_SERVICE_H_
#include <string>
#include "base/functional/callback.h"
#include "base/metrics/histogram_base.h"
#include "base/supports_user_data.h"
#include "base/types/id_type.h"
#include "build/build_config.h"
#include "components/keyed_service/core/keyed_service.h"
#include "components/segmentation_platform/public/database_client.h"
#include "components/segmentation_platform/public/input_context.h"
#include "components/segmentation_platform/public/prediction_options.h"
#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
#include "components/segmentation_platform/public/result.h"
#include "components/segmentation_platform/public/trigger.h"
#if BUILDFLAG(IS_ANDROID)
#include "base/android/jni_android.h"
#endif // BUILDFLAG(IS_ANDROID)
class PrefRegistrySimple;
namespace segmentation_platform {
class ServiceProxy;
struct SegmentSelectionResult;
using CallbackId = base::IdType32<class OnDemandSegmentSelectionCallbackTag>;
// Structure used to store data for training output collection.
// The name should be UMA histogram name or user action name. Currently this
// output will be appended to the training data, but the histogram name is not
// recorded. So, each model can only take one type of metric. Using 2 different
// metrics for same model would make it unclear what the value means while
// training.
struct TrainingLabels {
TrainingLabels();
~TrainingLabels();
// Name and sample of the output metric to be collected as training data.
std::optional<std::pair<std::string, base::HistogramBase::Sample>>
output_metric;
TrainingLabels(const TrainingLabels& other);
};
// The core class of segmentation platform that integrates all the required
// pieces on the client side.
class SegmentationPlatformService : public KeyedService,
public base::SupportsUserData {
public:
#if BUILDFLAG(IS_ANDROID)
// Returns a Java object of the type SegmentationPlatformService for the given
// SegmentationPlatformService.
static base::android::ScopedJavaLocalRef<jobject> GetJavaObject(
SegmentationPlatformService* segmentation_platform_service);
#endif // BUILDFLAG(IS_ANDROID)
using SuccessCallback = base::OnceCallback<void(bool)>;
SegmentationPlatformService() = default;
~SegmentationPlatformService() override = default;
// Disallow copy/assign.
SegmentationPlatformService(const SegmentationPlatformService&) = delete;
SegmentationPlatformService& operator=(const SegmentationPlatformService&) =
delete;
// Registers preferences used by this class in the provided |registry|. This
// should be called for the Profile registry.
static void RegisterProfilePrefs(PrefRegistrySimple* registry);
// Registers preferences used by this class in the provided |registry|. This
// should be called for the local state registry.
static void RegisterLocalStatePrefs(PrefRegistrySimple* registry);
using SegmentSelectionCallback =
base::OnceCallback<void(const SegmentSelectionResult&)>;
// Called to get the selected segment asynchronously. If none, returns empty
// result.
virtual void GetSelectedSegment(const std::string& segmentation_key,
SegmentSelectionCallback callback) = 0;
// Called to get the classification results for a given client. The
// classification config must be defined in the associated model metadata.
// Depending on the options and client config, it either runs the associated
// model or uses unexpired cached results.
virtual void GetClassificationResult(
const std::string& segmentation_key,
const PredictionOptions& prediction_options,
scoped_refptr<InputContext> input_context,
ClassificationResultCallback callback) = 0;
// Get the result from the model execution, annotated with output config to
// interpret the results. Depending on the options and client config, it
// either runs the associated model or uses unexpired cached results. This API
// is experimental and does not cleanly support transitions from a heuristic
// to ML models. This API is not usable for most ML models since ML models
// require normalization of the output values to make them usable.
virtual void GetAnnotatedNumericResult(
const std::string& segmentation_key,
const PredictionOptions& prediction_options,
scoped_refptr<InputContext> input_context,
AnnotatedNumericResultCallback callback) = 0;
// Called to get the selected segment synchronously. If none, returns empty
// result.
virtual SegmentSelectionResult GetCachedSegmentResult(
const std::string& segmentation_key) = 0;
// Called to trigger training data collection for a given request ID. Request
// IDs are given when |GetClassificationResult| is called.
virtual void CollectTrainingData(proto::SegmentId segment_id,
TrainingRequestId request_id,
const TrainingLabels& param,
SuccessCallback callback) = 0;
// Called to enable or disable metrics collection. Must be explicitly called
// on startup.
virtual void EnableMetrics(bool signal_collection_allowed) = 0;
// Called to get the proxy that is used for debugging purpose.
virtual ServiceProxy* GetServiceProxy();
// Get access to the segmentation databases using the client.
// WARNING: This will return nullptr till `IsPlatformInitialized()` is false.
// You can observe ServiceProxy to get notified when platform is initialized.
// TODO(ssid): Remove the initialization requirement by handling waiting for
// init internally.
// TODO(ssid): Add a Java version of this API.
virtual DatabaseClient* GetDatabaseClient();
// Returns true when platform finished initializing, and can execute models.
// The `GetSelectedSegment()` calls work without full platform initialization
// since they load results from previous sessions.
virtual bool IsPlatformInitialized() = 0;
};
} // namespace segmentation_platform
#endif // COMPONENTS_SEGMENTATION_PLATFORM_PUBLIC_SEGMENTATION_PLATFORM_SERVICE_H_