blob: 81e7fd233d14ef00dfa65b95832d1f221cb388c9 [file] [log] [blame]
// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/segmentation_platform/internal/database/segment_info_database.h"
#include "base/callback_helpers.h"
#include "base/containers/contains.h"
#include "base/strings/string_number_conversions.h"
#include "base/threading/thread_task_runner_handle.h"
namespace segmentation_platform {
namespace {
std::string ToString(SegmentId segment_id) {
return base::NumberToString(static_cast<int>(segment_id));
}
std::vector<std::string> SegmentIdsToString(
base::flat_set<SegmentId> segment_ids) {
std::vector<std::string> result;
for (SegmentId segment_id : segment_ids) {
result.emplace_back(ToString(segment_id));
}
return result;
}
} // namespace
SegmentInfoDatabase::SegmentInfoDatabase(
std::unique_ptr<SegmentInfoProtoDb> database,
std::unique_ptr<SegmentInfoCache> cache)
: database_(std::move(database)), cache_(std::move(cache)) {}
SegmentInfoDatabase::~SegmentInfoDatabase() = default;
void SegmentInfoDatabase::Initialize(SuccessCallback callback) {
database_->Init(
leveldb_proto::CreateSimpleOptions(),
base::BindOnce(&SegmentInfoDatabase::OnDatabaseInitialized,
weak_ptr_factory_.GetWeakPtr(), std::move(callback)));
}
void SegmentInfoDatabase::OnMultipleSegmentInfoLoaded(
std::unique_ptr<SegmentInfoList> segments_so_far,
MultipleSegmentInfoCallback callback,
bool success,
std::unique_ptr<std::vector<proto::SegmentInfo>> all_infos) {
if (success && all_infos) {
for (auto& info : *all_infos.get()) {
cache_->UpdateSegmentInfo(info.segment_id(), info);
segments_so_far->emplace_back(
std::make_pair(info.segment_id(), std::move(info)));
}
}
std::move(callback).Run(std::move(segments_so_far));
}
void SegmentInfoDatabase::GetSegmentInfoForSegments(
const base::flat_set<SegmentId>& segment_ids,
MultipleSegmentInfoCallback callback) {
base::flat_set<SegmentId> ids_needing_update;
auto segments_so_far =
cache_->GetSegmentInfoForSegments(segment_ids, ids_needing_update);
if (ids_needing_update.empty()) {
base::ThreadTaskRunnerHandle::Get()->PostTask(
FROM_HERE,
base::BindOnce(std::move(callback), std::move(segments_so_far)));
return;
}
// Converting list of segment ids to string as per database requirement.
std::vector<std::string> keys_to_fetch_from_db =
SegmentIdsToString(ids_needing_update);
database_->LoadEntriesWithFilter(
base::BindRepeating(
[](const std::vector<std::string>& key_dict, const std::string& key) {
return base::Contains(key_dict, key);
},
keys_to_fetch_from_db),
base::BindOnce(&SegmentInfoDatabase::OnMultipleSegmentInfoLoaded,
weak_ptr_factory_.GetWeakPtr(), std::move(segments_so_far),
std::move(callback)));
}
void SegmentInfoDatabase::GetSegmentInfo(SegmentId segment_id,
SegmentInfoCallback callback) {
std::pair<SegmentInfoCache::CachedItemState, absl::optional<SegmentInfo>>
segment_info = cache_->GetSegmentInfo(segment_id);
if (segment_info.first != SegmentInfoCache::CachedItemState::kNotCached) {
std::move(callback).Run(std::move(segment_info.second));
return;
}
database_->GetEntry(ToString(segment_id),
base::BindOnce(&SegmentInfoDatabase::OnGetSegmentInfo,
weak_ptr_factory_.GetWeakPtr(), segment_id,
std::move(callback)));
}
void SegmentInfoDatabase::OnGetSegmentInfo(
SegmentId segment_id,
SegmentInfoCallback callback,
bool success,
std::unique_ptr<proto::SegmentInfo> info) {
cache_->UpdateSegmentInfo(segment_id, (success && info)
? absl::make_optional(*info)
: absl::nullopt);
std::move(callback).Run((success && info) ? absl::make_optional(*info)
: absl::nullopt);
}
void SegmentInfoDatabase::UpdateSegment(
SegmentId segment_id,
absl::optional<proto::SegmentInfo> segment_info,
SuccessCallback callback) {
cache_->UpdateSegmentInfo(segment_id, segment_info);
auto entries_to_save = std::make_unique<
std::vector<std::pair<std::string, proto::SegmentInfo>>>();
auto keys_to_delete = std::make_unique<std::vector<std::string>>();
if (segment_info.has_value()) {
entries_to_save->emplace_back(
std::make_pair(ToString(segment_id), segment_info.value()));
} else {
keys_to_delete->emplace_back(ToString(segment_id));
}
database_->UpdateEntries(std::move(entries_to_save),
std::move(keys_to_delete), std::move(callback));
}
void SegmentInfoDatabase::SaveSegmentResult(
SegmentId segment_id,
absl::optional<proto::PredictionResult> result,
SuccessCallback callback) {
GetSegmentInfo(
segment_id,
base::BindOnce(&SegmentInfoDatabase::OnGetSegmentInfoForUpdatingResults,
weak_ptr_factory_.GetWeakPtr(), result,
std::move(callback)));
}
void SegmentInfoDatabase::OnGetSegmentInfoForUpdatingResults(
absl::optional<proto::PredictionResult> result,
SuccessCallback callback,
absl::optional<proto::SegmentInfo> segment_info) {
// Ignore results if the metadata no longer exists.
if (!segment_info.has_value()) {
std::move(callback).Run(false);
return;
}
// Update results.
if (result.has_value()) {
segment_info->mutable_prediction_result()->CopyFrom(*result);
} else {
segment_info->clear_prediction_result();
}
cache_->UpdateSegmentInfo(segment_info->segment_id(), segment_info);
auto entries_to_save = std::make_unique<
std::vector<std::pair<std::string, proto::SegmentInfo>>>();
entries_to_save->emplace_back(std::make_pair(
ToString(segment_info->segment_id()), std::move(segment_info.value())));
database_->UpdateEntries(std::move(entries_to_save),
std::make_unique<std::vector<std::string>>(),
std::move(callback));
}
void SegmentInfoDatabase::OnDatabaseInitialized(
SuccessCallback callback,
leveldb_proto::Enums::InitStatus status) {
std::move(callback).Run(status == leveldb_proto::Enums::InitStatus::kOK);
}
} // namespace segmentation_platform