blob: 987d7c16dd88e78d3e00678661fdd0a050789c2c [file] [log] [blame]
// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <memory>
#include <vector>
#include "base/containers/flat_set.h"
#include "base/functional/callback.h"
#include "components/leveldb_proto/public/proto_database.h"
#include "components/segmentation_platform/internal/database/segment_info_cache.h"
#include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
#include "components/segmentation_platform/public/proto/model_metadata.pb.h"
#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
#include "components/segmentation_platform/public/trigger.h"
#include "third_party/abseil-cpp/absl/types/optional.h"
namespace segmentation_platform {
using proto::ModelSource;
using proto::SegmentId;
namespace proto {
class SegmentInfo;
class PredictionResult;
} // namespace proto
// Represents a DB layer that stores model metadata and prediction results to
// the disk. At startup, all stored data is loaded into cache. Following
// interactions with this class will be directly from/to the cache. The disk
// will be updated asynchronously by this class.
class SegmentInfoDatabase {
using SuccessCallback = base::OnceCallback<void(bool)>;
using SegmentInfoList = std::vector<std::pair<SegmentId, proto::SegmentInfo>>;
using MultipleSegmentInfoCallback =
using SegmentInfoCallback =
using SegmentInfoProtoDb = leveldb_proto::ProtoDatabase<proto::SegmentInfo>;
using TrainingDataCallback =
explicit SegmentInfoDatabase(std::unique_ptr<SegmentInfoProtoDb> database,
std::unique_ptr<SegmentInfoCache> cache);
virtual ~SegmentInfoDatabase();
// Disallow copy/assign.
SegmentInfoDatabase(const SegmentInfoDatabase&) = delete;
SegmentInfoDatabase& operator=(const SegmentInfoDatabase&) = delete;
virtual void Initialize(SuccessCallback callback);
// Called to get metadata for a given list of server model segments.
virtual void GetSegmentInfoForSegments(
const base::flat_set<SegmentId>& segment_ids,
MultipleSegmentInfoCallback callback);
// Called to get the metadata for a given segment. ModelSource defines whether
// to give metadata from server/default models for the given segment.
virtual void GetSegmentInfo(SegmentId segment_id,
ModelSource model_source,
SegmentInfoCallback callback);
virtual absl::optional<SegmentInfo> GetCachedSegmentInfo(
SegmentId segment_id,
ModelSource model_source);
// Called to get the training data for a given segment with given model source
// and request ID. If delete_from_db is set to true, it will delete the
// corresponding entry in the cache and in the database.
virtual void GetTrainingData(SegmentId segment_id,
ModelSource model_source,
TrainingRequestId request_id,
bool delete_from_db,
TrainingDataCallback callback);
// Called to save or update metadata for a segment. The previous data is
// overwritten. If |segment_info| is empty, the segment will be deleted.
// Updates are written to the cache and callback is returned to the client.
// The database will be updated asynchronously after.
// TODO(shaktisahu): How does the client know if a segment is to be deleted?
virtual void UpdateSegment(SegmentId segment_id,
absl::optional<proto::SegmentInfo> segment_info,
SuccessCallback callback);
// Called to save or update metadata for multiple segments in a single
// database call. The previous data for all the provided segments is
// overwritten with new data. `segments_to_delete` includes list of
// segment ids to be deleted from the database.
virtual void UpdateMultipleSegments(
const SegmentInfoList& segments_to_update,
const std::vector<SegmentId>& segments_to_delete,
SuccessCallback callback);
// Called to write the model execution results for a given segment. It will
// first read the currently stored result, and then overwrite it with
// |result|. If |result| is null, the existing result will be deleted.
virtual void SaveSegmentResult(SegmentId segment_id,
absl::optional<proto::PredictionResult> result,
SuccessCallback callback);
// Called to write partial training data for a given segment. New training
// data are appended to the existing ones.
virtual void SaveTrainingData(SegmentId segment_id,
const proto::TrainingData& data,
SuccessCallback callback);
void OnDatabaseInitialized(SuccessCallback callback,
leveldb_proto::Enums::InitStatus status);
void OnLoadAllEntries(
SuccessCallback callback,
bool success,
std::unique_ptr<std::vector<proto::SegmentInfo>> all_infos);
std::unique_ptr<SegmentInfoProtoDb> database_;
std::unique_ptr<SegmentInfoCache> cache_;
base::WeakPtrFactory<SegmentInfoDatabase> weak_ptr_factory_{this};
} // namespace segmentation_platform