blob: 3679bef6c4e11e53545be0f21e908fc7eb2d120b [file] [log] [blame]
// Copyright 2022 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/segmentation_platform/internal/data_collection/training_data_cache.h"
#include "base/time/time.h"
namespace segmentation_platform {
using proto::SegmentId;
using TrainingData = proto::TrainingData;
TrainingDataCache::TrainingDataCache(SegmentInfoDatabase* segment_info_database)
: segment_info_database_(segment_info_database) {}
TrainingDataCache::~TrainingDataCache() = default;
void TrainingDataCache::StoreInputs(proto::SegmentId segment_id,
const proto::TrainingData& data,
bool save_to_db) {
if (save_to_db) {
segment_info_database_->SaveTrainingData(segment_id, std::move(data),
base::DoNothing());
} else {
cache[segment_id][TrainingRequestId::FromUnsafeValue(data.request_id())] =
std::move(data);
}
}
void TrainingDataCache::GetInputsAndDelete(SegmentId segment_id,
TrainingRequestId request_id,
TrainingDataCallback callback) {
absl::optional<TrainingData> result;
if (cache.contains(segment_id) && cache[segment_id].contains(request_id)) {
// TrainingRequestId found from cache, return and delete the cache entry.
auto& segment_data = cache[segment_id];
auto it = segment_data.find(request_id);
result = std::move(it->second);
segment_data.erase(it);
std::move(callback).Run(result);
} else {
// TODO (ritikagup@) : Add handling for default models, if required.
segment_info_database_->GetTrainingData(
segment_id, proto::ModelSource::SERVER_MODEL_SOURCE, request_id,
/*delete_from_db=*/true, std::move(callback));
}
}
absl::optional<TrainingRequestId> TrainingDataCache::GetRequestId(
proto::SegmentId segment_id) {
// TODO(haileywang): Add a metric to record how many request at a given time
// every time this function is triggered.
absl::optional<TrainingRequestId> request_id;
auto it = cache.find(segment_id);
if (it == cache.end() or it->second.size() == 0) {
return request_id;
}
return it->second.begin()->first;
}
TrainingRequestId TrainingDataCache::GenerateNextId() {
return request_id_generator.GenerateNextId();
}
} // namespace segmentation_platform