blob: a31dd67f998cd0e137d793a3bc297a30c0dd138e [file] [log] [blame]
// Copyright 2021 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_DATABASE_SEGMENT_INFO_DATABASE_H_
#define COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_DATABASE_SEGMENT_INFO_DATABASE_H_
#include <vector>
#include "base/callback.h"
#include "components/leveldb_proto/public/proto_database.h"
#include "components/optimization_guide/proto/models.pb.h"
#include "components/segmentation_platform/internal/proto/model_metadata.pb.h"
#include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
#include "third_party/abseil-cpp/absl/types/optional.h"
using optimization_guide::proto::OptimizationTarget;
namespace segmentation_platform {
namespace proto {
class SegmentInfo;
class PredictionResult;
} // namespace proto
// The key to be used to find discrete mapping for segmentation.
constexpr char kSegmentationDiscreteMappingKey[] = "segmentation";
// Represents a DB layer that stores model metadata and prediction results to
// the disk.
class SegmentInfoDatabase {
public:
using SuccessCallback = base::OnceCallback<void(bool)>;
using AllSegmentInfoCallback = base::OnceCallback<void(
std::vector<std::pair<OptimizationTarget, proto::SegmentInfo>>)>;
using SegmentInfoCallback =
base::OnceCallback<void(absl::optional<proto::SegmentInfo>)>;
using SegmentInfoProtoDb = leveldb_proto::ProtoDatabase<proto::SegmentInfo>;
explicit SegmentInfoDatabase(std::unique_ptr<SegmentInfoProtoDb> database);
virtual ~SegmentInfoDatabase();
// Disallow copy/assign.
SegmentInfoDatabase(const SegmentInfoDatabase&) = delete;
SegmentInfoDatabase& operator=(const SegmentInfoDatabase&) = delete;
virtual void Initialize(SuccessCallback callback);
// Convenient method to return combined info for all the segments in the
// database.
virtual void GetAllSegmentInfo(AllSegmentInfoCallback callback);
// Called to get the metadata for a given segment.
virtual void GetSegmentInfo(OptimizationTarget segment_id,
SegmentInfoCallback callback);
// Called to save or update metadata for a segment. The previous data is
// overwritten. If |segment_info| is empty, the segment will be deleted.
// TODO(shaktisahu): How does the client know if a segment is to be deleted?
virtual void UpdateSegment(OptimizationTarget segment_id,
absl::optional<proto::SegmentInfo> segment_info,
SuccessCallback callback);
// Called to write the model execution results for a given segment. It will
// first read the currently stored result, and then overwrite it with
// |result|. If |result| is null, the existing result will be deleted.
virtual void SaveSegmentResult(OptimizationTarget segment_id,
proto::PredictionResult* result,
SuccessCallback callback);
private:
void OnDatabaseInitialized(SuccessCallback callback,
leveldb_proto::Enums::InitStatus status);
void OnAllSegmentInfoLoaded(
AllSegmentInfoCallback callback,
bool success,
std::unique_ptr<std::vector<proto::SegmentInfo>> all_infos);
void OnGetSegmentInfo(SegmentInfoCallback callback,
bool success,
std::unique_ptr<proto::SegmentInfo> info);
void OnGetSegmentInfoForUpdatingResults(
proto::PredictionResult* result,
SuccessCallback callback,
absl::optional<proto::SegmentInfo> segment_info);
std::unique_ptr<SegmentInfoProtoDb> database_;
base::WeakPtrFactory<SegmentInfoDatabase> weak_ptr_factory_{this};
};
} // namespace segmentation_platform
#endif // COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_DATABASE_SEGMENT_INFO_DATABASE_H_