blob: ed8d1679e3c434097d30df8bd1c6e8aa2be56ce1 [file] [log] [blame]
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "media/learning/impl/extra_trees_trainer.h"
#include "base/bind.h"
#include "base/memory/ref_counted.h"
#include "base/test/task_environment.h"
#include "media/learning/impl/fisher_iris_dataset.h"
#include "media/learning/impl/test_random_number_generator.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace media {
namespace learning {
class ExtraTreesTest : public testing::TestWithParam<LearningTask::Ordering> {
public:
ExtraTreesTest() : rng_(0), ordering_(GetParam()) {
trainer_.SetRandomNumberGeneratorForTesting(&rng_);
}
// Set up |task_| to have |n| features with the given ordering.
void SetupFeatures(size_t n) {
for (size_t i = 0; i < n; i++) {
LearningTask::ValueDescription desc;
desc.ordering = ordering_;
task_.feature_descriptions.push_back(desc);
}
}
std::unique_ptr<Model> Train(const LearningTask& task,
const TrainingData& data) {
std::unique_ptr<Model> model;
trainer_.Train(
task_, data,
base::BindOnce(
[](std::unique_ptr<Model>* model_out,
std::unique_ptr<Model> model) { *model_out = std::move(model); },
&model));
task_environment_.RunUntilIdle();
return model;
}
base::test::TaskEnvironment task_environment_;
TestRandomNumberGenerator rng_;
ExtraTreesTrainer trainer_;
LearningTask task_;
// Feature ordering.
LearningTask::Ordering ordering_;
};
TEST_P(ExtraTreesTest, EmptyTrainingDataWorks) {
TrainingData empty;
auto model = Train(task_, empty);
EXPECT_NE(model.get(), nullptr);
EXPECT_EQ(model->PredictDistribution(FeatureVector()), TargetHistogram());
}
TEST_P(ExtraTreesTest, FisherIrisDataset) {
SetupFeatures(4);
FisherIrisDataset iris;
TrainingData training_data = iris.GetTrainingData();
auto model = Train(task_, training_data);
// Verify predictions on the training set, just for sanity.
size_t num_correct = 0;
for (const LabelledExample& example : training_data) {
TargetHistogram distribution = model->PredictDistribution(example.features);
TargetValue predicted_value;
if (distribution.FindSingularMax(&predicted_value) &&
predicted_value == example.target_value) {
num_correct += example.weight;
}
}
// Expect very high accuracy. We should get ~100%.
double train_accuracy = ((double)num_correct) / training_data.total_weight();
EXPECT_GT(train_accuracy, 0.95);
}
TEST_P(ExtraTreesTest, WeightedTrainingSetIsSupported) {
// Create a training set with unseparable data, but give one of them a large
// weight. See if that one wins.
SetupFeatures(1);
LabelledExample example_1({FeatureValue(123)}, TargetValue(1));
LabelledExample example_2({FeatureValue(123)}, TargetValue(2));
const size_t weight = 100;
TrainingData training_data;
example_1.weight = weight;
training_data.push_back(example_1);
// Push many |example_2|'s, which will win without the weights.
training_data.push_back(example_2);
training_data.push_back(example_2);
training_data.push_back(example_2);
training_data.push_back(example_2);
// Create a weighed set with |weight| for each example's weight.
EXPECT_FALSE(training_data.is_unweighted());
auto model = Train(task_, training_data);
// The singular max should be example_1.
TargetHistogram distribution = model->PredictDistribution(example_1.features);
TargetValue predicted_value;
EXPECT_TRUE(distribution.FindSingularMax(&predicted_value));
EXPECT_EQ(predicted_value, example_1.target_value);
}
TEST_P(ExtraTreesTest, RegressionWorks) {
// Create a training set with unseparable data, but give one of them a large
// weight. See if that one wins.
SetupFeatures(2);
LabelledExample example_1({FeatureValue(1), FeatureValue(123)},
TargetValue(1));
LabelledExample example_1_a({FeatureValue(1), FeatureValue(123)},
TargetValue(5));
LabelledExample example_2({FeatureValue(1), FeatureValue(456)},
TargetValue(20));
LabelledExample example_2_a({FeatureValue(1), FeatureValue(456)},
TargetValue(25));
TrainingData training_data;
example_1.weight = 100;
training_data.push_back(example_1);
training_data.push_back(example_1_a);
example_2.weight = 100;
training_data.push_back(example_2);
training_data.push_back(example_2_a);
task_.target_description.ordering = LearningTask::Ordering::kNumeric;
// Create a weighed set with |weight| for each example's weight.
auto model = Train(task_, training_data);
// Make sure that the results are in the right range.
TargetHistogram distribution = model->PredictDistribution(example_1.features);
EXPECT_GT(distribution.Average(), example_1.target_value.value() * 0.95);
EXPECT_LT(distribution.Average(), example_1.target_value.value() * 1.05);
distribution = model->PredictDistribution(example_2.features);
EXPECT_GT(distribution.Average(), example_2.target_value.value() * 0.95);
EXPECT_LT(distribution.Average(), example_2.target_value.value() * 1.05);
}
TEST_P(ExtraTreesTest, RegressionVsBinaryClassification) {
// Create a binary classification task and a regression task that are roughly
// the same. Verify that the results are the same, too. In particular, for
// each set of features, we choose a regression target |pct| between 0 and
// 100. For the corresponding binary classification problem, we add |pct|
// true instances, and 100-|pct| false instances. The predicted averages
// should be roughly the same.
SetupFeatures(3);
TrainingData c_data, r_data;
std::set<LabelledExample> r_examples;
for (size_t i = 0; i < 4 * 4 * 4; i++) {
FeatureValue f1(i & 3);
FeatureValue f2((i >> 2) & 3);
FeatureValue f3((i >> 4) & 3);
int frac = (1.0 * (f1.value() + f2.value() + f3.value())) / 9;
LabelledExample e({f1, f2, f3}, TargetValue(0));
// TODO(liberato): Consider adding noise, and verifying that the model
// predictions are roughly the same as each other, rather than the same as
// the currently noise-free target.
// Push some number of false and some number of true instances that is in
// the right ratio for |frac|.
const int total_examples = 100;
const int positive_examples = total_examples * frac;
e.weight = total_examples - positive_examples;
if (e.weight > 0)
c_data.push_back(e);
e.target_value = TargetValue(1.0);
e.weight = positive_examples;
if (e.weight > 0)
c_data.push_back(e);
// For the regression data, add an example with |frac| directly. Also save
// it so that we can look up the right answer below.
LabelledExample r_example(LabelledExample({f1, f2, f3}, TargetValue(frac)));
r_examples.insert(r_example);
r_data.push_back(r_example);
}
// Train a model on the binary classification task and the regression task.
auto c_model = Train(task_, c_data);
task_.target_description.ordering = LearningTask::Ordering::kNumeric;
auto r_model = Train(task_, r_data);
// Verify that, for all feature combinations, the models roughly agree. Since
// the data is separable, it probably should be exact.
for (auto& r_example : r_examples) {
const FeatureVector& fv = r_example.features;
TargetHistogram c_dist = c_model->PredictDistribution(fv);
EXPECT_LE(c_dist.Average(), r_example.target_value.value() * 1.05);
EXPECT_GE(c_dist.Average(), r_example.target_value.value() * 0.95);
TargetHistogram r_dist = r_model->PredictDistribution(fv);
EXPECT_LE(r_dist.Average(), r_example.target_value.value() * 1.05);
EXPECT_GE(r_dist.Average(), r_example.target_value.value() * 0.95);
}
}
INSTANTIATE_TEST_SUITE_P(ExtraTreesTest,
ExtraTreesTest,
testing::ValuesIn({LearningTask::Ordering::kUnordered,
LearningTask::Ordering::kNumeric}));
} // namespace learning
} // namespace media