blob: bfd09bf6d9223fb9a065dac38095f4a3f4385ace [file] [log] [blame]
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef MEDIA_LEARNING_IMPL_DISTRIBUTION_REPORTER_H_
#define MEDIA_LEARNING_IMPL_DISTRIBUTION_REPORTER_H_
#include <set>
#include "base/callback.h"
#include "base/component_export.h"
#include "base/macros.h"
#include "base/memory/weak_ptr.h"
#include "base/optional.h"
#include "media/learning/common/learning_task.h"
#include "media/learning/impl/model.h"
#include "media/learning/impl/target_histogram.h"
#include "services/metrics/public/cpp/ukm_source_id.h"
namespace media {
namespace learning {
// Helper class to report on predicted distrubutions vs target distributions.
// Use DistributionReporter::Create() to create one that's appropriate for a
// specific learning task.
class COMPONENT_EXPORT(LEARNING_IMPL) DistributionReporter {
public:
// Extra information provided to the reporter for each prediction.
struct PredictionInfo {
// What value was observed?
TargetValue observed;
// UKM source id to use when logging this result.
// This will be filled in by the LearningTaskController. For example, the
// MojoLearningTaskControllerService will be created in the browser by the
// MediaMetricsProvider, which gets the SourceId via callback from the
// RenderFrameHostDelegate on construction.
//
// TODO(liberato): Right now, this is not filled in anywhere. When the
// mojo service is created (MediaMetricsProvider), record the source id and
// memorize it in any MojoLearningTaskControllerService that's created by
// the MediaMetricsProvider, either directly or in a wrapper for the
// mojo controller.
ukm::SourceId source_id = ukm::kInvalidSourceId;
// Total weight of the training data used to create this model.
double total_training_weight = 0.;
// Total number of examples (unweighted) in the training set.
size_t total_training_examples = 0u;
// TODO(liberato): Move the feature subset here.
};
// Create a DistributionReporter that's suitable for |task|.
static std::unique_ptr<DistributionReporter> Create(const LearningTask& task);
virtual ~DistributionReporter();
// Returns a prediction CB that will be compared to |prediction_info.observed|
// TODO(liberato): This is too complicated. Skip the callback and just call
// us with the predicted value.
virtual Model::PredictionCB GetPredictionCallback(
const PredictionInfo& prediction_info);
// Set the subset of features that is being used to train the model. This is
// used for feature importance measuremnts.
//
// For example, sending in the set [0, 3, 7] would indicate that the model was
// trained with task().feature_descriptions[0, 3, 7] only.
//
// Note that UMA reporting only supports single feature subsets.
void SetFeatureSubset(const std::set<int>& feature_indices);
protected:
DistributionReporter(const LearningTask& task);
const LearningTask& task() const { return task_; }
// Implemented by subclasses to report a prediction.
virtual void OnPrediction(const PredictionInfo& prediction_info,
TargetHistogram predicted) = 0;
const base::Optional<std::set<int>>& feature_indices() const {
return feature_indices_;
}
private:
LearningTask task_;
// If provided, then these are the features that are used to train the model.
// Otherwise, we assume that all features are used.
base::Optional<std::set<int>> feature_indices_;
base::WeakPtrFactory<DistributionReporter> weak_factory_{this};
DISALLOW_COPY_AND_ASSIGN(DistributionReporter);
};
} // namespace learning
} // namespace media
#endif // MEDIA_LEARNING_IMPL_DISTRIBUTION_REPORTER_H_