blob: 2e9686a92ae04136d137f4cd26f8c927004aaf45 [file] [log] [blame]
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/segmentation_platform/internal/execution/model_manager_impl.h"
#include <map>
#include <memory>
#include "base/functional/bind.h"
#include "base/functional/callback_helpers.h"
#include "base/time/clock.h"
#include "base/time/time.h"
#include "base/trace_event/typed_macros.h"
#include "components/segmentation_platform/internal/execution/model_manager.h"
#include "components/segmentation_platform/internal/metadata/metadata_utils.h"
#include "components/segmentation_platform/internal/proto/model_prediction.pb.h"
#include "components/segmentation_platform/internal/stats.h"
#include "components/segmentation_platform/public/model_provider.h"
#include "components/segmentation_platform/public/proto/model_metadata.pb.h"
#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
#include "third_party/abseil-cpp/absl/types/optional.h"
namespace optimization_guide {
class OptimizationGuideModelProvider;
} // namespace optimization_guide
namespace segmentation_platform {
ModelManagerImpl::ModelManagerImpl(
const base::flat_set<SegmentId>& segment_ids,
ModelProviderFactory* model_provider_factory,
base::Clock* clock,
SegmentInfoDatabase* segment_database,
const SegmentationModelUpdatedCallback& model_updated_callback)
: segment_ids_(segment_ids),
model_provider_factory_(model_provider_factory),
clock_(clock),
segment_database_(segment_database),
model_updated_callback_(model_updated_callback) {}
void ModelManagerImpl::Initialize() {
for (SegmentId segment_id : segment_ids_) {
// Server models
std::unique_ptr<ModelProvider> provider =
model_provider_factory_->CreateProvider(segment_id);
provider->InitAndFetchModel(base::BindRepeating(
&ModelManagerImpl::OnSegmentationModelUpdated,
weak_ptr_factory_.GetWeakPtr(), ModelSource::SERVER_MODEL_SOURCE));
model_providers_.emplace(
std::make_pair(segment_id, ModelSource::SERVER_MODEL_SOURCE),
std::move(provider));
// Default models
std::unique_ptr<DefaultModelProvider> default_provider =
model_provider_factory_->CreateDefaultProvider(segment_id);
if (!default_provider) {
segment_database_->UpdateSegment(segment_id,
ModelSource::DEFAULT_MODEL_SOURCE,
absl::nullopt, base::DoNothing());
continue;
}
std::unique_ptr<DefaultModelProvider::ModelConfig> model_config =
default_provider->GetModelConfig();
model_providers_.emplace(
std::make_pair(segment_id, ModelSource::DEFAULT_MODEL_SOURCE),
std::move(default_provider));
OnSegmentationModelUpdated(ModelSource::DEFAULT_MODEL_SOURCE, segment_id,
std::move(model_config->metadata),
model_config->model_version);
}
}
ModelManagerImpl::~ModelManagerImpl() = default;
ModelProvider* ModelManagerImpl::GetModelProvider(
proto::SegmentId segment_id,
proto::ModelSource model_source) {
auto it = model_providers_.find(std::make_pair(segment_id, model_source));
if (it == model_providers_.end()) {
return nullptr;
}
return it->second.get();
}
void ModelManagerImpl::SetSegmentationModelUpdatedCallbackForTesting(
ModelManager::SegmentationModelUpdatedCallback model_updated_callback) {
model_updated_callback_ = model_updated_callback;
}
void ModelManagerImpl::OnSegmentationModelUpdated(
proto::ModelSource model_source,
proto::SegmentId segment_id,
absl::optional<proto::SegmentationModelMetadata> metadata,
int64_t model_version) {
TRACE_EVENT("segmentation_platform",
"ModelManagerImpl::OnSegmentationModelUpdated");
stats::RecordModelDeliveryReceived(segment_id, model_source);
if (segment_id == proto::SegmentId::OPTIMIZATION_TARGET_UNKNOWN) {
return;
}
if (!metadata.has_value()) {
const auto* deleted_segment =
segment_database_->GetCachedSegmentInfo(segment_id, model_source);
if (!deleted_segment) {
return;
}
segment_database_->UpdateSegment(
segment_id, model_source, absl::nullopt,
base::BindOnce(&ModelManagerImpl::OnSegmentInfoDeleted,
weak_ptr_factory_.GetWeakPtr(), segment_id, model_source,
deleted_segment->model_version()));
return;
}
// Set or overwrite name hashes for metadata features based on the name
// field.
metadata_utils::SetFeatureNameHashesFromName(&metadata.value());
const auto* old_segment_info =
segment_database_->GetCachedSegmentInfo(segment_id, model_source);
OnSegmentInfoFetchedForModelUpdate(segment_id, model_source,
std::move(metadata.value()), model_version,
old_segment_info);
}
void ModelManagerImpl::OnSegmentInfoFetchedForModelUpdate(
proto::SegmentId segment_id,
proto::ModelSource model_source,
proto::SegmentationModelMetadata metadata,
int64_t model_version,
const proto::SegmentInfo* old_segment_info) {
TRACE_EVENT("segmentation_platform",
"ModelManagerImpl::OnSegmentInfoFetchedForModelUpdate");
proto::SegmentInfo new_segment_info;
new_segment_info.set_segment_id(segment_id);
new_segment_info.set_model_source(model_source);
// Inject the newly updated metadata into the new SegmentInfo.
auto* new_metadata = new_segment_info.mutable_model_metadata();
new_metadata->Swap(&metadata);
new_segment_info.set_model_version(model_version);
int64_t new_model_update_time_s =
clock_->Now().ToDeltaSinceWindowsEpoch().InSeconds();
// If we find an existing SegmentInfo in the database, we can verify that it
// is valid, and we can copy over the PredictionResult to the new version
// we are creating.
absl::optional<int64_t> old_model_version;
if (old_segment_info) {
// The retrieved SegmentInfo's ID should match the one we looked up,
// otherwise the DB has not upheld its contract.
// If does not match, we should just overwrite the old entry with one
// that has a matching segment ID, otherwise we will keep ignoring it
// forever and never be able to clean it up.
stats::RecordModelDeliverySegmentIdMatches(
new_segment_info.segment_id(), model_source,
new_segment_info.segment_id() == old_segment_info->segment_id());
if (old_segment_info->has_model_version()) {
old_model_version = old_segment_info->model_version();
}
bool is_same_model = old_model_version.has_value() &&
old_model_version.value() == model_version;
if (is_same_model) {
if (old_segment_info->has_prediction_result()) {
// If we have an old PredictionResult, we need to keep it around in the
// new version of the SegmentInfo.
auto* prediction_result = new_segment_info.mutable_prediction_result();
prediction_result->CopyFrom(old_segment_info->prediction_result());
}
if (old_segment_info->has_model_update_time_s()) {
new_model_update_time_s = old_segment_info->model_update_time_s();
}
}
}
new_segment_info.set_model_update_time_s(new_model_update_time_s);
// We have a valid segment id, and the new metadata was valid, therefore the
// new metadata should be valid. We are not allowed to invoke the callback
// unless the metadata is valid.
auto validation =
metadata_utils::ValidateSegmentInfoMetadataAndFeatures(new_segment_info);
stats::RecordModelDeliveryMetadataValidation(
segment_id, model_source, /* processed = */ true, validation);
if (validation != metadata_utils::ValidationResult::kValidationSuccess) {
return;
}
stats::RecordModelDeliveryMetadataFeatureCount(
segment_id, model_source,
new_segment_info.model_metadata().input_features_size());
// Now that we've merged the old and the new SegmentInfo, we want to store
// the new version in the database.
auto update_callback = base::BindOnce(
&ModelManagerImpl::OnUpdatedSegmentInfoStored,
weak_ptr_factory_.GetWeakPtr(), new_segment_info, old_model_version);
segment_database_->UpdateSegment(
segment_id, model_source,
absl::make_optional(std::move(new_segment_info)),
std::move(update_callback));
}
void ModelManagerImpl::OnSegmentInfoDeleted(SegmentId segment_id,
proto::ModelSource model_source,
int64_t deleted_version,
bool success) {
stats::RecordModelDeliveryDeleteResult(segment_id, model_source, success);
// `model_updated_callback_` only supports server models.
if (model_source == proto::ModelSource::DEFAULT_MODEL_SOURCE) {
return;
}
proto::SegmentInfo deleted_segment_info;
deleted_segment_info.set_segment_id(segment_id);
deleted_segment_info.set_model_source(model_source);
model_updated_callback_.Run(std::move(deleted_segment_info), deleted_version);
}
void ModelManagerImpl::OnUpdatedSegmentInfoStored(
proto::SegmentInfo segment_info,
absl::optional<int64_t> old_model_version,
bool success) {
TRACE_EVENT("segmentation_platform",
"ModelManagerImpl::OnUpdatedSegmentInfoStored");
stats::RecordModelDeliverySaveResult(segment_info.segment_id(),
segment_info.model_source(), success);
if (!success) {
return;
}
// We are now ready to receive requests for execution, so invoke the
// callback.
// TODO (ritikagup@) : Add handling for default models to run callback.
// Remove this when callback supports default models.
if (segment_info.model_source() == proto::ModelSource::DEFAULT_MODEL_SOURCE) {
return;
}
model_updated_callback_.Run(std::move(segment_info), old_model_version);
}
} // namespace segmentation_platform