blob: 51ecde2a5874f2be7002fc649beb2e2a299f384b [file] [log] [blame]
// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/segmentation_platform/internal/database/test_segment_info_database.h"
#include "base/containers/contains.h"
#include "base/metrics/metrics_hashes.h"
#include "base/ranges/algorithm.h"
#include "components/segmentation_platform/internal/constants.h"
#include "components/segmentation_platform/internal/metadata/metadata_writer.h"
#include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
#include "components/segmentation_platform/public/proto/types.pb.h"
#include "third_party/abseil-cpp/absl/types/optional.h"
namespace segmentation_platform::test {
TestSegmentInfoDatabase::TestSegmentInfoDatabase()
: SegmentInfoDatabase(nullptr, nullptr) {}
TestSegmentInfoDatabase::~TestSegmentInfoDatabase() = default;
void TestSegmentInfoDatabase::Initialize(SuccessCallback callback) {
std::move(callback).Run(true);
}
void TestSegmentInfoDatabase::GetSegmentInfoForSegments(
const base::flat_set<SegmentId>& segment_ids,
MultipleSegmentInfoCallback callback) {
auto result = std::make_unique<SegmentInfoDatabase::SegmentInfoList>();
for (const auto& pair : segment_infos_) {
if (base::Contains(segment_ids, pair.first))
result->emplace_back(pair);
}
std::move(callback).Run(std::move(result));
}
void TestSegmentInfoDatabase::GetSegmentInfo(SegmentId segment_id,
ModelSource model_source,
SegmentInfoCallback callback) {
std::move(callback).Run(GetCachedSegmentInfo(segment_id, model_source));
}
absl::optional<SegmentInfo> TestSegmentInfoDatabase::GetCachedSegmentInfo(
SegmentId segment_id,
ModelSource model_source) {
auto result =
base::ranges::find(segment_infos_, segment_id,
&std::pair<SegmentId, proto::SegmentInfo>::first);
return result == segment_infos_.end() ? absl::nullopt
: absl::make_optional(result->second);
}
void TestSegmentInfoDatabase::UpdateSegment(
SegmentId segment_id,
absl::optional<proto::SegmentInfo> segment_info,
SuccessCallback callback) {
if (segment_info.has_value()) {
proto::SegmentInfo* info = FindOrCreateSegment(segment_id);
info->CopyFrom(segment_info.value());
} else {
// Delete the segment.
auto new_end = std::remove_if(
segment_infos_.begin(), segment_infos_.end(),
[segment_id](const std::pair<SegmentId, proto::SegmentInfo>& pair) {
return pair.first == segment_id;
});
segment_infos_.erase(new_end, segment_infos_.end());
}
std::move(callback).Run(true);
}
void TestSegmentInfoDatabase::SaveSegmentResult(
SegmentId segment_id,
absl::optional<proto::PredictionResult> result,
SuccessCallback callback) {
proto::SegmentInfo* info = FindOrCreateSegment(segment_id);
if (!result.has_value()) {
info->clear_prediction_result();
} else {
info->mutable_prediction_result()->Swap(&result.value());
}
std::move(callback).Run(true);
}
void TestSegmentInfoDatabase::SaveTrainingData(SegmentId segment_id,
const proto::TrainingData& data,
SuccessCallback callback) {
proto::SegmentInfo* info = FindOrCreateSegment(segment_id);
info->add_training_data()->CopyFrom(data);
std::move(callback).Run(true);
}
void TestSegmentInfoDatabase::GetTrainingData(SegmentId segment_id,
ModelSource model_source,
TrainingRequestId request_id,
bool delete_from_db,
TrainingDataCallback callback) {
auto segment_info =
base::ranges::find(segment_infos_, segment_id,
&std::pair<SegmentId, proto::SegmentInfo>::first);
absl::optional<proto::TrainingData> result;
if (segment_info != segment_infos_.end()) {
for (int i = 0; i < segment_info->second.training_data_size(); i++) {
if (segment_info->second.training_data(i).request_id() ==
request_id.GetUnsafeValue()) {
result = segment_info->second.training_data(i);
if (delete_from_db) {
segment_info->second.mutable_training_data()->DeleteSubrange(i, 1);
}
break;
}
}
}
std::move(callback).Run(result);
}
void TestSegmentInfoDatabase::AddUserActionFeature(
SegmentId segment_id,
const std::string& name,
uint64_t bucket_count,
uint64_t tensor_length,
proto::Aggregation aggregation) {
proto::SegmentInfo* info = FindOrCreateSegment(segment_id);
MetadataWriter writer(info->mutable_model_metadata());
MetadataWriter::UMAFeature feature{
.signal_type = proto::SignalType::USER_ACTION,
.name = name.c_str(),
.bucket_count = bucket_count,
.tensor_length = tensor_length,
.aggregation = aggregation,
.accepted_enum_ids = nullptr};
MetadataWriter::UMAFeature features[] = {feature};
writer.AddUmaFeatures(features, 1);
}
void TestSegmentInfoDatabase::AddHistogramValueFeature(
SegmentId segment_id,
const std::string& name,
uint64_t bucket_count,
uint64_t tensor_length,
proto::Aggregation aggregation) {
proto::SegmentInfo* info = FindOrCreateSegment(segment_id);
MetadataWriter writer(info->mutable_model_metadata());
MetadataWriter::UMAFeature feature{
.signal_type = proto::SignalType::HISTOGRAM_VALUE,
.name = name.c_str(),
.bucket_count = bucket_count,
.tensor_length = tensor_length,
.aggregation = aggregation,
.accepted_enum_ids = nullptr};
MetadataWriter::UMAFeature features[] = {feature};
writer.AddUmaFeatures(features, 1);
}
void TestSegmentInfoDatabase::AddHistogramEnumFeature(
SegmentId segment_id,
const std::string& name,
uint64_t bucket_count,
uint64_t tensor_length,
proto::Aggregation aggregation,
const std::vector<int32_t>& accepted_enum_ids) {
proto::SegmentInfo* info = FindOrCreateSegment(segment_id);
MetadataWriter writer(info->mutable_model_metadata());
MetadataWriter::UMAFeature feature{
.signal_type = proto::SignalType::HISTOGRAM_ENUM,
.name = name.c_str(),
.bucket_count = bucket_count,
.tensor_length = tensor_length,
.aggregation = aggregation,
.enum_ids_size = accepted_enum_ids.size(),
.accepted_enum_ids = accepted_enum_ids.data()};
MetadataWriter::UMAFeature features[] = {feature};
writer.AddUmaFeatures(features, 1);
}
void TestSegmentInfoDatabase::AddSqlFeature(
SegmentId segment_id,
const MetadataWriter::SqlFeature& feature) {
proto::SegmentInfo* info = FindOrCreateSegment(segment_id);
MetadataWriter writer(info->mutable_model_metadata());
MetadataWriter::SqlFeature features[] = {feature};
writer.AddSqlFeatures(features, 1);
}
void TestSegmentInfoDatabase::AddPredictionResult(SegmentId segment_id,
float score,
base::Time timestamp) {
proto::SegmentInfo* info = FindOrCreateSegment(segment_id);
proto::PredictionResult* result = info->mutable_prediction_result();
result->clear_result();
result->add_result(score);
result->set_timestamp_us(
timestamp.ToDeltaSinceWindowsEpoch().InMicroseconds());
}
void TestSegmentInfoDatabase::AddDiscreteMapping(
SegmentId segment_id,
const float mappings[][2],
int num_pairs,
const std::string& discrete_mapping_key) {
proto::SegmentInfo* info = FindOrCreateSegment(segment_id);
auto* discrete_mappings_map =
info->mutable_model_metadata()->mutable_discrete_mappings();
auto& discrete_mappings = (*discrete_mappings_map)[discrete_mapping_key];
for (int i = 0; i < num_pairs; i++) {
auto* pair = mappings[i];
auto* entry = discrete_mappings.add_entries();
entry->set_min_result(pair[0]);
entry->set_rank(pair[1]);
}
}
void TestSegmentInfoDatabase::SetBucketDuration(SegmentId segment_id,
uint64_t bucket_duration,
proto::TimeUnit time_unit) {
proto::SegmentInfo* info = FindOrCreateSegment(segment_id);
info->mutable_model_metadata()->set_bucket_duration(bucket_duration);
info->mutable_model_metadata()->set_time_unit(time_unit);
}
proto::SegmentInfo* TestSegmentInfoDatabase::FindOrCreateSegment(
SegmentId segment_id) {
proto::SegmentInfo* info = nullptr;
for (auto& pair : segment_infos_) {
if (pair.first == segment_id) {
info = &pair.second;
break;
}
}
if (info == nullptr) {
segment_infos_.emplace_back(segment_id, proto::SegmentInfo());
info = &segment_infos_.back().second;
info->set_segment_id(segment_id);
}
return info;
}
} // namespace segmentation_platform::test