blob: 460668c68eba46e30fa788e398eb587b25510598 [file] [log] [blame]
// Copyright 2022 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/segmentation_platform/internal/execution/default_model_manager.h"
#include "base/threading/thread_task_runner_handle.h"
#include "components/segmentation_platform/internal/database/segment_info_database.h"
namespace segmentation_platform {
DefaultModelManager::SegmentInfoWrapper::SegmentInfoWrapper() = default;
DefaultModelManager::SegmentInfoWrapper::~SegmentInfoWrapper() = default;
DefaultModelManager::DefaultModelManager(
ModelProviderFactory* model_provider_factory,
const std::vector<SegmentId>& segment_ids)
: model_provider_factory_(model_provider_factory) {
for (SegmentId segment_id : segment_ids) {
std::unique_ptr<ModelProvider> provider =
model_provider_factory->CreateDefaultProvider(segment_id);
if (!provider)
continue;
default_model_providers_.emplace(
std::make_pair(segment_id, std::move(provider)));
}
}
DefaultModelManager::~DefaultModelManager() = default;
ModelProvider* DefaultModelManager::GetDefaultProvider(SegmentId segment_id) {
auto it = default_model_providers_.find(segment_id);
if (it != default_model_providers_.end())
return it->second.get();
return nullptr;
}
void DefaultModelManager::GetAllSegmentInfoFromDefaultModel(
const std::vector<SegmentId>& segment_ids,
MultipleSegmentInfoCallback callback) {
auto result = std::make_unique<SegmentInfoList>();
std::deque<SegmentId> remaining_segment_ids(segment_ids.begin(),
segment_ids.end());
GetNextSegmentInfoFromDefaultModel(
std::move(result), std::move(remaining_segment_ids), std::move(callback));
}
void DefaultModelManager::GetNextSegmentInfoFromDefaultModel(
std::unique_ptr<SegmentInfoList> result,
std::deque<SegmentId> remaining_segment_ids,
MultipleSegmentInfoCallback callback) {
SegmentId segment_id = SegmentId::OPTIMIZATION_TARGET_UNKNOWN;
ModelProvider* default_provider = nullptr;
// Find the next available default provider.
while (!default_provider && !remaining_segment_ids.empty()) {
segment_id = remaining_segment_ids.front();
remaining_segment_ids.pop_front();
if (default_model_providers_.count(segment_id) == 1) {
default_provider = default_model_providers_[segment_id].get();
break;
}
}
if (!default_provider) {
// If there are no more default providers, return the result so far.
base::ThreadTaskRunnerHandle::Get()->PostTask(
FROM_HERE, base::BindOnce(std::move(callback), std::move(*result)));
return;
}
default_provider->InitAndFetchModel(base::BindRepeating(
&DefaultModelManager::OnFetchDefaultModel, weak_ptr_factory_.GetWeakPtr(),
base::Passed(&result), remaining_segment_ids, base::Passed(&callback)));
}
void DefaultModelManager::OnFetchDefaultModel(
std::unique_ptr<SegmentInfoList> result,
std::deque<SegmentId> remaining_segment_ids,
MultipleSegmentInfoCallback callback,
SegmentId segment_id,
proto::SegmentationModelMetadata metadata,
int64_t model_version) {
auto info = std::make_unique<SegmentInfoWrapper>();
info->segment_source = DefaultModelManager::SegmentSource::DEFAULT_MODEL;
info->segment_info.set_segment_id(segment_id);
info->segment_info.mutable_model_metadata()->CopyFrom(metadata);
info->segment_info.set_model_version(model_version);
result->push_back(std::move(info));
GetNextSegmentInfoFromDefaultModel(
std::move(result), std::move(remaining_segment_ids), std::move(callback));
}
void DefaultModelManager::GetAllSegmentInfoFromBothModels(
const std::vector<SegmentId>& segment_ids,
SegmentInfoDatabase* segment_database,
MultipleSegmentInfoCallback callback) {
segment_database->GetSegmentInfoForSegments(
segment_ids,
base::BindOnce(&DefaultModelManager::OnGetAllSegmentInfoFromDatabase,
weak_ptr_factory_.GetWeakPtr(), segment_ids,
std::move(callback)));
}
void DefaultModelManager::OnGetAllSegmentInfoFromDatabase(
const std::vector<SegmentId>& segment_ids,
MultipleSegmentInfoCallback callback,
std::unique_ptr<SegmentInfoDatabase::SegmentInfoList> segment_infos) {
GetAllSegmentInfoFromDefaultModel(
segment_ids,
base::BindOnce(&DefaultModelManager::OnGetAllSegmentInfoFromDefaultModel,
weak_ptr_factory_.GetWeakPtr(), std::move(callback),
std::move(segment_infos)));
}
void DefaultModelManager::OnGetAllSegmentInfoFromDefaultModel(
MultipleSegmentInfoCallback callback,
std::unique_ptr<SegmentInfoDatabase::SegmentInfoList> segment_infos_from_db,
SegmentInfoList segment_infos_from_default_model) {
SegmentInfoList merged_results;
if (segment_infos_from_db) {
for (auto it : *segment_infos_from_db) {
merged_results.push_back(std::make_unique<SegmentInfoWrapper>());
merged_results.back()->segment_source = SegmentSource::DATABASE;
merged_results.back()->segment_info.Swap(&it.second);
}
}
merged_results.insert(
merged_results.end(),
std::make_move_iterator(segment_infos_from_default_model.begin()),
std::make_move_iterator(segment_infos_from_default_model.end()));
std::move(callback).Run(std::move(merged_results));
}
void DefaultModelManager::SetDefaultProvidersForTesting(
std::map<SegmentId, std::unique_ptr<ModelProvider>>&& providers) {
default_model_providers_ = std::move(providers);
}
} // namespace segmentation_platform