blob: 39d7e30fe6d3de644f76567a579989a461d9aaf7 [file] [log] [blame]
// Copyright 2018 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "media/learning/impl/learning_session_impl.h"
#include <set>
#include <utility>
#include "base/bind.h"
#include "base/logging.h"
#include "media/learning/impl/distribution_reporter.h"
#include "media/learning/impl/learning_task_controller_impl.h"
namespace media {
namespace learning {
// Allow multiple clients to own an LTC that points to the same underlying LTC.
// Since we don't own the LTC, we also keep track of in-flight observations and
// explicitly cancel them on destruction, since dropping an LTC implies that.
class WeakLearningTaskController : public LearningTaskController {
public:
WeakLearningTaskController(
base::WeakPtr<LearningSessionImpl> weak_session,
base::SequenceBound<LearningTaskController>* controller)
: weak_session_(std::move(weak_session)), controller_(controller) {}
~WeakLearningTaskController() override {
if (!weak_session_)
return;
// Cancel any outstanding observations.
for (auto& id : outstanding_ids_) {
controller_->Post(FROM_HERE, &LearningTaskController::CancelObservation,
id);
}
}
void BeginObservation(base::UnguessableToken id,
const FeatureVector& features) override {
if (!weak_session_)
return;
outstanding_ids_.insert(id);
controller_->Post(FROM_HERE, &LearningTaskController::BeginObservation, id,
features);
}
void CompleteObservation(base::UnguessableToken id,
const ObservationCompletion& completion) override {
if (!weak_session_)
return;
outstanding_ids_.erase(id);
controller_->Post(FROM_HERE, &LearningTaskController::CompleteObservation,
id, completion);
}
void CancelObservation(base::UnguessableToken id) override {
if (!weak_session_)
return;
outstanding_ids_.erase(id);
controller_->Post(FROM_HERE, &LearningTaskController::CancelObservation,
id);
}
base::WeakPtr<LearningSessionImpl> weak_session_;
base::SequenceBound<LearningTaskController>* controller_;
// Set of ids that have been started but not completed / cancelled yet.
std::set<base::UnguessableToken> outstanding_ids_;
};
LearningSessionImpl::LearningSessionImpl(
scoped_refptr<base::SequencedTaskRunner> task_runner)
: task_runner_(std::move(task_runner)),
controller_factory_(base::BindRepeating(
[](scoped_refptr<base::SequencedTaskRunner> task_runner,
const LearningTask& task,
SequenceBoundFeatureProvider feature_provider)
-> base::SequenceBound<LearningTaskController> {
return base::SequenceBound<LearningTaskControllerImpl>(
task_runner, task, DistributionReporter::Create(task),
std::move(feature_provider));
})),
weak_factory_(this) {}
LearningSessionImpl::~LearningSessionImpl() = default;
void LearningSessionImpl::SetTaskControllerFactoryCBForTesting(
CreateTaskControllerCB cb) {
controller_factory_ = std::move(cb);
}
std::unique_ptr<LearningTaskController> LearningSessionImpl::GetController(
const std::string& task_name) {
auto iter = task_map_.find(task_name);
if (iter == task_map_.end())
return nullptr;
// If there were any way to replace / destroy a controller other than when we
// destroy |this|, then this wouldn't be such a good idea.
return std::make_unique<WeakLearningTaskController>(
weak_factory_.GetWeakPtr(), &iter->second);
}
void LearningSessionImpl::RegisterTask(
const LearningTask& task,
SequenceBoundFeatureProvider feature_provider) {
DCHECK(task_map_.count(task.name) == 0);
task_map_.emplace(
task.name,
controller_factory_.Run(task_runner_, task, std::move(feature_provider)));
}
} // namespace learning
} // namespace media