blob: 38c8ea280bfde06455126361dc30f6d07ccfb685 [file] [log] [blame]
// Copyright 2019 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/optimization_guide/prediction/prediction_manager.h"
#include <utility>
#include "base/containers/flat_map.h"
#include "base/containers/flat_set.h"
#include "base/metrics/histogram_functions.h"
#include "base/metrics/histogram_macros.h"
#include "base/metrics/histogram_macros_local.h"
#include "base/rand_util.h"
#include "base/sequenced_task_runner.h"
#include "base/strings/stringprintf.h"
#include "base/task/post_task.h"
#include "base/task/thread_pool.h"
#include "base/time/default_clock.h"
#include "chrome/browser/browser_process.h"
#include "chrome/browser/engagement/site_engagement_service.h"
#include "chrome/browser/optimization_guide/optimization_guide_navigation_data.h"
#include "chrome/browser/optimization_guide/optimization_guide_permissions_util.h"
#include "chrome/browser/optimization_guide/optimization_guide_session_statistic.h"
#include "chrome/browser/optimization_guide/optimization_guide_util.h"
#include "chrome/browser/optimization_guide/prediction/prediction_model.h"
#include "chrome/browser/optimization_guide/prediction/prediction_model_fetcher.h"
#include "chrome/browser/profiles/profile.h"
#include "components/leveldb_proto/public/proto_database_provider.h"
#include "components/optimization_guide/optimization_guide_constants.h"
#include "components/optimization_guide/optimization_guide_decider.h"
#include "components/optimization_guide/optimization_guide_enums.h"
#include "components/optimization_guide/optimization_guide_features.h"
#include "components/optimization_guide/optimization_guide_prefs.h"
#include "components/optimization_guide/optimization_guide_store.h"
#include "components/optimization_guide/optimization_guide_switches.h"
#include "components/optimization_guide/store_update_data.h"
#include "components/optimization_guide/top_host_provider.h"
#include "components/prefs/pref_service.h"
#include "content/public/browser/navigation_handle.h"
#include "content/public/browser/web_contents.h"
namespace {
// Returns true if |optimization_target_decision| reflects that the model had
// already been evaluated.
bool ShouldUseCurrentOptimizationTargetDecision(
optimization_guide::OptimizationTargetDecision
optimization_target_decision) {
switch (optimization_target_decision) {
case optimization_guide::OptimizationTargetDecision::kPageLoadMatches:
case optimization_guide::OptimizationTargetDecision::kPageLoadDoesNotMatch:
case optimization_guide::OptimizationTargetDecision::
kModelPredictionHoldback:
return true;
case optimization_guide::OptimizationTargetDecision::
kModelNotAvailableOnClient:
case optimization_guide::OptimizationTargetDecision::kUnknown:
case optimization_guide::OptimizationTargetDecision::kDeciderNotInitialized:
return false;
}
}
// Delay between retries on failed fetch and store of prediction models and
// host model features from the remote Optimization Guide Service.
constexpr base::TimeDelta kFetchRetryDelay = base::TimeDelta::FromMinutes(16);
// The amount of time to wait after a successful fetch of models and host model
// features before requesting an update from the remote Optimization Guide
// Service.
constexpr base::TimeDelta kUpdateModelsAndFeaturesDelay =
base::TimeDelta::FromHours(24);
// Provide a random time delta in seconds before fetching models and host model
// features.
base::TimeDelta RandomFetchDelay() {
return base::TimeDelta::FromSeconds(base::RandInt(
optimization_guide::features::PredictionModelFetchRandomMinDelaySecs(),
optimization_guide::features::PredictionModelFetchRandomMaxDelaySecs()));
}
// Util class for recording the state of a prediction model. The result is
// recorded when it goes out of scope and its destructor is called.
class ScopedPredictionManagerModelStatusRecorder {
public:
explicit ScopedPredictionManagerModelStatusRecorder(
optimization_guide::proto::OptimizationTarget optimization_target)
: status_(optimization_guide::PredictionManagerModelStatus::kUnknown),
optimization_target_(optimization_target) {}
~ScopedPredictionManagerModelStatusRecorder() {
DCHECK_NE(status_,
optimization_guide::PredictionManagerModelStatus::kUnknown);
base::UmaHistogramEnumeration(
"OptimizationGuide.ShouldTargetNavigation.PredictionModelStatus",
status_);
base::UmaHistogramEnumeration(
"OptimizationGuide.ShouldTargetNavigation.PredictionModelStatus." +
GetStringNameForOptimizationTarget(optimization_target_),
status_);
}
void set_status(optimization_guide::PredictionManagerModelStatus status) {
status_ = status;
}
private:
optimization_guide::PredictionManagerModelStatus status_;
const optimization_guide::proto::OptimizationTarget optimization_target_;
};
// Util class for recording the construction and validation of a prediction
// model. The result is recorded when it goes out of scope and its destructor is
// called.
class ScopedPredictionModelConstructionAndValidationRecorder {
public:
explicit ScopedPredictionModelConstructionAndValidationRecorder(
optimization_guide::proto::OptimizationTarget optimization_target)
: validation_start_time_(base::TimeTicks::Now()),
optimization_target_(optimization_target) {}
~ScopedPredictionModelConstructionAndValidationRecorder() {
base::UmaHistogramBoolean("OptimizationGuide.IsPredictionModelValid",
is_valid_);
base::UmaHistogramBoolean(
"OptimizationGuide.IsPredictionModelValid." +
GetStringNameForOptimizationTarget(optimization_target_),
is_valid_);
// Only record the timing if the model is valid and was able to be
// constructed.
if (is_valid_) {
base::TimeDelta validation_latency =
base::TimeTicks::Now() - validation_start_time_;
base::UmaHistogramTimes(
"OptimizationGuide.PredictionModelValidationLatency",
validation_latency);
base::UmaHistogramTimes(
"OptimizationGuide.PredictionModelValidationLatency." +
GetStringNameForOptimizationTarget(optimization_target_),
validation_latency);
}
}
void set_is_valid(bool is_valid) { is_valid_ = is_valid; }
private:
bool is_valid_ = true;
const base::TimeTicks validation_start_time_;
const optimization_guide::proto::OptimizationTarget optimization_target_;
};
} // namespace
namespace optimization_guide {
PredictionManager::PredictionManager(
const std::vector<optimization_guide::proto::OptimizationTarget>&
optimization_targets_at_initialization,
const base::FilePath& profile_path,
leveldb_proto::ProtoDatabaseProvider* database_provider,
TopHostProvider* top_host_provider,
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
PrefService* pref_service,
Profile* profile)
: PredictionManager(
optimization_targets_at_initialization,
std::make_unique<OptimizationGuideStore>(
database_provider,
profile_path.AddExtensionASCII(
optimization_guide::
kOptimizationGuidePredictionModelAndFeaturesStore),
base::ThreadPool::CreateSequencedTaskRunner(
{base::MayBlock(), base::TaskPriority::BEST_EFFORT})),
top_host_provider,
url_loader_factory,
pref_service,
profile) {}
PredictionManager::PredictionManager(
const std::vector<optimization_guide::proto::OptimizationTarget>&
optimization_targets_at_initialization,
std::unique_ptr<OptimizationGuideStore> model_and_features_store,
TopHostProvider* top_host_provider,
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
PrefService* pref_service,
Profile* profile)
: host_model_features_cache_(
std::max(features::MaxHostModelFeaturesCacheSize(), size_t(1))),
session_fcp_(),
top_host_provider_(top_host_provider),
model_and_features_store_(std::move(model_and_features_store)),
url_loader_factory_(url_loader_factory),
pref_service_(pref_service),
profile_(profile),
clock_(base::DefaultClock::GetInstance()) {
Initialize(optimization_targets_at_initialization);
}
PredictionManager::~PredictionManager() {
g_browser_process->network_quality_tracker()
->RemoveEffectiveConnectionTypeObserver(this);
}
void PredictionManager::Initialize(const std::vector<proto::OptimizationTarget>&
optimization_targets_at_initialization) {
RegisterOptimizationTargets(optimization_targets_at_initialization);
g_browser_process->network_quality_tracker()
->AddEffectiveConnectionTypeObserver(this);
model_and_features_store_->Initialize(
switches::ShouldPurgeModelAndFeaturesStoreOnStartup(),
base::BindOnce(&PredictionManager::OnStoreInitialized,
ui_weak_ptr_factory_.GetWeakPtr()));
}
void PredictionManager::UpdateFCPSessionStatistics(base::TimeDelta fcp) {
previous_load_fcp_ms_ = static_cast<float>(fcp.InMilliseconds());
session_fcp_.AddSample(*previous_load_fcp_ms_);
pref_service_->SetDouble(prefs::kSessionStatisticFCPMean,
session_fcp_.GetMean());
pref_service_->SetDouble(prefs::kSessionStatisticFCPStdDev,
session_fcp_.GetStdDev());
}
void PredictionManager::RegisterOptimizationTargets(
const std::vector<proto::OptimizationTarget>& optimization_targets) {
SEQUENCE_CHECKER(sequence_checker_);
if (optimization_targets.size() == 0)
return;
base::flat_set<proto::OptimizationTarget> new_optimization_targets;
for (const auto& optimization_target : optimization_targets) {
if (optimization_target == proto::OPTIMIZATION_TARGET_UNKNOWN)
continue;
if (registered_optimization_targets_.find(optimization_target) !=
registered_optimization_targets_.end()) {
continue;
}
registered_optimization_targets_.insert(optimization_target);
new_optimization_targets.insert(optimization_target);
}
// Before loading/fetching models and features, the store must be ready.
if (!store_is_ready_)
return;
// Only proceed if there are newly registered targets to load/fetch models and
// features for. Otherwise, the registered targets will have models loaded
// when the store was initialized.
if (new_optimization_targets.size() == 0)
return;
// Start loading the host model features if they are not already.
if (!host_model_features_loaded_) {
LoadHostModelFeatures();
return;
}
// Otherwise, the host model features are loaded, so load prediction models
// for any newly registered targets.
LoadPredictionModels(new_optimization_targets);
}
base::Optional<float> PredictionManager::GetValueForClientFeature(
const std::string& model_feature,
content::NavigationHandle* navigation_handle) const {
SEQUENCE_CHECKER(sequence_checker_);
proto::ClientModelFeature client_model_feature;
if (!proto::ClientModelFeature_Parse(model_feature, &client_model_feature))
return base::nullopt;
base::Optional<float> value;
switch (client_model_feature) {
case proto::CLIENT_MODEL_FEATURE_UNKNOWN: {
return base::nullopt;
}
case proto::CLIENT_MODEL_FEATURE_EFFECTIVE_CONNECTION_TYPE: {
value = static_cast<float>(current_effective_connection_type_);
break;
}
case proto::CLIENT_MODEL_FEATURE_PAGE_TRANSITION: {
value = static_cast<float>(navigation_handle->GetPageTransition());
break;
}
case proto::CLIENT_MODEL_FEATURE_SITE_ENGAGEMENT_SCORE: {
Profile* profile = Profile::FromBrowserContext(
navigation_handle->GetWebContents()->GetBrowserContext());
SiteEngagementService* engagement_service =
SiteEngagementService::Get(profile);
// Precision loss is acceptable/expected for prediction models.
value = static_cast<float>(
engagement_service->GetScore(navigation_handle->GetURL()));
break;
}
case proto::CLIENT_MODEL_FEATURE_SAME_ORIGIN_NAVIGATION: {
OptimizationGuideNavigationData* nav_data =
OptimizationGuideNavigationData::GetFromNavigationHandle(
navigation_handle);
bool is_same_origin = nav_data && nav_data->is_same_origin_navigation();
LOCAL_HISTOGRAM_BOOLEAN(
"OptimizationGuide.PredictionManager.IsSameOrigin", is_same_origin);
value = static_cast<float>(is_same_origin);
break;
}
case proto::CLIENT_MODEL_FEATURE_FIRST_CONTENTFUL_PAINT_SESSION_MEAN: {
value = session_fcp_.GetNumberOfSamples() == 0
? static_cast<float>(pref_service_->GetDouble(
prefs::kSessionStatisticFCPMean))
: session_fcp_.GetMean();
break;
}
case proto::
CLIENT_MODEL_FEATURE_FIRST_CONTENTFUL_PAINT_SESSION_STANDARD_DEVIATION: {
value = session_fcp_.GetNumberOfSamples() == 0
? static_cast<float>(pref_service_->GetDouble(
prefs::kSessionStatisticFCPStdDev))
: session_fcp_.GetStdDev();
break;
}
case proto::
CLIENT_MODEL_FEATURE_FIRST_CONTENTFUL_PAINT_PREVIOUS_PAGE_LOAD: {
value = previous_load_fcp_ms_.value_or(static_cast<float>(
pref_service_->GetDouble(prefs::kSessionStatisticFCPMean)));
break;
}
default: {
return base::nullopt;
}
}
OptimizationGuideNavigationData* navigation_data =
OptimizationGuideNavigationData::GetFromNavigationHandle(
navigation_handle);
if (value && navigation_data) {
navigation_data->SetValueForModelFeature(client_model_feature, *value);
return value;
}
return base::nullopt;
}
base::flat_map<std::string, float> PredictionManager::BuildFeatureMap(
content::NavigationHandle* navigation_handle,
const base::flat_set<std::string>& model_features) {
SEQUENCE_CHECKER(sequence_checker_);
base::flat_map<std::string, float> feature_map;
if (model_features.size() == 0)
return feature_map;
const base::flat_map<std::string, float>* host_model_features = nullptr;
std::string host = navigation_handle->GetURL().host();
auto it = host_model_features_cache_.Get(host);
if (it != host_model_features_cache_.end())
host_model_features = &(it->second);
UMA_HISTOGRAM_BOOLEAN(
"OptimizationGuide.PredictionManager.HasHostModelFeaturesForHost",
host_model_features != nullptr);
// If the feature is not implemented by the client, it is assumed that it is a
// host model feature we have in the map. If it is not in either, a default is
// created for it. This ensures that the prediction model will have values for
// every feature that it requires to be evaluated.
for (const auto& model_feature : model_features) {
base::Optional<float> value =
GetValueForClientFeature(model_feature, navigation_handle);
if (value) {
feature_map[model_feature] = *value;
continue;
}
if (!host_model_features || !host_model_features->contains(model_feature)) {
feature_map[model_feature] = -1.0;
continue;
}
feature_map[model_feature] =
host_model_features->find(model_feature)->second;
}
return feature_map;
}
OptimizationTargetDecision PredictionManager::ShouldTargetNavigation(
content::NavigationHandle* navigation_handle,
proto::OptimizationTarget optimization_target) {
SEQUENCE_CHECKER(sequence_checker_);
DCHECK(navigation_handle->GetURL().SchemeIsHTTPOrHTTPS());
OptimizationGuideNavigationData* navigation_data =
OptimizationGuideNavigationData::GetFromNavigationHandle(
navigation_handle);
if (navigation_data) {
base::Optional<optimization_guide::OptimizationTargetDecision>
optimization_target_decision =
navigation_data->GetDecisionForOptimizationTarget(
optimization_target);
if (optimization_target_decision.has_value() &&
ShouldUseCurrentOptimizationTargetDecision(
*optimization_target_decision)) {
return *optimization_target_decision;
}
}
if (!registered_optimization_targets_.contains(optimization_target))
return OptimizationTargetDecision::kUnknown;
ScopedPredictionManagerModelStatusRecorder model_status_recorder(
optimization_target);
auto it = optimization_target_prediction_model_map_.find(optimization_target);
if (it == optimization_target_prediction_model_map_.end()) {
if (store_is_ready_ && model_and_features_store_) {
OptimizationGuideStore::EntryKey model_entry_key;
if (model_and_features_store_->FindPredictionModelEntryKey(
optimization_target, &model_entry_key)) {
model_status_recorder.set_status(
PredictionManagerModelStatus::kStoreAvailableModelNotLoaded);
} else {
model_status_recorder.set_status(
PredictionManagerModelStatus::kStoreAvailableNoModelForTarget);
}
} else {
model_status_recorder.set_status(
PredictionManagerModelStatus::kStoreUnavailableModelUnknown);
}
return OptimizationTargetDecision::kModelNotAvailableOnClient;
}
model_status_recorder.set_status(
PredictionManagerModelStatus::kModelAvailable);
PredictionModel* prediction_model = it->second.get();
base::flat_map<std::string, float> feature_map =
BuildFeatureMap(navigation_handle, prediction_model->GetModelFeatures());
base::TimeTicks model_evaluation_start_time = base::TimeTicks::Now();
double prediction_score = 0.0;
optimization_guide::OptimizationTargetDecision target_decision =
prediction_model->Predict(feature_map, &prediction_score);
if (target_decision != OptimizationTargetDecision::kUnknown) {
UmaHistogramTimes(
"OptimizationGuide.PredictionModelEvaluationLatency." +
GetStringNameForOptimizationTarget(optimization_target),
base::TimeTicks::Now() - model_evaluation_start_time);
}
if (navigation_data) {
navigation_data->SetModelVersionForOptimizationTarget(
optimization_target, prediction_model->GetVersion());
navigation_data->SetModelPredictionScoreForOptimizationTarget(
optimization_target, prediction_score);
}
if (optimization_guide::features::
ShouldOverrideOptimizationTargetDecisionForMetricsPurposes(
optimization_target)) {
return optimization_guide::OptimizationTargetDecision::
kModelPredictionHoldback;
}
return target_decision;
}
void PredictionManager::OnEffectiveConnectionTypeChanged(
net::EffectiveConnectionType effective_connection_type) {
SEQUENCE_CHECKER(sequence_checker_);
current_effective_connection_type_ = effective_connection_type;
}
PredictionModel* PredictionManager::GetPredictionModelForTesting(
proto::OptimizationTarget optimization_target) const {
auto it = optimization_target_prediction_model_map_.find(optimization_target);
if (it != optimization_target_prediction_model_map_.end())
return it->second.get();
return nullptr;
}
const HostModelFeaturesMRUCache*
PredictionManager::GetHostModelFeaturesForTesting() const {
return &host_model_features_cache_;
}
void PredictionManager::SetPredictionModelFetcherForTesting(
std::unique_ptr<PredictionModelFetcher> prediction_model_fetcher) {
prediction_model_fetcher_ = std::move(prediction_model_fetcher);
}
void PredictionManager::FetchModelsAndHostModelFeatures() {
SEQUENCE_CHECKER(sequence_checker_);
if (!IsUserPermittedToFetchFromRemoteOptimizationGuide(profile_))
return;
ScheduleModelsAndHostModelFeaturesFetch();
// Models and host model features should not be fetched if there are no
// optimization targets registered.
if (registered_optimization_targets_.size() == 0)
return;
std::vector<std::string> top_hosts;
// If the top host provider is not available, the user has likely not seen the
// Lite mode infobar, so top hosts cannot be provided. However, prediction
// models are allowed to be fetched.
if (top_host_provider_) {
top_hosts = top_host_provider_->GetTopHosts();
// Remove hosts that are already available in the host model features cache.
// The request should still be made in case there is a new model or a model
// that does not rely on host model features to be fetched.
auto it = top_hosts.begin();
while (it != top_hosts.end()) {
if (host_model_features_cache_.Peek(*it) !=
host_model_features_cache_.end()) {
it = top_hosts.erase(it);
continue;
}
++it;
}
}
if (!prediction_model_fetcher_) {
prediction_model_fetcher_ = std::make_unique<PredictionModelFetcher>(
url_loader_factory_,
features::GetOptimizationGuideServiceGetModelsURL());
}
std::vector<proto::ModelInfo> models_info = std::vector<proto::ModelInfo>();
proto::ModelInfo base_model_info;
for (auto client_model_feature = proto::ClientModelFeature_MIN + 1;
client_model_feature <= proto::ClientModelFeature_MAX;
client_model_feature++) {
if (proto::ClientModelFeature_IsValid(client_model_feature)) {
base_model_info.add_supported_model_features(
static_cast<proto::ClientModelFeature>(client_model_feature));
}
}
// Only Decision Trees are currently supported.
base_model_info.add_supported_model_types(proto::MODEL_TYPE_DECISION_TREE);
// For now, we will fetch for all registered optimization targets.
for (const auto& optimization_target : registered_optimization_targets_) {
proto::ModelInfo model_info(base_model_info);
model_info.set_optimization_target(optimization_target);
auto it =
optimization_target_prediction_model_map_.find(optimization_target);
if (it != optimization_target_prediction_model_map_.end())
model_info.set_version(it->second.get()->GetVersion());
models_info.push_back(model_info);
}
prediction_model_fetcher_->FetchOptimizationGuideServiceModels(
models_info, top_hosts, optimization_guide::proto::CONTEXT_BATCH_UPDATE,
base::BindOnce(&PredictionManager::OnModelsAndHostFeaturesFetched,
ui_weak_ptr_factory_.GetWeakPtr()));
}
void PredictionManager::OnModelsAndHostFeaturesFetched(
base::Optional<std::unique_ptr<proto::GetModelsResponse>>
get_models_response_data) {
SEQUENCE_CHECKER(sequence_checker_);
if (!get_models_response_data)
return;
// Update host model features, even if empty so the store metadata
// that contains the update time for new models and features to be fetched
// from the remote Optimization Guide Service is updated.
UpdateHostModelFeatures((*get_models_response_data)->host_model_features());
if ((*get_models_response_data)->models_size() > 0) {
// Stash the response so the models can be stored once the host
// model features are stored.
get_models_response_data_to_store_ = std::move(*get_models_response_data);
}
fetch_timer_.Stop();
fetch_timer_.Start(
FROM_HERE, kUpdateModelsAndFeaturesDelay, this,
&PredictionManager::ScheduleModelsAndHostModelFeaturesFetch);
}
void PredictionManager::UpdateHostModelFeatures(
const google::protobuf::RepeatedPtrField<proto::HostModelFeatures>&
host_model_features) {
SEQUENCE_CHECKER(sequence_checker_);
std::unique_ptr<StoreUpdateData> host_model_features_update_data =
StoreUpdateData::CreateHostModelFeaturesStoreUpdateData(
/*update_time=*/clock_->Now() + kUpdateModelsAndFeaturesDelay,
/*expiry_time=*/clock_->Now() +
features::StoredHostModelFeaturesFreshnessDuration());
for (const auto& host_model_features : host_model_features) {
if (ProcessAndStoreHostModelFeatures(host_model_features)) {
host_model_features_update_data->CopyHostModelFeaturesIntoUpdateData(
host_model_features);
}
}
model_and_features_store_->UpdateHostModelFeatures(
std::move(host_model_features_update_data),
base::BindOnce(&PredictionManager::OnHostModelFeaturesStored,
ui_weak_ptr_factory_.GetWeakPtr()));
}
std::unique_ptr<PredictionModel> PredictionManager::CreatePredictionModel(
const proto::PredictionModel& model) const {
SEQUENCE_CHECKER(sequence_checker_);
return PredictionModel::Create(
std::make_unique<proto::PredictionModel>(model));
}
void PredictionManager::UpdatePredictionModels(
const google::protobuf::RepeatedPtrField<proto::PredictionModel>&
prediction_models) {
SEQUENCE_CHECKER(sequence_checker_);
std::unique_ptr<StoreUpdateData> prediction_model_update_data =
StoreUpdateData::CreatePredictionModelStoreUpdateData();
bool models_to_store = false;
for (const auto& model : prediction_models) {
if (ProcessAndStorePredictionModel(model)) {
prediction_model_update_data->CopyPredictionModelIntoUpdateData(model);
models_to_store = true;
base::UmaHistogramSparse(
"OptimizationGuide.PredictionModelUpdateVersion." +
GetStringNameForOptimizationTarget(
model.model_info().optimization_target()),
model.model_info().version());
}
}
if (models_to_store) {
model_and_features_store_->UpdatePredictionModels(
std::move(prediction_model_update_data),
base::BindOnce(&PredictionManager::OnPredictionModelsStored,
ui_weak_ptr_factory_.GetWeakPtr()));
}
}
void PredictionManager::OnPredictionModelsStored() {
SEQUENCE_CHECKER(sequence_checker_);
LOCAL_HISTOGRAM_BOOLEAN(
"OptimizationGuide.PredictionManager.PredictionModelsStored", true);
}
void PredictionManager::OnHostModelFeaturesStored() {
SEQUENCE_CHECKER(sequence_checker_);
LOCAL_HISTOGRAM_BOOLEAN(
"OptimizationGuide.PredictionManager.HostModelFeaturesStored", true);
if (get_models_response_data_to_store_ &&
get_models_response_data_to_store_->models_size() > 0) {
UpdatePredictionModels(get_models_response_data_to_store_->models());
}
// Clear any data remaining in the stored get models response.
get_models_response_data_to_store_.reset();
// Purge any expired host model features from the store.
model_and_features_store_->PurgeExpiredHostModelFeatures();
fetch_timer_.Stop();
ScheduleModelsAndHostModelFeaturesFetch();
}
void PredictionManager::OnStoreInitialized() {
SEQUENCE_CHECKER(sequence_checker_);
store_is_ready_ = true;
// Only load host model features if there are optimization targets registered.
if (registered_optimization_targets_.size() == 0)
return;
// The store is ready so start loading host model features and the models for
// the registered optimization targets. Once the host model features are
// loaded, prediction models for the registered optimization targets will be
// loaded.
LoadHostModelFeatures();
MaybeScheduleModelAndHostModelFeaturesFetch();
}
void PredictionManager::LoadHostModelFeatures() {
SEQUENCE_CHECKER(sequence_checker_);
// Load the host model features first, each prediction model requires the set
// of host model features to be known before creation.
model_and_features_store_->LoadAllHostModelFeatures(
base::BindOnce(&PredictionManager::OnLoadHostModelFeatures,
ui_weak_ptr_factory_.GetWeakPtr()));
}
void PredictionManager::OnLoadHostModelFeatures(
std::unique_ptr<std::vector<proto::HostModelFeatures>>
all_host_model_features) {
SEQUENCE_CHECKER(sequence_checker_);
// If the store returns an empty vector of host model features, the store
// contains no host model features. However, the load is otherwise complete
// and prediction models can be loaded but they will require no host model
// feature information.
host_model_features_loaded_ = true;
if (all_host_model_features) {
for (const auto& host_model_features : *all_host_model_features)
ProcessAndStoreHostModelFeatures(host_model_features);
}
UMA_HISTOGRAM_COUNTS_1000(
"OptimizationGuide.PredictionManager.HostModelFeaturesMapSize",
host_model_features_cache_.size());
// Load the prediction models for all the registered optimization targets now
// that it is not blocked by loading the host model features.
LoadPredictionModels(registered_optimization_targets_);
}
void PredictionManager::LoadPredictionModels(
const base::flat_set<proto::OptimizationTarget>& optimization_targets) {
SEQUENCE_CHECKER(sequence_checker_);
DCHECK(host_model_features_loaded_);
OptimizationGuideStore::EntryKey model_entry_key;
for (const auto& optimization_target : optimization_targets) {
// The prediction model for this optimization target has already been
// loaded.
if (optimization_target_prediction_model_map_.contains(optimization_target))
continue;
if (!model_and_features_store_->FindPredictionModelEntryKey(
optimization_target, &model_entry_key)) {
continue;
}
model_and_features_store_->LoadPredictionModel(
model_entry_key,
base::BindOnce(&PredictionManager::OnLoadPredictionModel,
ui_weak_ptr_factory_.GetWeakPtr()));
}
}
void PredictionManager::OnLoadPredictionModel(
std::unique_ptr<proto::PredictionModel> model) {
SEQUENCE_CHECKER(sequence_checker_);
if (!model)
return;
if (ProcessAndStorePredictionModel(*model)) {
base::UmaHistogramSparse("OptimizationGuide.PredictionModelLoadedVersion." +
GetStringNameForOptimizationTarget(
model->model_info().optimization_target()),
model->model_info().version());
}
}
bool PredictionManager::ProcessAndStorePredictionModel(
const proto::PredictionModel& model) {
SEQUENCE_CHECKER(sequence_checker_);
if (!model.model_info().has_optimization_target())
return false;
if (!model.has_model())
return false;
if (!registered_optimization_targets_.contains(
model.model_info().optimization_target())) {
return false;
}
ScopedPredictionModelConstructionAndValidationRecorder
prediction_model_recorder(model.model_info().optimization_target());
std::unique_ptr<PredictionModel> prediction_model =
CreatePredictionModel(model);
if (!prediction_model) {
prediction_model_recorder.set_is_valid(false);
return false;
}
auto it = optimization_target_prediction_model_map_.find(
model.model_info().optimization_target());
if (it == optimization_target_prediction_model_map_.end()) {
optimization_target_prediction_model_map_.emplace(
model.model_info().optimization_target(), std::move(prediction_model));
return true;
}
if (it->second->GetVersion() != prediction_model->GetVersion()) {
it->second = std::move(prediction_model);
return true;
}
return false;
}
bool PredictionManager::ProcessAndStoreHostModelFeatures(
const proto::HostModelFeatures& host_model_features) {
SEQUENCE_CHECKER(sequence_checker_);
if (!host_model_features.has_host())
return false;
if (host_model_features.model_features_size() == 0)
return false;
base::flat_map<std::string, float> model_features_for_host;
model_features_for_host.reserve(host_model_features.model_features_size());
for (const auto& model_feature : host_model_features.model_features()) {
if (!model_feature.has_feature_name())
continue;
switch (model_feature.feature_value_case()) {
case proto::ModelFeature::kDoubleValue:
// Loss of precision from double is acceptable for features supported
// by the prediction models.
model_features_for_host.emplace(
model_feature.feature_name(),
static_cast<float>(model_feature.double_value()));
break;
case proto::ModelFeature::kInt64Value:
model_features_for_host.emplace(
model_feature.feature_name(),
static_cast<float>(model_feature.int64_value()));
break;
case proto::ModelFeature::FEATURE_VALUE_NOT_SET:
NOTREACHED();
break;
}
}
if (model_features_for_host.size() == 0)
return false;
host_model_features_cache_.Put(host_model_features.host(),
model_features_for_host);
return true;
}
void PredictionManager::MaybeScheduleModelAndHostModelFeaturesFetch() {
if (optimization_guide::switches::
ShouldOverrideFetchModelsAndFeaturesTimer()) {
SetLastModelAndFeaturesFetchAttemptTime(clock_->Now());
fetch_timer_.Start(FROM_HERE, base::TimeDelta::FromSeconds(1), this,
&PredictionManager::FetchModelsAndHostModelFeatures);
} else {
ScheduleModelsAndHostModelFeaturesFetch();
}
}
base::Time PredictionManager::GetLastFetchAttemptTime() const {
SEQUENCE_CHECKER(squence_checker_);
return base::Time::FromDeltaSinceWindowsEpoch(
base::TimeDelta::FromMicroseconds(
pref_service_->GetInt64(prefs::kModelAndFeaturesLastFetchAttempt)));
}
void PredictionManager::ScheduleModelsAndHostModelFeaturesFetch() {
DCHECK(!fetch_timer_.IsRunning());
DCHECK(store_is_ready_);
const base::TimeDelta time_until_update_time =
model_and_features_store_->GetHostModelFeaturesUpdateTime() -
clock_->Now();
const base::TimeDelta time_until_retry =
GetLastFetchAttemptTime() + kFetchRetryDelay - clock_->Now();
base::TimeDelta fetcher_delay =
std::max(time_until_update_time, time_until_retry);
if (fetcher_delay <= base::TimeDelta()) {
SetLastModelAndFeaturesFetchAttemptTime(clock_->Now());
fetch_timer_.Start(FROM_HERE, RandomFetchDelay(), this,
&PredictionManager::FetchModelsAndHostModelFeatures);
return;
}
fetch_timer_.Start(
FROM_HERE, fetcher_delay, this,
&PredictionManager::ScheduleModelsAndHostModelFeaturesFetch);
}
void PredictionManager::SetLastModelAndFeaturesFetchAttemptTime(
base::Time last_attempt_time) {
SEQUENCE_CHECKER(sequence_checker_);
pref_service_->SetInt64(
prefs::kModelAndFeaturesLastFetchAttempt,
last_attempt_time.ToDeltaSinceWindowsEpoch().InMicroseconds());
}
void PredictionManager::SetClockForTesting(const base::Clock* clock) {
clock_ = clock;
}
void PredictionManager::ClearHostModelFeatures() {
host_model_features_cache_.Clear();
if (model_and_features_store_)
model_and_features_store_->ClearHostModelFeaturesFromDatabase();
}
base::Optional<base::flat_map<std::string, float>>
PredictionManager::GetHostModelFeaturesForHost(const std::string& host) const {
auto it = host_model_features_cache_.Peek(host);
if (it == host_model_features_cache_.end())
return base::nullopt;
return it->second;
}
} // namespace optimization_guide