blob: 195de4ba7743ef0ab676f730576f41d4f1c0a628 [file] [log] [blame]
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef MEDIA_LEARNING_COMMON_LEARNING_TASK_CONTROLLER_H_
#define MEDIA_LEARNING_COMMON_LEARNING_TASK_CONTROLLER_H_
#include "base/callback.h"
#include "base/component_export.h"
#include "base/macros.h"
#include "base/optional.h"
#include "base/unguessable_token.h"
#include "media/learning/common/labelled_example.h"
#include "media/learning/common/learning_task.h"
#include "services/metrics/public/cpp/ukm_source_id.h"
namespace media {
namespace learning {
// Wrapper struct for completing an observation via LearningTaskController.
// Most callers will just send in a TargetValue, so this lets us provide a
// default weight. Further, a few callers will add optional data, like the UKM
// SourceId, which most callers don't care about.
struct ObservationCompletion {
ObservationCompletion() = default;
/* implicit */ ObservationCompletion(const TargetValue& target,
WeightType w = 1.,
ukm::SourceId id = ukm::kInvalidSourceId)
: target_value(target), weight(w), source_id(id) {}
TargetValue target_value;
WeightType weight;
// Optional, and ignored from the renderer.
ukm::SourceId source_id;
// Mostly for gmock matchers.
bool operator==(const ObservationCompletion& rhs) const {
return target_value == rhs.target_value && weight == rhs.weight &&
source_id == rhs.source_id;
}
};
// Client for a single learning task. Intended to be the primary API for client
// code that generates FeatureVectors / requests predictions for a single task.
// The API supports sending in an observed FeatureVector without a target value,
// so that framework-provided features (FeatureProvider) can be snapshotted at
// the right time. One doesn't generally want to wait until the TargetValue is
// observed to do that.
class COMPONENT_EXPORT(LEARNING_COMMON) LearningTaskController {
public:
LearningTaskController() = default;
virtual ~LearningTaskController() = default;
// Start a new observation. Call this at the time one would try to predict
// the TargetValue. This lets the framework snapshot any framework-provided
// feature values at prediction time. Later, if you want to turn these
// features into an example for training a model, then call
// CompleteObservation with the same id and an ObservationCompletion.
// Otherwise, call CancelObservation with |id|. It's also okay to destroy the
// controller with outstanding observations; these will be cancelled if no
// |default_target| was specified, or completed with |default_target|.
//
// TODO(liberato): This should optionally take a callback to receive a
// prediction for the FeatureVector.
// TODO(liberato): See if this ends up generating smaller code with pass-by-
// value or with |FeatureVector&&|, once we have callers that can actually
// benefit from it.
virtual void BeginObservation(
base::UnguessableToken id,
const FeatureVector& features,
const base::Optional<TargetValue>& default_target =
base::Optional<TargetValue>()) = 0;
// Complete an observation by sending a completion.
virtual void CompleteObservation(base::UnguessableToken id,
const ObservationCompletion& completion) = 0;
// Notify the LearningTaskController that no completion will be sent.
virtual void CancelObservation(base::UnguessableToken id) = 0;
// Returns the LearningTask associated with |this|.
virtual const LearningTask& GetLearningTask() = 0;
private:
DISALLOW_COPY_AND_ASSIGN(LearningTaskController);
};
} // namespace learning
} // namespace media
#endif // MEDIA_LEARNING_COMMON_LEARNING_TASK_CONTROLLER_H_