blob: 10f6e23019851485bd8608b395d3cc187c8c53d2 [file] [log] [blame]
// Copyright 2021 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/segmentation_platform/internal/service_proxy_impl.h"
#include <inttypes.h>
#include <sstream>
#include "base/observer_list.h"
#include "base/strings/string_number_conversions.h"
#include "base/strings/stringprintf.h"
#include "base/time/time.h"
#include "components/segmentation_platform/internal/database/signal_storage_config.h"
#include "components/segmentation_platform/internal/metadata/metadata_utils.h"
#include "components/segmentation_platform/internal/scheduler/execution_service.h"
#include "components/segmentation_platform/internal/segmentation_platform_service_impl.h"
#include "components/segmentation_platform/internal/selection/segment_selector_impl.h"
#include "components/segmentation_platform/public/config.h"
#include "base/logging.h"
namespace segmentation_platform {
namespace {
std::string SegmentMetadataToString(const proto::SegmentInfo& segment_info) {
if (!segment_info.has_model_metadata())
return std::string();
return "model_metadata: { " +
metadata_utils::SegmetationModelMetadataToString(
segment_info.model_metadata()) +
" }";
}
std::string PredictionResultToString(const proto::SegmentInfo& segment_info) {
if (!segment_info.has_prediction_result())
return std::string();
const auto prediction_result = segment_info.prediction_result();
base::Time time;
if (prediction_result.has_timestamp_us()) {
time = base::Time::FromDeltaSinceWindowsEpoch(
base::Microseconds(prediction_result.timestamp_us()));
}
std::ostringstream time_string;
time_string << time;
return base::StringPrintf(
"result: %f, time: %s",
prediction_result.has_result() ? prediction_result.result() : 0,
time_string.str().c_str());
}
} // namespace
ServiceProxyImpl::ServiceProxyImpl(
SegmentInfoDatabase* segment_db,
SignalStorageConfig* signal_storage_config,
std::vector<std::unique_ptr<Config>>* configs,
base::flat_map<std::string, std::unique_ptr<SegmentSelectorImpl>>*
segment_selectors)
: segment_db_(segment_db),
signal_storage_config_(signal_storage_config),
configs_(configs),
segment_selectors_(segment_selectors) {}
ServiceProxyImpl::~ServiceProxyImpl() = default;
void ServiceProxyImpl::AddObserver(ServiceProxy::Observer* observer) {
observers_.AddObserver(observer);
}
void ServiceProxyImpl::RemoveObserver(ServiceProxy::Observer* observer) {
observers_.RemoveObserver(observer);
}
void ServiceProxyImpl::OnServiceStatusChanged(bool is_initialized,
int status_flag) {
bool changed = (is_service_initialized_ != is_initialized) ||
(service_status_flag_ != status_flag);
is_service_initialized_ = is_initialized;
service_status_flag_ = status_flag;
UpdateObservers(changed);
}
void ServiceProxyImpl::UpdateObservers(bool update_service_status) {
if (observers_.empty())
return;
if (update_service_status) {
for (auto& obs : observers_)
obs.OnServiceStatusChanged(is_service_initialized_, service_status_flag_);
}
if (segment_db_ &&
(static_cast<int>(ServiceStatus::kSegmentationInfoDbInitialized) &
service_status_flag_)) {
segment_db_->GetAllSegmentInfo(
base::BindOnce(&ServiceProxyImpl::OnGetAllSegmentationInfo,
weak_ptr_factory_.GetWeakPtr()));
}
}
void ServiceProxyImpl::SetExecutionService(
ExecutionService* model_execution_scheduler) {
execution_service = model_execution_scheduler;
}
void ServiceProxyImpl::GetServiceStatus() {
UpdateObservers(true /* update_service_status */);
}
void ServiceProxyImpl::ExecuteModel(SegmentId segment_id) {
if (!execution_service ||
segment_id == SegmentId::OPTIMIZATION_TARGET_UNKNOWN) {
return;
}
segment_db_->GetSegmentInfo(
segment_id,
base::BindOnce(&ServiceProxyImpl::OnSegmentInfoFetchedForExecution,
weak_ptr_factory_.GetWeakPtr()));
}
void ServiceProxyImpl::OnSegmentInfoFetchedForExecution(
absl::optional<proto::SegmentInfo> segment_info) {
if (!segment_info)
return;
auto request = std::make_unique<ExecutionRequest>();
request->record_metrics_for_default = false;
request->save_result_to_db = true;
request->segment_info = &segment_info.value();
execution_service->RequestModelExecution(std::move(request));
}
void ServiceProxyImpl::OverwriteResult(SegmentId segment_id, float result) {
if (!execution_service)
return;
if (result < 0 || result > 1)
return;
if (segment_id != SegmentId::OPTIMIZATION_TARGET_UNKNOWN) {
execution_service->OverwriteModelExecutionResult(
segment_id, std::make_pair(result, ModelExecutionStatus::kSuccess));
}
}
void ServiceProxyImpl::SetSelectedSegment(const std::string& segmentation_key,
SegmentId segment_id) {
if (!segment_selectors_ ||
segment_selectors_->find(segmentation_key) == segment_selectors_->end()) {
return;
}
if (segment_id != SegmentId::OPTIMIZATION_TARGET_UNKNOWN) {
auto& selector = segment_selectors_->at(segmentation_key);
selector->UpdateSelectedSegment(segment_id);
}
}
void ServiceProxyImpl::OnGetAllSegmentationInfo(
std::unique_ptr<SegmentInfoDatabase::SegmentInfoList> segment_info) {
if (!configs_)
return;
// Convert the |segment_info| vector to a map for quick lookup.
base::flat_map<SegmentId, proto::SegmentInfo> segment_ids;
for (const auto& info : *segment_info) {
segment_ids[info.first] = info.second;
}
std::vector<ServiceProxy::ClientInfo> result;
for (const auto& config : *configs_) {
SegmentId selected = SegmentId::OPTIMIZATION_TARGET_UNKNOWN;
if (segment_selectors_ &&
segment_selectors_->find(config->segmentation_key) !=
segment_selectors_->end()) {
absl::optional<proto::SegmentId> target =
segment_selectors_->at(config->segmentation_key)
->GetCachedSegmentResult()
.segment;
if (target.has_value()) {
selected = *target;
}
}
result.emplace_back(config->segmentation_key, selected);
for (const auto& segment_id : config->segment_ids) {
if (!segment_ids.contains(segment_id))
continue;
const auto& info = segment_ids[segment_id];
result.back().segment_status.emplace_back(
segment_id, SegmentMetadataToString(info),
PredictionResultToString(info),
signal_storage_config_
? signal_storage_config_->MeetsSignalCollectionRequirement(
info.model_metadata())
: false);
}
}
for (auto& obs : observers_)
obs.OnClientInfoAvailable(result);
}
void ServiceProxyImpl::OnModelExecutionCompleted(SegmentId segment_id) {
// Update the observers with the new execution results.
UpdateObservers(false);
}
} // namespace segmentation_platform