blob: 0235eef07d4747aec8a0aeeec75d52cc77bf6e86 [file] [log] [blame]
// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_EXECUTION_DEFAULT_MODEL_MANAGER_H_
#define COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_EXECUTION_DEFAULT_MODEL_MANAGER_H_
#include <deque>
#include <map>
#include <memory>
#include <set>
#include <vector>
#include "base/containers/flat_map.h"
#include "base/functional/callback.h"
#include "base/logging.h"
#include "components/segmentation_platform/internal/database/segment_info_database.h"
#include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
#include "components/segmentation_platform/public/model_provider.h"
#include "components/segmentation_platform/public/proto/model_metadata.pb.h"
#include "third_party/abseil-cpp/absl/types/optional.h"
namespace segmentation_platform {
using proto::SegmentId;
class SegmentInfoDatabase;
// DefaultModelManager provides support to query all default models available.
// It also provides useful methods to combine results from both the database and
// the default model.
class DefaultModelManager {
public:
DefaultModelManager(ModelProviderFactory* model_provider_factory,
const base::flat_set<SegmentId>& segment_ids);
virtual ~DefaultModelManager();
// Disallow copy/assign.
DefaultModelManager(const DefaultModelManager&) = delete;
DefaultModelManager& operator=(const DefaultModelManager&) = delete;
// Callback for returning a list of segment infos associated with IDs.
// The same segment ID can be repeated multiple times.
enum class SegmentSource {
DATABASE,
DEFAULT_MODEL,
};
struct SegmentInfoWrapper {
SegmentInfoWrapper();
~SegmentInfoWrapper();
SegmentInfoWrapper(const SegmentInfoWrapper&) = delete;
SegmentInfoWrapper& operator=(const SegmentInfoWrapper&) = delete;
SegmentSource segment_source;
proto::SegmentInfo segment_info;
};
using SegmentInfoList = std::vector<std::unique_ptr<SegmentInfoWrapper>>;
using MultipleSegmentInfoCallback = base::OnceCallback<void(SegmentInfoList)>;
// Utility function to get the segment info from both the database and the
// default model for a given set of segment IDs. The result can contain
// the same segment ID multiple times.
virtual void GetAllSegmentInfoFromBothModels(
const base::flat_set<SegmentId>& segment_ids,
SegmentInfoDatabase* segment_database,
MultipleSegmentInfoCallback callback);
// Called to get the segment info from the default model for a given set of
// segment IDs.
virtual void GetAllSegmentInfoFromDefaultModel(
const base::flat_set<SegmentId>& segment_ids,
MultipleSegmentInfoCallback callback);
// Returns the default provider or `nulllptr` when unavailable.
ModelProvider* GetDefaultProvider(SegmentId segment_id);
void SetDefaultProvidersForTesting(
std::map<SegmentId, std::unique_ptr<ModelProvider>>&& providers);
private:
void GetNextSegmentInfoFromDefaultModel(
std::unique_ptr<SegmentInfoList> result,
std::deque<SegmentId> remaining_segment_ids,
MultipleSegmentInfoCallback callback);
void OnFetchDefaultModel(std::unique_ptr<SegmentInfoList> result,
std::deque<SegmentId> remaining_segment_ids,
MultipleSegmentInfoCallback callback,
SegmentId segment_id,
proto::SegmentationModelMetadata metadata,
int64_t model_version);
void OnGetAllSegmentInfoFromDatabase(
const base::flat_set<SegmentId>& segment_ids,
MultipleSegmentInfoCallback callback,
std::unique_ptr<SegmentInfoDatabase::SegmentInfoList> segment_infos);
void OnGetAllSegmentInfoFromDefaultModel(
MultipleSegmentInfoCallback callback,
std::unique_ptr<SegmentInfoDatabase::SegmentInfoList>
segment_infos_from_db,
SegmentInfoList segment_infos_from_default_model);
// Default model providers.
std::map<SegmentId, std::unique_ptr<ModelProvider>> default_model_providers_;
const raw_ptr<ModelProviderFactory, DanglingUntriaged>
model_provider_factory_;
base::WeakPtrFactory<DefaultModelManager> weak_ptr_factory_{this};
};
} // namespace segmentation_platform
#endif // COMPONENTS_SEGMENTATION_PLATFORM_INTERNAL_EXECUTION_DEFAULT_MODEL_MANAGER_H_