blob: 08f4574276f9ec04697998f5265f0e5592283723 [file] [log] [blame]
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "media/learning/impl/random_forest_trainer.h"
#include <set>
#include "base/logging.h"
#include "media/learning/impl/random_forest.h"
#include "media/learning/impl/random_tree_trainer.h"
namespace media {
namespace learning {
RandomForestTrainer::TrainingResult::TrainingResult() = default;
RandomForestTrainer::TrainingResult::~TrainingResult() = default;
RandomForestTrainer::RandomForestTrainer() = default;
RandomForestTrainer::~RandomForestTrainer() = default;
std::unique_ptr<RandomForestTrainer::TrainingResult> RandomForestTrainer::Train(
const LearningTask& task,
const TrainingData& training_data) {
int n_trees = task.rf_number_of_trees;
RandomTreeTrainer tree_trainer(rng());
std::vector<std::unique_ptr<Model>> trees;
trees.reserve(n_trees);
// [example] = sum of all oob predictions
std::map<const TrainingExample*, TargetDistribution> oob_distributions;
// We don't support weighted training data, since bagging with weights is
// very hard to do right without spending time that depends on the value of
// the weights. Since ExactTrees don't need bagging, we skip it.
if (!training_data.is_unweighted())
return std::make_unique<TrainingResult>();
const size_t n_examples = training_data.size();
for (int i = 0; i < n_trees; i++) {
// Collect a bagged training set and oob data for it.
TrainingData bagged_data(training_data.storage());
std::set<const TrainingExample*> bagged_set;
for (size_t e = 0; e < n_examples; e++) {
WeightedExample weighted_example =
training_data[rng()->Generate(n_examples)];
const TrainingExample* example = weighted_example.example();
bagged_data.push_back(example);
bagged_set.insert(example);
}
// Train the tree.
std::unique_ptr<Model> tree = tree_trainer.Train(task, bagged_data);
// Compute OOB distribution.
int n_oob = 0;
for (WeightedExample weighted_example : training_data) {
const TrainingExample* example = weighted_example.example();
if (bagged_set.find(example) != bagged_set.end())
continue;
n_oob++;
TargetDistribution predicted =
tree->PredictDistribution(example->features);
// Add the predicted distribution to this example's total distribution.
// Remember that the distribution is not normalized, so the counts will
// scale with the number of examples.
// TODO(liberato): Should it be normalized before being combined?
TargetDistribution& our_oob_dist = oob_distributions[example];
our_oob_dist += predicted;
}
trees.push_back(std::move(tree));
}
std::unique_ptr<RandomForest> forest =
std::make_unique<RandomForest>(std::move(trees));
// Compute OOB accuracy.
int num_correct = 0;
for (auto& oob_pair : oob_distributions) {
const TrainingExample* example = oob_pair.first;
const TargetDistribution& distribution = oob_pair.second;
// If there are no guesses, or if it's a tie, then count it as wrong.
TargetValue max_value;
if (distribution.FindSingularMax(&max_value) &&
max_value == example->target_value) {
num_correct++;
}
}
std::unique_ptr<TrainingResult> result = std::make_unique<TrainingResult>();
result->model = std::move(forest);
result->oob_correct = num_correct;
result->oob_total = oob_distributions.size();
return result;
}
} // namespace learning
} // namespace media