blob: 99e8a39126952c04f66c11ba5b916db47b134505 [file] [log] [blame]
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <memory>
#include <utility>
#include "base/bind.h"
#include "base/macros.h"
#include "base/memory/ptr_util.h"
#include "base/test/scoped_task_environment.h"
#include "base/threading/thread.h"
#include "media/learning/mojo/mojo_learning_task_controller_service.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace media {
namespace learning {
class MojoLearningTaskControllerServiceTest : public ::testing::Test {
public:
class FakeLearningTaskController : public LearningTaskController {
public:
void BeginObservation(base::UnguessableToken id,
const FeatureVector& features) override {
begin_args_.id_ = id;
begin_args_.features_ = features;
}
void CompleteObservation(base::UnguessableToken id,
const ObservationCompletion& completion) override {
complete_args_.id_ = id;
complete_args_.completion_ = completion;
}
void CancelObservation(base::UnguessableToken id) override {
cancel_args_.id_ = id;
}
struct {
base::UnguessableToken id_;
FeatureVector features_;
} begin_args_;
struct {
base::UnguessableToken id_;
ObservationCompletion completion_;
} complete_args_;
struct {
base::UnguessableToken id_;
} cancel_args_;
};
public:
MojoLearningTaskControllerServiceTest() = default;
~MojoLearningTaskControllerServiceTest() override = default;
void SetUp() override {
std::unique_ptr<FakeLearningTaskController> controller =
std::make_unique<FakeLearningTaskController>();
controller_raw_ = controller.get();
// Add two features.
task_.feature_descriptions.push_back({});
task_.feature_descriptions.push_back({});
// Tell |learning_controller_| to forward to the fake learner impl.
service_ = std::make_unique<MojoLearningTaskControllerService>(
task_, std::move(controller));
}
LearningTask task_;
// Mojo stuff.
base::test::ScopedTaskEnvironment scoped_task_environment_;
FakeLearningTaskController* controller_raw_ = nullptr;
// The learner under test.
std::unique_ptr<MojoLearningTaskControllerService> service_;
};
TEST_F(MojoLearningTaskControllerServiceTest, BeginComplete) {
base::UnguessableToken id = base::UnguessableToken::Create();
FeatureVector features = {FeatureValue(123), FeatureValue(456)};
service_->BeginObservation(id, features);
EXPECT_EQ(id, controller_raw_->begin_args_.id_);
EXPECT_EQ(features, controller_raw_->begin_args_.features_);
ObservationCompletion completion(TargetValue(1234));
service_->CompleteObservation(id, completion);
EXPECT_EQ(id, controller_raw_->complete_args_.id_);
EXPECT_EQ(completion.target_value,
controller_raw_->complete_args_.completion_.target_value);
}
TEST_F(MojoLearningTaskControllerServiceTest, BeginCancel) {
base::UnguessableToken id = base::UnguessableToken::Create();
FeatureVector features = {FeatureValue(123), FeatureValue(456)};
service_->BeginObservation(id, features);
EXPECT_EQ(id, controller_raw_->begin_args_.id_);
EXPECT_EQ(features, controller_raw_->begin_args_.features_);
service_->CancelObservation(id);
EXPECT_EQ(id, controller_raw_->cancel_args_.id_);
}
TEST_F(MojoLearningTaskControllerServiceTest, TooFewFeaturesIsIgnored) {
// A FeatureVector with too few elements should be ignored.
base::UnguessableToken id = base::UnguessableToken::Create();
FeatureVector short_features = {FeatureValue(123)};
service_->BeginObservation(id, short_features);
EXPECT_NE(id, controller_raw_->begin_args_.id_);
EXPECT_EQ(controller_raw_->begin_args_.features_.size(), 0u);
}
TEST_F(MojoLearningTaskControllerServiceTest, TooManyFeaturesIsIgnored) {
// A FeatureVector with too many elements should be ignored.
base::UnguessableToken id = base::UnguessableToken::Create();
FeatureVector long_features = {FeatureValue(123), FeatureValue(456),
FeatureValue(789)};
service_->BeginObservation(id, long_features);
EXPECT_NE(id, controller_raw_->begin_args_.id_);
EXPECT_EQ(controller_raw_->begin_args_.features_.size(), 0u);
}
TEST_F(MojoLearningTaskControllerServiceTest, CompleteWithoutBeginFails) {
base::UnguessableToken id = base::UnguessableToken::Create();
ObservationCompletion completion(TargetValue(1234));
service_->CompleteObservation(id, completion);
EXPECT_NE(id, controller_raw_->complete_args_.id_);
}
TEST_F(MojoLearningTaskControllerServiceTest, CancelWithoutBeginFails) {
base::UnguessableToken id = base::UnguessableToken::Create();
service_->CancelObservation(id);
EXPECT_NE(id, controller_raw_->cancel_args_.id_);
}
} // namespace learning
} // namespace media