blob: 99310aa0ed8354a976f61d4cbc52b2ccacbcdb6d [file] [log] [blame]
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <set>
#include "base/callback.h"
#include "base/component_export.h"
#include "base/macros.h"
#include "base/memory/weak_ptr.h"
#include "base/optional.h"
#include "media/learning/common/learning_task.h"
#include "media/learning/impl/model.h"
#include "media/learning/impl/target_histogram.h"
namespace media {
namespace learning {
// Helper class to report on predicted distrubutions vs target distributions.
// Use DistributionReporter::Create() to create one that's appropriate for a
// specific learning task.
class COMPONENT_EXPORT(LEARNING_IMPL) DistributionReporter {
// Extra information provided to the reporter for each prediction.
struct PredictionInfo {
// What value was observed?
TargetValue observed;
// Total weight of the training data used to create this model.
double total_training_weight = 0.;
// Total number of examples (unweighted) in the training set.
size_t total_training_examples = 0u;
// TODO(liberato): Move the feature subset here.
// Create a DistributionReporter that's suitable for |task|.
static std::unique_ptr<DistributionReporter> Create(const LearningTask& task);
virtual ~DistributionReporter();
// Returns a prediction CB that will be compared to |prediction_info.observed|
// TODO(liberato): This is too complicated. Skip the callback and just call
// us with the predicted value.
virtual Model::PredictionCB GetPredictionCallback(
const PredictionInfo& prediction_info);
// Set the subset of features that is being used to train the model. This is
// used for feature importance measuremnts.
// For example, sending in the set [0, 3, 7] would indicate that the model was
// trained with task().feature_descriptions[0, 3, 7] only.
// Note that UMA reporting only supports single feature subsets.
void SetFeatureSubset(const std::set<int>& feature_indices);
DistributionReporter(const LearningTask& task);
const LearningTask& task() const { return task_; }
// Implemented by subclasses to report a prediction.
virtual void OnPrediction(const PredictionInfo& prediction_info,
TargetHistogram predicted) = 0;
const base::Optional<std::set<int>>& feature_indices() const {
return feature_indices_;
LearningTask task_;
// If provided, then these are the features that are used to train the model.
// Otherwise, we assume that all features are used.
base::Optional<std::set<int>> feature_indices_;
base::WeakPtrFactory<DistributionReporter> weak_factory_;
} // namespace learning
} // namespace media