blob: b820c6a64d835de0bd7968e06e902184715cd145 [file] [log] [blame]
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "media/learning/impl/extra_trees_trainer.h"
#include <set>
#include "base/bind.h"
#include "base/logging.h"
#include "media/learning/impl/voting_ensemble.h"
namespace media {
namespace learning {
ExtraTreesTrainer::ExtraTreesTrainer() = default;
ExtraTreesTrainer::~ExtraTreesTrainer() = default;
void ExtraTreesTrainer::Train(const LearningTask& task,
const TrainingData& training_data,
TrainedModelCB model_cb) {
// Make sure that there is no training in progress.
DCHECK_EQ(trees_.size(), 0u);
DCHECK_EQ(converter_.get(), nullptr);
task_ = task;
trees_.reserve(task.rf_number_of_trees);
// Instantiate our tree trainer if we haven't already. We do this now only
// so that we can send it our rng, mostly for tests.
// TODO(liberato): We should always take the rng in the ctor, rather than
// via SetRngForTesting. Then we can do this earlier.
if (!tree_trainer_)
tree_trainer_ = std::make_unique<RandomTreeTrainer>(rng());
// We've modified RandomTree to handle nominals, so we don't need to do one-
// hot conversion normally. It's slow. However, the changes to RandomTree
// are only approximately the same thing.
if (task_.use_one_hot_conversion) {
converter_ = std::make_unique<OneHotConverter>(task, training_data);
converted_training_data_ = converter_->Convert(training_data);
task_ = converter_->converted_task();
} else {
converted_training_data_ = training_data;
}
// Start training. Send in nullptr to start the process.
OnRandomTreeModel(std::move(model_cb), nullptr);
}
void ExtraTreesTrainer::OnRandomTreeModel(TrainedModelCB model_cb,
std::unique_ptr<Model> model) {
// Allow a null Model to make it easy to start training.
if (model)
trees_.push_back(std::move(model));
// If this is the last tree, then return the finished model.
if (trees_.size() == task_.rf_number_of_trees) {
std::unique_ptr<Model> model =
std::make_unique<VotingEnsemble>(std::move(trees_));
// If we have a converter, then wrap everything in a ConvertingModel.
if (converter_) {
model = std::make_unique<ConvertingModel>(std::move(converter_),
std::move(model));
}
std::move(model_cb).Run(std::move(model));
return;
}
// Train the next tree.
auto cb = base::BindOnce(&ExtraTreesTrainer::OnRandomTreeModel, AsWeakPtr(),
std::move(model_cb));
tree_trainer_->Train(task_, converted_training_data_, std::move(cb));
}
} // namespace learning
} // namespace media