blob: ffe54247ec0fe2ba14dd8641fc692139f5140fc1 [file] [log] [blame]
// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_DATABASE_TEST_SEGMENT_INFO_DATABASE_H_
#define COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_DATABASE_TEST_SEGMENT_INFO_DATABASE_H_
#include <utility>
#include <vector>
#include "base/containers/flat_set.h"
#include "components/segmentation_platform/internal/database/segment_info_database.h"
#include "components/segmentation_platform/internal/database/ukm_types.h"
#include "components/segmentation_platform/internal/metadata/metadata_writer.h"
#include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
#include "components/segmentation_platform/public/proto/aggregation.pb.h"
#include "components/segmentation_platform/public/proto/model_metadata.pb.h"
namespace segmentation_platform::test {
// A fake database with sample entries that can be used for tests.
class TestSegmentInfoDatabase : public SegmentInfoDatabase {
public:
TestSegmentInfoDatabase();
~TestSegmentInfoDatabase() override;
// SegmentInfoDatabase overrides.
void Initialize(SuccessCallback callback) override;
void GetSegmentInfoForSegments(const base::flat_set<SegmentId>& segment_ids,
MultipleSegmentInfoCallback callback) override;
void GetSegmentInfo(SegmentId segment_id,
ModelSource model_source,
SegmentInfoCallback callback) override;
absl::optional<SegmentInfo> GetCachedSegmentInfo(
SegmentId segment_id,
ModelSource model_source) override;
void UpdateSegment(SegmentId segment_id,
absl::optional<proto::SegmentInfo> segment_info,
SuccessCallback callback) override;
void SaveSegmentResult(SegmentId segment_id,
ModelSource model_source,
absl::optional<proto::PredictionResult> result,
SuccessCallback callback) override;
void SaveTrainingData(SegmentId segment_id,
ModelSource model_source,
const proto::TrainingData& data,
SuccessCallback callback) override;
void GetTrainingData(SegmentId segment_id,
ModelSource model_source,
TrainingRequestId request_id,
bool delete_from_db,
TrainingDataCallback callback) override;
// Test helper methods.
void AddUserActionFeature(SegmentId segment_id,
const std::string& user_action,
uint64_t bucket_count,
uint64_t tensor_length,
proto::Aggregation aggregation);
void AddHistogramValueFeature(SegmentId segment_id,
const std::string& histogram,
uint64_t bucket_count,
uint64_t tensor_length,
proto::Aggregation aggregation);
void AddHistogramEnumFeature(SegmentId segment_id,
const std::string& histogram_name,
uint64_t bucket_count,
uint64_t tensor_length,
proto::Aggregation aggregation,
const std::vector<int32_t>& accepted_enum_ids);
void AddSqlFeature(SegmentId segment_id,
const MetadataWriter::SqlFeature& feature);
void AddPredictionResult(SegmentId segment_id,
float score,
base::Time timestamp);
void AddDiscreteMapping(SegmentId segment_id,
const float mappings[][2],
int num_pairs,
const std::string& discrete_mapping_key);
void SetBucketDuration(SegmentId segment_id,
uint64_t bucket_duration,
proto::TimeUnit time_unit);
// Finds a segment with given |segment_id|. Creates one if it doesn't exist.
proto::SegmentInfo* FindOrCreateSegment(SegmentId segment_id);
private:
std::vector<std::pair<SegmentId, proto::SegmentInfo>> segment_infos_;
};
} // namespace segmentation_platform::test
#endif // COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_DATABASE_TEST_SEGMENT_INFO_DATABASE_H_