blob: 176cb0003e8f43b406c78ee3eaeca76887e14556 [file] [log] [blame]
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "media/learning/impl/learning_session_impl.h"
#include <utility>
#include "base/bind.h"
#include "base/logging.h"
#include "media/learning/impl/distribution_reporter.h"
#include "media/learning/impl/learning_task_controller_impl.h"
namespace media {
namespace learning {
LearningSessionImpl::LearningSessionImpl()
: controller_factory_(
base::BindRepeating([](const LearningTask& task,
SequenceBoundFeatureProvider feature_provider)
-> std::unique_ptr<LearningTaskController> {
return std::make_unique<LearningTaskControllerImpl>(
task, DistributionReporter::Create(task),
std::move(feature_provider));
})) {}
LearningSessionImpl::~LearningSessionImpl() = default;
void LearningSessionImpl::SetTaskControllerFactoryCBForTesting(
CreateTaskControllerCB cb) {
controller_factory_ = std::move(cb);
}
void LearningSessionImpl::AddExample(const std::string& task_name,
const LabelledExample& example) {
auto iter = task_map_.find(task_name);
if (iter != task_map_.end()) {
// TODO(liberato): We shouldn't be adding examples. We should provide the
// LearningTaskController instead, although ownership gets a bit weird.
iter->second->BeginObservation(example.features)
.Run(example.target_value, example.weight);
}
}
void LearningSessionImpl::RegisterTask(
const LearningTask& task,
SequenceBoundFeatureProvider feature_provider) {
DCHECK(task_map_.count(task.name) == 0);
task_map_.emplace(task.name,
controller_factory_.Run(task, std::move(feature_provider)));
}
} // namespace learning
} // namespace media