blob: 2f1c89de2f846dcb9939b6df9dbf85874851f6a7 [file] [log] [blame]
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "media/learning/impl/random_tree_trainer.h"
#include <math.h>
#include "base/bind.h"
#include "base/logging.h"
#include "base/optional.h"
namespace media {
namespace learning {
// static
TrainingAlgorithmCB RandomTreeTrainer::GetTrainingAlgorithmCB(
const LearningTask& task) {
return base::BindRepeating(
[](LearningTask task, TrainingData training_data,
TrainedModelCB model_cb) {
std::move(model_cb).Run(RandomTreeTrainer().Train(task, training_data));
},
task);
}
RandomTreeTrainer::Split::Split() = default;
RandomTreeTrainer::Split::Split(int index) : split_index(index) {}
RandomTreeTrainer::Split::Split(Split&& rhs) = default;
RandomTreeTrainer::Split::~Split() = default;
RandomTreeTrainer::Split& RandomTreeTrainer::Split::operator=(Split&& rhs) =
default;
RandomTreeTrainer::Split::BranchInfo::BranchInfo(
scoped_refptr<TrainingDataStorage> storage)
: training_data(std::move(storage)) {}
RandomTreeTrainer::Split::BranchInfo::BranchInfo(BranchInfo&& rhs) = default;
RandomTreeTrainer::Split::BranchInfo::~BranchInfo() = default;
struct InteriorNode : public Model {
InteriorNode(const LearningTask& task,
int split_index,
FeatureValue split_point)
: split_index_(split_index),
rt_unknown_value_handling_(task.rt_unknown_value_handling),
ordering_(task.feature_descriptions[split_index].ordering),
split_point_(split_point) {}
// Model
TargetDistribution PredictDistribution(
const FeatureVector& features) override {
// Figure out what feature value we should use for the split.
FeatureValue f;
switch (ordering_) {
case LearningTask::Ordering::kUnordered:
// Use the nominal value directly.
f = features[split_index_];
break;
case LearningTask::Ordering::kNumeric:
// Use 0 for "<=" and 1 for ">".
f = FeatureValue(features[split_index_] > split_point_);
break;
}
auto iter = children_.find(f);
// If we've never seen this feature value, then average all our branches.
// This is an attempt to mimic one-hot encoding, where we'll take the zero
// branch but it depends on the tree structure which of the one-hot values
// we're choosing.
if (iter == children_.end()) {
switch (rt_unknown_value_handling_) {
case LearningTask::RTUnknownValueHandling::kEmptyDistribution:
return TargetDistribution();
case LearningTask::RTUnknownValueHandling::kUseAllSplits:
return PredictDistributionWithMissingValues(features);
}
}
return iter->second->PredictDistribution(features);
}
TargetDistribution PredictDistributionWithMissingValues(
const FeatureVector& features) {
TargetDistribution total;
for (auto& child_pair : children_) {
TargetDistribution predicted =
child_pair.second->PredictDistribution(features);
// TODO(liberato): Normalize? Weight?
total += predicted;
}
return total;
}
// Add |child| has the node for feature value |v|.
void AddChild(FeatureValue v, std::unique_ptr<Model> child) {
DCHECK_EQ(children_.count(v), 0u);
children_.emplace(v, std::move(child));
}
private:
// Feature value that we split on.
int split_index_ = -1;
base::flat_map<FeatureValue, std::unique_ptr<Model>> children_;
// How we handle unknown values.
LearningTask::RTUnknownValueHandling rt_unknown_value_handling_;
// How is our feature value ordered?
LearningTask::Ordering ordering_;
// For kNumeric features, this is the split point.
FeatureValue split_point_;
};
struct LeafNode : public Model {
LeafNode(const TrainingData& training_data) {
for (WeightedExample example : training_data)
distribution_ += example;
}
// TreeNode
TargetDistribution PredictDistribution(const FeatureVector&) override {
return distribution_;
}
private:
TargetDistribution distribution_;
};
RandomTreeTrainer::RandomTreeTrainer(RandomNumberGenerator* rng)
: HasRandomNumberGenerator(rng) {}
RandomTreeTrainer::~RandomTreeTrainer() = default;
std::unique_ptr<Model> RandomTreeTrainer::Train(
const LearningTask& task,
const TrainingData& training_data) {
if (training_data.empty())
return std::make_unique<LeafNode>(training_data);
return Build(task, training_data, FeatureSet());
}
std::unique_ptr<Model> RandomTreeTrainer::Build(
const LearningTask& task,
const TrainingData& training_data,
const FeatureSet& used_set) {
DCHECK_GT(training_data.total_weight(), 0u);
// TODO(liberato): Does it help if we refuse to split without an info gain?
Split best_potential_split;
// Select the feature subset to consider at this leaf.
FeatureSet feature_candidates;
for (size_t i = 0; i < training_data.begin()->example()->features.size();
i++) {
if (used_set.find(i) != used_set.end())
continue;
feature_candidates.insert(i);
}
// TODO(liberato): Let our caller override this.
const size_t features_per_split =
std::min(static_cast<int>(sqrt(feature_candidates.size())), 3);
while (feature_candidates.size() > features_per_split) {
// Remove a random feature.
size_t which = rng()->Generate(feature_candidates.size());
auto iter = feature_candidates.begin();
for (; which; which--, iter++)
;
feature_candidates.erase(iter);
}
// Find the best split among the candidates that we have.
for (int i : feature_candidates) {
Split potential_split = ConstructSplit(task, training_data, i);
if (potential_split.nats_remaining < best_potential_split.nats_remaining) {
best_potential_split = std::move(potential_split);
}
}
// Note that we can have a split with no index (i.e., no features left, or no
// feature was an improvement in nats), or with a single index (had features,
// but all had the same value). Either way, we should end up with a leaf.
if (best_potential_split.branch_infos.size() < 2) {
// Stop when there is no more tree.
return std::make_unique<LeafNode>(training_data);
}
// Build an interior node
std::unique_ptr<InteriorNode> node = std::make_unique<InteriorNode>(
task, best_potential_split.split_index, best_potential_split.split_point);
// Don't let the subtree use this feature if this is nominal split, since
// there's nothing left to split. For numeric splits, we might want to split
// it further. Note that if there is only one branch for this split, then
// we returned a leaf anyway.
FeatureSet new_used_set(used_set);
if (task.feature_descriptions[best_potential_split.split_index].ordering ==
LearningTask::Ordering::kUnordered) {
new_used_set.insert(best_potential_split.split_index);
}
for (auto& branch_iter : best_potential_split.branch_infos) {
node->AddChild(branch_iter.first,
Build(task, branch_iter.second.training_data, new_used_set));
}
return node;
}
RandomTreeTrainer::Split RandomTreeTrainer::ConstructSplit(
const LearningTask& task,
const TrainingData& training_data,
int index) {
// We should not be given a training set of size 0, since there's no need to
// check an empty split.
DCHECK_GT(training_data.total_weight(), 0u);
Split split(index);
base::Optional<FeatureValue> split_point;
// For a numeric split, find the split point. Otherwise, we'll split on every
// nominal value that this feature has in |training_data|.
if (task.feature_descriptions[index].ordering ==
LearningTask::Ordering::kNumeric) {
split_point = FindNumericSplitPoint(split.split_index, training_data);
split.split_point = *split_point;
}
// Find the split's feature values and construct the training set for each.
// I think we want to iterate on the underlying vector, and look up the int in
// the training data directly.
for (WeightedExample weighted_example : training_data) {
const TrainingExample* example = weighted_example.example();
// Get the value of the |index|-th feature for |example|.
FeatureValue v_i = example->features[split.split_index];
// Figure out what value this example would use for splitting. For nominal,
// it's just |v_i|. For numeric, it's whether |v_i| is <= the split point
// or not (0 for <=, 1 for >).
FeatureValue split_feature;
if (split_point)
split_feature = FeatureValue(v_i > *split_point);
else
split_feature = v_i;
// Add |v_i| to the right training set. Remember that emplace will do
// nothing if the key already exists.
auto result = split.branch_infos.emplace(
split_feature, Split::BranchInfo(training_data.storage()));
auto iter = result.first;
Split::BranchInfo& branch_info = iter->second;
branch_info.training_data.push_back(example);
branch_info.target_distribution += weighted_example;
}
// Compute the nats given that we're at this node.
split.nats_remaining = 0;
for (auto& info_iter : split.branch_infos) {
Split::BranchInfo& branch_info = info_iter.second;
const double total_counts = branch_info.target_distribution.total_counts();
// |p_branch| is the probability of following this branch.
const double p_branch =
((double)total_counts) / training_data.total_weight();
for (auto& iter : branch_info.target_distribution) {
double p = iter.second / total_counts;
// p*log(p) is the expected nats if the answer is |iter|. We multiply
// that by the probability of being in this bucket at all.
split.nats_remaining -= (p * log(p)) * p_branch;
}
}
return split;
}
FeatureValue RandomTreeTrainer::FindNumericSplitPoint(
size_t index,
const TrainingData& training_data) {
// We should not be given a training set of size 0, since there's no need to
// check an empty split.
DCHECK_GT(training_data.total_weight(), 0u);
// We should either (a) choose the single best split point given all our
// training data (i.e., choosing between the splits that are equally between
// adjacent feature values), or (b) choose the best split point by drawing
// uniformly over the range that contains our feature values. (a) is
// appropriate with RandomForest, while (b) is appropriate with ExtraTrees.
FeatureValue v_min = (*training_data.begin()).example()->features[index];
FeatureValue v_max = (*training_data.begin()).example()->features[index];
for (WeightedExample weighted_example : training_data) {
const TrainingExample* example = weighted_example.example();
// Get the value of the |index|-th feature for
FeatureValue v_i = example->features[index];
if (v_i < v_min)
v_min = v_i;
if (v_i > v_max)
v_max = v_i;
}
FeatureValue v_split;
if (v_max == v_min) {
// Pick |v_split| to return a trivial split, so that this ends up as a
// leaf node anyway.
v_split = v_max;
} else {
// Choose a random split point. Note that we want to end up with two
// buckets, so we don't have a trivial split. By picking [v_min, v_max),
// |v_min| will always be in one bucket and |v_max| will always not be.
v_split = FeatureValue((rand() % (v_max.value() - v_min.value())) +
v_min.value());
}
return v_split;
}
} // namespace learning
} // namespace media