blob: 78b22e65c93cea78d56d453ecc0b07dd2f34ab9a [file] [log] [blame]
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef MEDIA_LEARNING_IMPL_DISTRIBUTION_REPORTER_H_
#define MEDIA_LEARNING_IMPL_DISTRIBUTION_REPORTER_H_
#include "base/callback.h"
#include "base/component_export.h"
#include "base/macros.h"
#include "base/memory/weak_ptr.h"
#include "media/learning/common/learning_task.h"
#include "media/learning/impl/model.h"
#include "media/learning/impl/target_distribution.h"
namespace media {
namespace learning {
// Helper class to report on predicted distrubutions vs target distributions.
// Use DistributionReporter::Create() to create one that's appropriate for a
// specific learning task.
class COMPONENT_EXPORT(LEARNING_IMPL) DistributionReporter {
public:
// Create a DistributionReporter that's suitable for |task|.
static std::unique_ptr<DistributionReporter> Create(const LearningTask& task);
virtual ~DistributionReporter();
// Returns a prediction CB that will be compared to |observed|. |observed| is
// the total number of counts that we observed.
virtual Model::PredictionCB GetPredictionCallback(
TargetDistribution observed);
protected:
DistributionReporter(const LearningTask& task);
const LearningTask& task() const { return task_; }
// Implemented by subclasses to report a prediction.
virtual void OnPrediction(TargetDistribution observed,
TargetDistribution predicted) = 0;
private:
LearningTask task_;
base::WeakPtrFactory<DistributionReporter> weak_factory_;
DISALLOW_COPY_AND_ASSIGN(DistributionReporter);
};
} // namespace learning
} // namespace media
#endif // MEDIA_LEARNING_IMPL_DISTRIBUTION_REPORTER_H_