blob: 3815d5c48421d716d7aec8e8bb62c391e5f50269 [file] [log] [blame]
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <memory>
#include <vector>
#include "base/bind.h"
#include "base/test/scoped_task_environment.h"
#include "media/learning/common/learning_task.h"
#include "media/learning/impl/distribution_reporter.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace media {
namespace learning {
class DistributionReporterTest : public testing::Test {
public:
base::test::ScopedTaskEnvironment scoped_task_environment_;
LearningTask task_;
std::unique_ptr<DistributionReporter> reporter_;
};
TEST_F(DistributionReporterTest, DistributionReporterDoesNotCrash) {
// Make sure that we request some sort of reporting.
task_.target_description.ordering = LearningTask::Ordering::kNumeric;
task_.uma_hacky_confusion_matrix = "test";
reporter_ = DistributionReporter::Create(task_);
EXPECT_NE(reporter_, nullptr);
const TargetValue Zero(0);
const TargetValue One(1);
TargetDistribution observed;
// Observe an average of 2 / 3.
observed[Zero] = 100;
observed[One] = 200;
auto cb = reporter_->GetPredictionCallback(observed);
TargetDistribution predicted;
// Predict an average of 5 / 9.
predicted[Zero] = 40;
predicted[One] = 50;
std::move(cb).Run(predicted);
// TODO(liberato): When we switch to ukm, use a TestUkmRecorder to make sure
// that it fills in the right stuff.
// https://chromium-review.googlesource.com/c/chromium/src/+/1385107 .
}
TEST_F(DistributionReporterTest, DistributionReporterNeedsUmaName) {
// Make sure that we don't get a reporter if we don't request any reporting.
task_.target_description.ordering = LearningTask::Ordering::kNumeric;
task_.uma_hacky_confusion_matrix = "";
reporter_ = DistributionReporter::Create(task_);
EXPECT_EQ(reporter_, nullptr);
}
TEST_F(DistributionReporterTest,
DistributionReporterHackyConfusionMatrixNeedsRegression) {
// Hacky confusion matrix reporting only works with regression.
task_.target_description.ordering = LearningTask::Ordering::kUnordered;
task_.uma_hacky_confusion_matrix = "test";
reporter_ = DistributionReporter::Create(task_);
EXPECT_EQ(reporter_, nullptr);
}
} // namespace learning
} // namespace media