blob: f9face03115c21ed0d5ee472734df5dda2bc2df3 [file] [log] [blame]
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "media/learning/impl/random_tree_trainer.h"
#include "base/bind.h"
#include "base/run_loop.h"
#include "base/test/scoped_task_environment.h"
#include "media/learning/impl/test_random_number_generator.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace media {
namespace learning {
class RandomTreeTest : public testing::TestWithParam<LearningTask::Ordering> {
public:
RandomTreeTest()
: rng_(0),
trainer_(&rng_),
ordering_(GetParam()) {}
// Set up |task_| to have |n| features with the given ordering.
void SetupFeatures(size_t n) {
for (size_t i = 0; i < n; i++) {
LearningTask::ValueDescription desc;
desc.ordering = ordering_;
task_.feature_descriptions.push_back(desc);
}
}
std::unique_ptr<Model> Train(const LearningTask& task,
const TrainingData& data) {
std::unique_ptr<Model> model;
trainer_.Train(
task_, data,
base::BindOnce(
[](std::unique_ptr<Model>* model_out,
std::unique_ptr<Model> model) { *model_out = std::move(model); },
&model));
scoped_task_environment_.RunUntilIdle();
return model;
}
base::test::ScopedTaskEnvironment scoped_task_environment_;
TestRandomNumberGenerator rng_;
RandomTreeTrainer trainer_;
LearningTask task_;
// Feature ordering.
LearningTask::Ordering ordering_;
};
TEST_P(RandomTreeTest, EmptyTrainingDataWorks) {
TrainingData empty;
std::unique_ptr<Model> model = Train(task_, empty);
EXPECT_NE(model.get(), nullptr);
EXPECT_EQ(model->PredictDistribution(FeatureVector()), TargetHistogram());
}
TEST_P(RandomTreeTest, UniformTrainingDataWorks) {
SetupFeatures(2);
LabelledExample example({FeatureValue(123), FeatureValue(456)},
TargetValue(789));
TrainingData training_data;
const size_t n_examples = 10;
for (size_t i = 0; i < n_examples; i++)
training_data.push_back(example);
std::unique_ptr<Model> model = Train(task_, training_data);
// The tree should produce a distribution for one value (our target), which
// has one count.
TargetHistogram distribution = model->PredictDistribution(example.features);
EXPECT_EQ(distribution.size(), 1u);
EXPECT_EQ(distribution[example.target_value], 1.0);
}
TEST_P(RandomTreeTest, SimpleSeparableTrainingData) {
SetupFeatures(1);
TrainingData training_data;
LabelledExample example_1({FeatureValue(123)}, TargetValue(1));
LabelledExample example_2({FeatureValue(456)}, TargetValue(2));
training_data.push_back(example_1);
training_data.push_back(example_2);
std::unique_ptr<Model> model = Train(task_, training_data);
// Each value should have a distribution with one target value with one count.
TargetHistogram distribution = model->PredictDistribution(example_1.features);
EXPECT_NE(model.get(), nullptr);
EXPECT_EQ(distribution.size(), 1u);
EXPECT_EQ(distribution[example_1.target_value], 1u);
distribution = model->PredictDistribution(example_2.features);
EXPECT_EQ(distribution.size(), 1u);
EXPECT_EQ(distribution[example_2.target_value], 1u);
}
TEST_P(RandomTreeTest, ComplexSeparableTrainingData) {
// Building a random tree with numeric splits isn't terribly likely to work,
// so just skip it. Entirely randomized splits are just too random. The
// RandomForest unittests will test them as part of an ensemble.
if (ordering_ == LearningTask::Ordering::kNumeric)
return;
SetupFeatures(4);
// Build a four-feature training set that's completely separable, but one
// needs all four features to do it.
TrainingData training_data;
for (int f1 = 0; f1 < 2; f1++) {
for (int f2 = 0; f2 < 2; f2++) {
for (int f3 = 0; f3 < 2; f3++) {
for (int f4 = 0; f4 < 2; f4++) {
LabelledExample example(
{FeatureValue(f1), FeatureValue(f2), FeatureValue(f3),
FeatureValue(f4)},
TargetValue(f1 * 1 + f2 * 2 + f3 * 4 + f4 * 8));
// Add two copies of each example.
training_data.push_back(example);
training_data.push_back(example);
}
}
}
}
std::unique_ptr<Model> model = Train(task_, training_data);
EXPECT_NE(model.get(), nullptr);
// Each example should have a distribution that selects the right value.
for (const LabelledExample& example : training_data) {
TargetHistogram distribution = model->PredictDistribution(example.features);
TargetValue singular_max;
EXPECT_TRUE(distribution.FindSingularMax(&singular_max));
EXPECT_EQ(singular_max, example.target_value);
}
}
TEST_P(RandomTreeTest, UnseparableTrainingData) {
SetupFeatures(1);
TrainingData training_data;
LabelledExample example_1({FeatureValue(123)}, TargetValue(1));
LabelledExample example_2({FeatureValue(123)}, TargetValue(2));
training_data.push_back(example_1);
training_data.push_back(example_2);
std::unique_ptr<Model> model = Train(task_, training_data);
EXPECT_NE(model.get(), nullptr);
// Each value should have a distribution with two targets with equal counts.
TargetHistogram distribution = model->PredictDistribution(example_1.features);
EXPECT_EQ(distribution.size(), 2u);
EXPECT_EQ(distribution[example_1.target_value], 0.5);
EXPECT_EQ(distribution[example_2.target_value], 0.5);
distribution = model->PredictDistribution(example_2.features);
EXPECT_EQ(distribution.size(), 2u);
EXPECT_EQ(distribution[example_1.target_value], 0.5);
EXPECT_EQ(distribution[example_2.target_value], 0.5);
}
TEST_P(RandomTreeTest, UnknownFeatureValueHandling) {
// Verify how a previously unseen feature value is handled.
SetupFeatures(1);
TrainingData training_data;
LabelledExample example_1({FeatureValue(123)}, TargetValue(1));
LabelledExample example_2({FeatureValue(456)}, TargetValue(2));
training_data.push_back(example_1);
training_data.push_back(example_2);
auto model = Train(task_, training_data);
auto distribution =
model->PredictDistribution(FeatureVector({FeatureValue(789)}));
if (ordering_ == LearningTask::Ordering::kUnordered) {
// OOV data could be split on either feature first, so we don't really know
// which to expect. We assert that there should be exactly one example, but
// whether it's |example_1| or |example_2| isn't clear.
EXPECT_EQ(distribution.size(), 1u);
EXPECT_EQ(distribution[example_1.target_value] +
distribution[example_2.target_value],
1u);
} else {
// The unknown feature is numerically higher than |example_2|, so we
// expect it to fall into that bucket.
EXPECT_EQ(distribution.size(), 1u);
EXPECT_EQ(distribution[example_2.target_value], 1u);
}
}
TEST_P(RandomTreeTest, NumericFeaturesSplitMultipleTimes) {
// Verify that numeric features can be split more than once in the tree.
// This should also pass for nominal features, though it's less interesting.
SetupFeatures(1);
TrainingData training_data;
const int feature_mult = 10;
for (size_t i = 0; i < 4; i++) {
LabelledExample example({FeatureValue(i * feature_mult)}, TargetValue(i));
training_data.push_back(example);
}
std::unique_ptr<Model> model = Train(task_, training_data);
for (size_t i = 0; i < 4; i++) {
// Get a prediction for the |i|-th feature value.
TargetHistogram distribution = model->PredictDistribution(
FeatureVector({FeatureValue(i * feature_mult)}));
// The distribution should have one count that should be correct. If
// the feature isn't split four times, then some feature value will have too
// many or too few counts.
EXPECT_EQ(distribution.total_counts(), 1u);
EXPECT_EQ(distribution[TargetValue(i)], 1u);
}
}
INSTANTIATE_TEST_SUITE_P(RandomTreeTest,
RandomTreeTest,
testing::ValuesIn({LearningTask::Ordering::kUnordered,
LearningTask::Ordering::kNumeric}));
} // namespace learning
} // namespace media