blob: b1c6cd4b7c28772e8530c7dc2548f9428259ec80 [file] [log] [blame]
// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/segmentation_platform/internal/database/segment_info_database.h"
#include "base/memory/raw_ptr.h"
#include "base/strings/string_number_conversions.h"
#include "base/test/task_environment.h"
#include "components/leveldb_proto/public/proto_database.h"
#include "components/leveldb_proto/testing/fake_db.h"
#include "components/segmentation_platform/public/model_provider.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
using InitStatus = leveldb_proto::Enums::InitStatus;
namespace segmentation_platform {
namespace {
// Test Ids.
const SegmentId kSegmentId =
SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
const SegmentId kSegmentId2 = SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHARE;
const ModelSource kServerModelSource = ModelSource::SERVER_MODEL_SOURCE;
std::string ToString(SegmentId segment_id, ModelSource model_source) {
return base::NumberToString(static_cast<int>(segment_id));
}
proto::SegmentInfo CreateSegment(SegmentId segment_id,
ModelSource model_source,
absl::optional<int> result = absl::nullopt) {
proto::SegmentInfo info;
info.set_segment_id(segment_id);
info.set_model_source(model_source);
if (result.has_value()) {
info.mutable_prediction_result()->add_result(result.value());
}
return info;
}
} // namespace
class SegmentInfoDatabaseTest : public testing::Test {
public:
SegmentInfoDatabaseTest() = default;
~SegmentInfoDatabaseTest() override = default;
void OnGetAllSegments(
base::RepeatingClosure closure,
std::unique_ptr<SegmentInfoDatabase::SegmentInfoList> entries) {
get_all_segment_result_.swap(entries);
std::move(closure).Run();
}
void OnGetSegment(absl::optional<proto::SegmentInfo> result) {
get_segment_result_ = result;
}
protected:
void SetUpDB() {
DCHECK(!db_);
DCHECK(!segment_db_);
auto db = std::make_unique<leveldb_proto::test::FakeDB<proto::SegmentInfo>>(
&db_entries_);
db_ = db.get();
auto segment_info_cache = std::make_unique<SegmentInfoCache>();
segment_info_cache_ = segment_info_cache.get();
segment_db_ = std::make_unique<SegmentInfoDatabase>(
std::move(db), std::move(segment_info_cache));
}
void TearDown() override {
db_entries_.clear();
db_ = nullptr;
segment_db_.reset();
}
void VerifyDb(base::flat_set<std::pair<SegmentId, ModelSource>>
expected_ids_with_model_source) {
EXPECT_EQ(expected_ids_with_model_source.size(), db_entries_.size());
for (auto segment_id_and_model_source : expected_ids_with_model_source) {
EXPECT_TRUE(
db_entries_.find(ToString(segment_id_and_model_source.first,
segment_id_and_model_source.second)) !=
db_entries_.end());
}
}
void WriteResult(SegmentId segment_id,
ModelSource model_source,
absl::optional<float> result) {
proto::PredictionResult prediction_result;
if (result.has_value())
prediction_result.add_result(result.value());
segment_db_->SaveSegmentResult(segment_id,
result.has_value()
? absl::make_optional(prediction_result)
: absl::nullopt,
base::DoNothing());
if (!segment_info_cache_->GetSegmentInfo(segment_id).has_value()) {
db_->GetCallback(true);
}
db_->UpdateCallback(true);
}
void WriteTrainingData(SegmentId segment_id,
ModelSource model_source,
int64_t request_id,
float data) {
proto::TrainingData training_data;
training_data.add_inputs(data);
training_data.set_request_id(request_id);
segment_db_->SaveTrainingData(segment_id, training_data, base::DoNothing());
if (!segment_info_cache_->GetSegmentInfo(segment_id).has_value()) {
db_->GetCallback(true);
}
db_->UpdateCallback(true);
}
void VerifyResult(SegmentId segment_id,
ModelSource model_source,
absl::optional<float> result,
absl::optional<std::vector<ModelProvider::Request>>
training_inputs = absl::nullopt) {
segment_db_->GetSegmentInfo(
segment_id, model_source,
base::BindOnce(&SegmentInfoDatabaseTest::OnGetSegment,
base::Unretained(this)));
if (!segment_info_cache_->GetSegmentInfo(segment_id).has_value()) {
db_->GetCallback(true);
}
EXPECT_EQ(segment_id, get_segment_result_->segment_id());
EXPECT_EQ(result.has_value(), get_segment_result_->has_prediction_result());
if (result.has_value()) {
EXPECT_THAT(get_segment_result_->prediction_result().result(),
testing::ElementsAre(result.value()));
}
if (training_inputs.has_value()) {
for (int i = 0; i < get_segment_result_->training_data_size(); i++) {
for (int j = 0; j < get_segment_result_->training_data(i).inputs_size();
j++) {
EXPECT_EQ(training_inputs.value()[i][j],
get_segment_result_->training_data(i).inputs(j));
}
}
}
}
void ExecuteAndVerifyGetSegmentInfoForSegments(
const base::flat_set<SegmentId>& segment_ids) {
base::RunLoop loop;
segment_db_->GetSegmentInfoForSegments(
segment_ids,
base::BindOnce(&SegmentInfoDatabaseTest::OnGetAllSegments,
base::Unretained(this), loop.QuitClosure()));
for (SegmentId segment_id : segment_ids) {
if (!segment_info_cache_->GetSegmentInfo(segment_id).has_value()) {
db_->LoadCallback(true);
break;
}
}
loop.Run();
EXPECT_EQ(segment_ids.size(), get_all_segment_result().size());
int index = 0;
for (SegmentId segment_id : segment_ids) {
EXPECT_EQ(segment_id, get_all_segment_result()[index++].first);
}
}
const SegmentInfoDatabase::SegmentInfoList& get_all_segment_result() const {
return *get_all_segment_result_;
}
base::test::TaskEnvironment task_environment_;
std::unique_ptr<SegmentInfoDatabase::SegmentInfoList> get_all_segment_result_;
absl::optional<proto::SegmentInfo> get_segment_result_;
std::map<std::string, proto::SegmentInfo> db_entries_;
raw_ptr<leveldb_proto::test::FakeDB<proto::SegmentInfo>> db_{nullptr};
std::unique_ptr<SegmentInfoDatabase> segment_db_;
raw_ptr<SegmentInfoCache, DanglingUntriaged> segment_info_cache_;
};
TEST_F(SegmentInfoDatabaseTest, Get) {
// Initialize DB with one entry.
db_entries_.insert(
std::make_pair(ToString(kSegmentId, kServerModelSource),
CreateSegment(kSegmentId, kServerModelSource)));
SetUpDB();
segment_db_->Initialize(base::DoNothing());
db_->InitStatusCallback(leveldb_proto::Enums::InitStatus::kOK);
db_->LoadCallback(true);
VerifyDb({std::make_pair(kSegmentId, kServerModelSource)});
// Get all segments.
ExecuteAndVerifyGetSegmentInfoForSegments({kSegmentId});
// Get a single segment.
segment_db_->GetSegmentInfo(
kSegmentId, kServerModelSource,
base::BindOnce(&SegmentInfoDatabaseTest::OnGetSegment,
base::Unretained(this)));
if (!segment_info_cache_->GetSegmentInfo(kSegmentId).has_value()) {
db_->GetCallback(true);
}
EXPECT_TRUE(get_segment_result_.has_value());
EXPECT_EQ(kSegmentId, get_segment_result_->segment_id());
}
TEST_F(SegmentInfoDatabaseTest, Update) {
// Initialize DB with one entry.
db_entries_.insert(
std::make_pair(ToString(kSegmentId, kServerModelSource),
CreateSegment(kSegmentId, kServerModelSource)));
SetUpDB();
segment_db_->Initialize(base::DoNothing());
db_->InitStatusCallback(leveldb_proto::Enums::InitStatus::kOK);
db_->LoadCallback(true);
// Delete a segment.
segment_db_->UpdateSegment(kSegmentId, absl::nullopt, base::DoNothing());
db_->UpdateCallback(true);
VerifyDb({});
// Insert a segment and verify.
segment_db_->UpdateSegment(kSegmentId,
CreateSegment(kSegmentId, kServerModelSource),
base::DoNothing());
db_->UpdateCallback(true);
VerifyDb({std::make_pair(kSegmentId, kServerModelSource)});
// Insert another segment and verify.
segment_db_->UpdateSegment(kSegmentId2,
CreateSegment(kSegmentId2, kServerModelSource),
base::DoNothing());
db_->UpdateCallback(true);
VerifyDb({std::make_pair(kSegmentId, kServerModelSource),
std::make_pair(kSegmentId2, kServerModelSource)});
// Verify GetSegmentInfoForSegments.
ExecuteAndVerifyGetSegmentInfoForSegments({kSegmentId2});
ExecuteAndVerifyGetSegmentInfoForSegments({kSegmentId});
ExecuteAndVerifyGetSegmentInfoForSegments({kSegmentId, kSegmentId2});
}
TEST_F(SegmentInfoDatabaseTest, UpdateMultipleSegments) {
// Initialize DB with two entry.
db_entries_.insert(
std::make_pair(ToString(kSegmentId, kServerModelSource),
CreateSegment(kSegmentId, kServerModelSource)));
db_entries_.insert(
std::make_pair(ToString(kSegmentId2, kServerModelSource),
CreateSegment(kSegmentId2, kServerModelSource)));
SetUpDB();
segment_db_->Initialize(base::DoNothing());
db_->InitStatusCallback(leveldb_proto::Enums::InitStatus::kOK);
db_->LoadCallback(true);
// Delete both segments.
segment_db_->UpdateMultipleSegments({}, {kSegmentId, kSegmentId2},
base::DoNothing());
db_->UpdateCallback(true);
VerifyDb({});
// Insert multiple segments and verify.
std::vector<std::pair<SegmentId, proto::SegmentInfo>> segments_to_update;
segments_to_update.emplace_back(
kSegmentId, CreateSegment(kSegmentId, kServerModelSource));
segments_to_update.emplace_back(
kSegmentId2, CreateSegment(kSegmentId2, kServerModelSource));
segment_db_->UpdateMultipleSegments(segments_to_update, {},
base::DoNothing());
db_->UpdateCallback(true);
VerifyDb({std::make_pair(kSegmentId, kServerModelSource),
std::make_pair(kSegmentId2, kServerModelSource)});
// Update one of the existing segment and verify.
proto::SegmentInfo segment_info =
CreateSegment(kSegmentId2, kServerModelSource);
segment_info.mutable_prediction_result()->add_result(0.9f);
// Add this entry to `segments_to_update`.
segments_to_update.clear();
segments_to_update.emplace_back(std::make_pair(kSegmentId2, segment_info));
// Call and Verify.
segment_db_->UpdateMultipleSegments(segments_to_update, {},
base::DoNothing());
db_->UpdateCallback(true);
VerifyDb({std::make_pair(kSegmentId, kServerModelSource),
std::make_pair(kSegmentId2, kServerModelSource)});
VerifyResult(kSegmentId2, kServerModelSource, 0.9f);
// Verify GetSegmentInfoForSegments.
ExecuteAndVerifyGetSegmentInfoForSegments({kSegmentId2});
ExecuteAndVerifyGetSegmentInfoForSegments({kSegmentId});
ExecuteAndVerifyGetSegmentInfoForSegments({kSegmentId, kSegmentId2});
}
TEST_F(SegmentInfoDatabaseTest, WriteResult) {
// Initialize DB with cache enabled and one entry.
db_entries_.insert(
std::make_pair(ToString(kSegmentId, kServerModelSource),
CreateSegment(kSegmentId, kServerModelSource)));
VerifyDb({{kSegmentId, kServerModelSource}});
SetUpDB();
segment_db_->Initialize(base::DoNothing());
db_->InitStatusCallback(leveldb_proto::Enums::InitStatus::kOK);
EXPECT_FALSE(segment_info_cache_->GetSegmentInfo(kSegmentId).has_value());
// Verify that all DB entries are loaded into cache on initialization.
db_->LoadCallback(true);
EXPECT_TRUE(segment_info_cache_->GetSegmentInfo(kSegmentId).has_value());
// Update results and verify that db is updated.
WriteResult(kSegmentId, kServerModelSource, 0.4f);
// Verify that cache is updated.
VerifyResult(kSegmentId, kServerModelSource, 0.4f);
// Overwrite results and verify.
WriteResult(kSegmentId, kServerModelSource, 0.9f);
VerifyResult(kSegmentId, kServerModelSource, 0.9f);
// Clear results and verify.
WriteResult(kSegmentId, kServerModelSource, absl::nullopt);
VerifyResult(kSegmentId, kServerModelSource, absl::nullopt);
}
TEST_F(SegmentInfoDatabaseTest, WriteTrainingData) {
// Initialize DB with one entry.
db_entries_.insert(
std::make_pair(ToString(kSegmentId, kServerModelSource),
CreateSegment(kSegmentId, kServerModelSource)));
SetUpDB();
segment_db_->Initialize(base::DoNothing());
db_->InitStatusCallback(leveldb_proto::Enums::InitStatus::kOK);
db_->LoadCallback(true);
EXPECT_TRUE(segment_info_cache_->GetSegmentInfo(kSegmentId).has_value());
std::vector<ModelProvider::Request> expected_training_inputs;
// Add training data and verify.
WriteTrainingData(kSegmentId, kServerModelSource, /*request_id=*/0,
/*data=*/0.4f);
expected_training_inputs.push_back({0.4f});
VerifyResult(kSegmentId, kServerModelSource, absl::nullopt,
expected_training_inputs);
// Add another training data and verify.
int64_t request_id = 1;
WriteTrainingData(kSegmentId, kServerModelSource, request_id, /*data=*/0.9f);
expected_training_inputs.push_back({0.9f});
VerifyResult(kSegmentId, kServerModelSource, absl::nullopt,
expected_training_inputs);
// Remove the last training data and verify.
segment_db_->GetTrainingData(kSegmentId, kServerModelSource,
TrainingRequestId::FromUnsafeValue(request_id),
/*delete_from_db=*/true, base::DoNothing());
expected_training_inputs.pop_back();
VerifyResult(kSegmentId, kServerModelSource, absl::nullopt,
expected_training_inputs);
}
TEST_F(SegmentInfoDatabaseTest, WriteResultForTwoSegments) {
// Initialize DB with two entries.
db_entries_.insert(
std::make_pair(ToString(kSegmentId, kServerModelSource),
CreateSegment(kSegmentId, kServerModelSource)));
db_entries_.insert(
std::make_pair(ToString(kSegmentId2, kServerModelSource),
CreateSegment(kSegmentId2, kServerModelSource)));
SetUpDB();
segment_db_->Initialize(base::DoNothing());
db_->InitStatusCallback(leveldb_proto::Enums::InitStatus::kOK);
db_->LoadCallback(true);
// Update results for first segment.
WriteResult(kSegmentId, kServerModelSource, 0.4f);
// Update results for second segment.
WriteResult(kSegmentId2, kServerModelSource, 0.9f);
// Verify results for both segments.
VerifyResult(kSegmentId, kServerModelSource, 0.4f);
VerifyResult(kSegmentId2, kServerModelSource, 0.9f);
}
} // namespace segmentation_platform