| // Copyright 2019 The Chromium Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "chrome/browser/optimization_guide/prediction/prediction_manager.h" |
| |
| #include <memory> |
| #include <utility> |
| |
| #include "base/containers/flat_map.h" |
| #include "base/containers/flat_set.h" |
| #include "base/containers/flat_tree.h" |
| #include "base/metrics/histogram_functions.h" |
| #include "base/metrics/histogram_macros.h" |
| #include "base/metrics/histogram_macros_local.h" |
| #include "base/path_service.h" |
| #include "base/rand_util.h" |
| #include "base/sequence_checker.h" |
| #include "base/sequenced_task_runner.h" |
| #include "base/task/post_task.h" |
| #include "base/task/thread_pool.h" |
| #include "base/time/default_clock.h" |
| #include "chrome/browser/browser_process.h" |
| #include "chrome/browser/download/background_download_service_factory.h" |
| #include "chrome/browser/optimization_guide/prediction/prediction_model_download_manager.h" |
| #include "chrome/browser/profiles/profile.h" |
| #include "chrome/browser/profiles/profile_key.h" |
| #include "chrome/common/chrome_paths.h" |
| #include "components/optimization_guide/content/browser/optimization_guide_decider.h" |
| #include "components/optimization_guide/core/model_info.h" |
| #include "components/optimization_guide/core/optimization_guide_constants.h" |
| #include "components/optimization_guide/core/optimization_guide_enums.h" |
| #include "components/optimization_guide/core/optimization_guide_features.h" |
| #include "components/optimization_guide/core/optimization_guide_permissions_util.h" |
| #include "components/optimization_guide/core/optimization_guide_prefs.h" |
| #include "components/optimization_guide/core/optimization_guide_store.h" |
| #include "components/optimization_guide/core/optimization_guide_switches.h" |
| #include "components/optimization_guide/core/optimization_guide_util.h" |
| #include "components/optimization_guide/core/optimization_target_model_observer.h" |
| #include "components/optimization_guide/core/prediction_model.h" |
| #include "components/optimization_guide/core/prediction_model_fetcher.h" |
| #include "components/optimization_guide/core/store_update_data.h" |
| #include "components/optimization_guide/proto/models.pb.h" |
| #include "components/prefs/pref_service.h" |
| #include "components/site_engagement/content/site_engagement_service.h" |
| #include "content/public/browser/navigation_handle.h" |
| #include "content/public/browser/network_service_instance.h" |
| #include "content/public/browser/web_contents.h" |
| #include "mojo/public/cpp/bindings/remote.h" |
| #include "services/network/public/cpp/shared_url_loader_factory.h" |
| |
| namespace { |
| |
| // Provide a random time delta in seconds before fetching models and host model |
| // features. |
| base::TimeDelta RandomFetchDelay() { |
| return base::Seconds(base::RandInt( |
| optimization_guide::features::PredictionModelFetchRandomMinDelaySecs(), |
| optimization_guide::features::PredictionModelFetchRandomMaxDelaySecs())); |
| } |
| |
| // Util class for recording the state of a prediction model. The result is |
| // recorded when it goes out of scope and its destructor is called. |
| class ScopedPredictionManagerModelStatusRecorder { |
| public: |
| explicit ScopedPredictionManagerModelStatusRecorder( |
| optimization_guide::proto::OptimizationTarget optimization_target) |
| : status_(optimization_guide::PredictionManagerModelStatus::kUnknown), |
| optimization_target_(optimization_target) {} |
| |
| ~ScopedPredictionManagerModelStatusRecorder() { |
| DCHECK_NE(status_, |
| optimization_guide::PredictionManagerModelStatus::kUnknown); |
| base::UmaHistogramEnumeration( |
| "OptimizationGuide.ShouldTargetNavigation.PredictionModelStatus", |
| status_); |
| |
| base::UmaHistogramEnumeration( |
| "OptimizationGuide.ShouldTargetNavigation.PredictionModelStatus." + |
| optimization_guide::GetStringNameForOptimizationTarget( |
| optimization_target_), |
| status_); |
| } |
| |
| void set_status(optimization_guide::PredictionManagerModelStatus status) { |
| status_ = status; |
| } |
| |
| private: |
| optimization_guide::PredictionManagerModelStatus status_; |
| const optimization_guide::proto::OptimizationTarget optimization_target_; |
| }; |
| |
| // Util class for recording the construction and validation of a prediction |
| // model. The result is recorded when it goes out of scope and its destructor is |
| // called. |
| class ScopedPredictionModelConstructionAndValidationRecorder { |
| public: |
| explicit ScopedPredictionModelConstructionAndValidationRecorder( |
| optimization_guide::proto::OptimizationTarget optimization_target) |
| : validation_start_time_(base::TimeTicks::Now()), |
| optimization_target_(optimization_target) {} |
| |
| ~ScopedPredictionModelConstructionAndValidationRecorder() { |
| base::UmaHistogramBoolean("OptimizationGuide.IsPredictionModelValid", |
| is_valid_); |
| base::UmaHistogramBoolean( |
| "OptimizationGuide.IsPredictionModelValid." + |
| optimization_guide::GetStringNameForOptimizationTarget( |
| optimization_target_), |
| is_valid_); |
| |
| // Only record the timing if the model is valid and was able to be |
| // constructed. |
| if (is_valid_) { |
| base::TimeDelta validation_latency = |
| base::TimeTicks::Now() - validation_start_time_; |
| base::UmaHistogramTimes( |
| "OptimizationGuide.PredictionModelValidationLatency", |
| validation_latency); |
| base::UmaHistogramTimes( |
| "OptimizationGuide.PredictionModelValidationLatency." + |
| optimization_guide::GetStringNameForOptimizationTarget( |
| optimization_target_), |
| validation_latency); |
| } |
| } |
| |
| void set_is_valid(bool is_valid) { is_valid_ = is_valid; } |
| |
| private: |
| bool is_valid_ = true; |
| const base::TimeTicks validation_start_time_; |
| const optimization_guide::proto::OptimizationTarget optimization_target_; |
| }; |
| |
| void RecordModelUpdateVersion( |
| const optimization_guide::proto::ModelInfo& model_info) { |
| base::UmaHistogramSparse( |
| "OptimizationGuide.PredictionModelUpdateVersion." + |
| optimization_guide::GetStringNameForOptimizationTarget( |
| model_info.optimization_target()), |
| model_info.version()); |
| } |
| |
| void RecordModelTypeChanged( |
| optimization_guide::proto::OptimizationTarget optimization_target, |
| bool changed) { |
| base::UmaHistogramBoolean( |
| "OptimizationGuide.PredictionManager.ModelTypeChanged." + |
| optimization_guide::GetStringNameForOptimizationTarget( |
| optimization_target), |
| changed); |
| } |
| |
| // Returns whether models and host model features should be fetched from the |
| // remote Optimization Guide Service. |
| bool ShouldFetchModels(Profile* profile) { |
| return optimization_guide::features::IsRemoteFetchingEnabled() && |
| !profile->IsOffTheRecord(); |
| } |
| |
| std::unique_ptr<optimization_guide::proto::PredictionModel> |
| BuildPredictionModelFromCommandLineForOptimizationTarget( |
| optimization_guide::proto::OptimizationTarget optimization_target) { |
| absl::optional< |
| std::pair<std::string, absl::optional<optimization_guide::proto::Any>>> |
| model_file_path_and_metadata = |
| optimization_guide::GetModelOverrideForOptimizationTarget( |
| optimization_target); |
| if (!model_file_path_and_metadata) |
| return nullptr; |
| |
| std::unique_ptr<optimization_guide::proto::PredictionModel> prediction_model = |
| std::make_unique<optimization_guide::proto::PredictionModel>(); |
| prediction_model->mutable_model_info()->set_optimization_target( |
| optimization_target); |
| prediction_model->mutable_model_info()->set_version(123); |
| if (model_file_path_and_metadata->second) { |
| *prediction_model->mutable_model_info()->mutable_model_metadata() = |
| model_file_path_and_metadata->second.value(); |
| } |
| prediction_model->mutable_model()->set_download_url( |
| model_file_path_and_metadata->first); |
| return prediction_model; |
| } |
| |
| } // namespace |
| |
| namespace optimization_guide { |
| |
| struct PredictionDecisionParams { |
| PredictionDecisionParams(proto::OptimizationTarget optimization_target, |
| OptimizationTargetDecisionCallback callback, |
| int64_t version, |
| base::TimeTicks model_evaluation_start_time) |
| : optimization_target(optimization_target), |
| callback(std::move(callback)), |
| version(version), |
| model_evaluation_start_time(model_evaluation_start_time) {} |
| |
| ~PredictionDecisionParams() = default; |
| |
| PredictionDecisionParams(const PredictionDecisionParams&) = delete; |
| PredictionDecisionParams& operator=(const PredictionDecisionParams&) = delete; |
| |
| // Target of the prediction. |
| proto::OptimizationTarget optimization_target; |
| // Callback to be invoked once a OptimizationTargetDecision is made. |
| OptimizationTargetDecisionCallback callback; |
| // Model version. |
| int64_t version; |
| // Time when the model evaluation is initiated. |
| base::TimeTicks model_evaluation_start_time; |
| }; |
| |
| PredictionManager::PredictionManager( |
| OptimizationGuideStore* model_and_features_store, |
| scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory, |
| PrefService* pref_service, |
| Profile* profile) |
| : host_model_features_cache_( |
| std::max(features::MaxHostModelFeaturesCacheSize(), size_t(1))), |
| prediction_model_download_manager_(nullptr), |
| model_and_features_store_(model_and_features_store), |
| url_loader_factory_(url_loader_factory), |
| pref_service_(pref_service), |
| profile_(profile), |
| clock_(base::DefaultClock::GetInstance()) { |
| DCHECK(model_and_features_store_); |
| |
| Initialize(); |
| } |
| |
| PredictionManager::~PredictionManager() { |
| if (prediction_model_download_manager_) |
| prediction_model_download_manager_->RemoveObserver(this); |
| } |
| |
| void PredictionManager::Initialize() { |
| model_and_features_store_->Initialize( |
| switches::ShouldPurgeModelAndFeaturesStoreOnStartup(), |
| base::BindOnce(&PredictionManager::OnStoreInitialized, |
| ui_weak_ptr_factory_.GetWeakPtr())); |
| } |
| |
| void PredictionManager::RegisterOptimizationTargets( |
| const std::vector< |
| std::pair<proto::OptimizationTarget, absl::optional<proto::Any>>>& |
| optimization_targets_and_metadata) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| if (optimization_targets_and_metadata.empty()) |
| return; |
| |
| base::flat_set<proto::OptimizationTarget> new_optimization_targets; |
| for (const auto& optimization_target_and_metadata : |
| optimization_targets_and_metadata) { |
| proto::OptimizationTarget optimization_target = |
| optimization_target_and_metadata.first; |
| if (optimization_target == proto::OPTIMIZATION_TARGET_UNKNOWN) |
| continue; |
| if (registered_optimization_targets_and_metadata_.contains( |
| optimization_target)) { |
| continue; |
| } |
| registered_optimization_targets_and_metadata_.emplace( |
| optimization_target_and_metadata); |
| new_optimization_targets.insert(optimization_target); |
| if (switches::IsDebugLogsEnabled()) { |
| DVLOG(0) << "OptimizationGuide: Registered new OptimizationTarget: " |
| << proto::OptimizationTarget_Name(optimization_target); |
| } |
| } |
| |
| // Before loading/fetching models and features, the store must be ready. |
| if (!store_is_ready_) |
| return; |
| |
| // Only proceed if there are newly registered targets to load/fetch models and |
| // features for. Otherwise, the registered targets will have models loaded |
| // when the store was initialized. |
| if (new_optimization_targets.empty()) |
| return; |
| |
| // If no fetch is scheduled, maybe schedule one. |
| if (!fetch_timer_.IsRunning()) |
| MaybeScheduleModelFetch(); |
| |
| // Start loading the host model features if they are not already. |
| if (!host_model_features_loaded_) { |
| LoadHostModelFeatures(); |
| return; |
| } |
| // Otherwise, the host model features are loaded, so load prediction models |
| // for any newly registered targets. |
| LoadPredictionModels(new_optimization_targets); |
| } |
| |
| void PredictionManager::AddObserverForOptimizationTargetModel( |
| proto::OptimizationTarget optimization_target, |
| const absl::optional<proto::Any>& model_metadata, |
| OptimizationTargetModelObserver* observer) { |
| DCHECK(registered_observers_for_optimization_targets_.find( |
| optimization_target) == |
| registered_observers_for_optimization_targets_.end()); |
| |
| // As DCHECKS don't run in the wild, just do not register the observer if |
| // something is already registered for the type. Otherwise, file reads may |
| // blow up. |
| if (registered_observers_for_optimization_targets_.find( |
| optimization_target) != |
| registered_observers_for_optimization_targets_.end()) { |
| DLOG(ERROR) << "Did not add observer for optimization target " |
| << static_cast<int>(optimization_target) |
| << " since an observer for the target was already registered"; |
| return; |
| } |
| |
| registered_observers_for_optimization_targets_[optimization_target] |
| .AddObserver(observer); |
| if (switches::IsDebugLogsEnabled()) { |
| DVLOG(0) << "OptimizationGuide: Observer added for OptimizationTarget: " |
| << proto::OptimizationTarget_Name(optimization_target); |
| } |
| |
| // Notify observer of existing model file path. |
| auto model_it = optimization_target_model_info_map_.find(optimization_target); |
| if (model_it != optimization_target_model_info_map_.end()) { |
| observer->OnModelUpdated(optimization_target, *model_it->second); |
| if (switches::IsDebugLogsEnabled()) { |
| std::string debug_msg = |
| "OptimizationGuide: OnModelFileUpdated for OptimizationTarget: "; |
| debug_msg += proto::OptimizationTarget_Name(optimization_target); |
| debug_msg += "\nFile path: "; |
| debug_msg += (*model_it->second).GetModelFilePath().AsUTF8Unsafe(); |
| debug_msg += "\nHas metadata: "; |
| debug_msg += (model_metadata ? "True" : "False"); |
| DVLOG(0) << debug_msg; |
| } |
| } |
| |
| RegisterOptimizationTargets({{optimization_target, model_metadata}}); |
| } |
| |
| void PredictionManager::RemoveObserverForOptimizationTargetModel( |
| proto::OptimizationTarget optimization_target, |
| OptimizationTargetModelObserver* observer) { |
| auto observers_it = |
| registered_observers_for_optimization_targets_.find(optimization_target); |
| if (observers_it == registered_observers_for_optimization_targets_.end()) |
| return; |
| |
| observers_it->second.RemoveObserver(observer); |
| } |
| |
| base::flat_map<std::string, float> PredictionManager::BuildFeatureMap( |
| content::NavigationHandle* navigation_handle, |
| const base::flat_set<std::string>& model_features) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| if (model_features.empty()) |
| return {}; |
| |
| const base::flat_map<std::string, float>* host_model_features = nullptr; |
| |
| std::string host = navigation_handle->GetURL().host(); |
| auto it = host_model_features_cache_.Get(host); |
| if (it != host_model_features_cache_.end()) |
| host_model_features = &(it->second); |
| |
| UMA_HISTOGRAM_BOOLEAN( |
| "OptimizationGuide.PredictionManager.HasHostModelFeaturesForHost", |
| host_model_features != nullptr); |
| |
| // If the feature is not implemented by the client, it is assumed that it is a |
| // host model feature we have in the map. If it is not in either, a default is |
| // created for it. This ensures that the prediction model will have values for |
| // every feature that it requires to be evaluated. |
| std::vector<std::pair<std::string, float>> feature_map; |
| feature_map.reserve(model_features.size()); |
| for (const auto& model_feature : model_features) { |
| absl::optional<float> value; |
| if (host_model_features) { |
| const auto feature_it = host_model_features->find(model_feature); |
| if (feature_it != host_model_features->end()) |
| value = feature_it->second; |
| } |
| feature_map.emplace_back(model_feature, value.value_or(-1.0f)); |
| } |
| return {base::sorted_unique, std::move(feature_map)}; |
| } |
| |
| OptimizationTargetDecision PredictionManager::ShouldTargetNavigation( |
| content::NavigationHandle* navigation_handle, |
| proto::OptimizationTarget optimization_target) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| DCHECK(navigation_handle->GetURL().SchemeIsHTTPOrHTTPS()); |
| |
| if (!registered_optimization_targets_and_metadata_.contains( |
| optimization_target)) { |
| return OptimizationTargetDecision::kUnknown; |
| } |
| |
| ScopedPredictionManagerModelStatusRecorder model_status_recorder( |
| optimization_target); |
| auto it = optimization_target_prediction_model_map_.find(optimization_target); |
| if (it == optimization_target_prediction_model_map_.end()) { |
| if (store_is_ready_ && model_and_features_store_) { |
| OptimizationGuideStore::EntryKey model_entry_key; |
| if (model_and_features_store_->FindPredictionModelEntryKey( |
| optimization_target, &model_entry_key)) { |
| model_status_recorder.set_status( |
| PredictionManagerModelStatus::kStoreAvailableModelNotLoaded); |
| } else { |
| model_status_recorder.set_status( |
| PredictionManagerModelStatus::kStoreAvailableNoModelForTarget); |
| } |
| } else { |
| model_status_recorder.set_status( |
| PredictionManagerModelStatus::kStoreUnavailableModelUnknown); |
| } |
| return OptimizationTargetDecision::kModelNotAvailableOnClient; |
| } |
| model_status_recorder.set_status( |
| PredictionManagerModelStatus::kModelAvailable); |
| PredictionModel* prediction_model = it->second.get(); |
| |
| base::flat_map<std::string, float> feature_map = |
| BuildFeatureMap(navigation_handle, prediction_model->GetModelFeatures()); |
| |
| base::TimeTicks model_evaluation_start_time = base::TimeTicks::Now(); |
| double prediction_score = 0.0; |
| optimization_guide::OptimizationTargetDecision target_decision = |
| prediction_model->Predict(feature_map, &prediction_score); |
| if (target_decision != OptimizationTargetDecision::kUnknown) { |
| UmaHistogramTimes( |
| "OptimizationGuide.PredictionModelEvaluationLatency." + |
| optimization_guide::GetStringNameForOptimizationTarget( |
| optimization_target), |
| base::TimeTicks::Now() - model_evaluation_start_time); |
| } |
| |
| if (features::ShouldOverrideOptimizationTargetDecisionForMetricsPurposes( |
| optimization_target)) { |
| return optimization_guide::OptimizationTargetDecision:: |
| kModelPredictionHoldback; |
| } |
| |
| return target_decision; |
| } |
| |
| base::flat_set<proto::OptimizationTarget> |
| PredictionManager::GetRegisteredOptimizationTargets() const { |
| base::flat_set<proto::OptimizationTarget> optimization_targets; |
| for (const auto& optimization_target_and_metadata : |
| registered_optimization_targets_and_metadata_) { |
| optimization_targets.insert(optimization_target_and_metadata.first); |
| } |
| return optimization_targets; |
| } |
| |
| PredictionModel* PredictionManager::GetPredictionModelForTesting( |
| proto::OptimizationTarget optimization_target) const { |
| auto it = optimization_target_prediction_model_map_.find(optimization_target); |
| if (it != optimization_target_prediction_model_map_.end()) |
| return it->second.get(); |
| return nullptr; |
| } |
| |
| const HostModelFeaturesMRUCache* |
| PredictionManager::GetHostModelFeaturesForTesting() const { |
| return &host_model_features_cache_; |
| } |
| |
| void PredictionManager::SetPredictionModelFetcherForTesting( |
| std::unique_ptr<PredictionModelFetcher> prediction_model_fetcher) { |
| prediction_model_fetcher_ = std::move(prediction_model_fetcher); |
| } |
| |
| void PredictionManager::SetPredictionModelDownloadManagerForTesting( |
| std::unique_ptr<PredictionModelDownloadManager> |
| prediction_model_download_manager) { |
| prediction_model_download_manager_ = |
| std::move(prediction_model_download_manager); |
| } |
| |
| void PredictionManager::FetchModels() { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| if (switches::IsModelOverridePresent()) |
| return; |
| |
| if (!ShouldFetchModels(profile_)) |
| return; |
| |
| // Models and host model features should not be fetched if there are no |
| // optimization targets registered. |
| if (registered_optimization_targets_and_metadata_.empty()) |
| return; |
| |
| if (prediction_model_download_manager_) { |
| bool download_service_available = |
| prediction_model_download_manager_->IsAvailableForDownloads(); |
| base::UmaHistogramBoolean( |
| "OptimizationGuide.PredictionManager." |
| "DownloadServiceAvailabilityBlockedFetch", |
| !download_service_available); |
| if (!download_service_available) { |
| // We cannot download any models from the server, so don't refresh them. |
| return; |
| } |
| |
| prediction_model_download_manager_->CancelAllPendingDownloads(); |
| } |
| |
| // NOTE: ALL PRECONDITIONS FOR THIS FUNCTION MUST BE CHECKED ABOVE THIS LINE. |
| // It is assumed that if we proceed past here, that a fetch will at least be |
| // attempted. |
| |
| std::vector<proto::FieldTrial> active_field_trials; |
| // Active field trials convey some sort of user information, so |
| // ensure that the user has opted into the right permissions before adding |
| // these fields to the request. |
| if (IsUserPermittedToFetchFromRemoteOptimizationGuide( |
| profile_->IsOffTheRecord(), pref_service_)) { |
| google::protobuf::RepeatedPtrField<proto::FieldTrial> current_field_trials = |
| GetActiveFieldTrialsAllowedForFetch(); |
| active_field_trials = std::vector<proto::FieldTrial>( |
| {current_field_trials.begin(), current_field_trials.end()}); |
| } |
| |
| if (!prediction_model_fetcher_) { |
| prediction_model_fetcher_ = std::make_unique<PredictionModelFetcher>( |
| url_loader_factory_, |
| features::GetOptimizationGuideServiceGetModelsURL(), |
| content::GetNetworkConnectionTracker()); |
| } |
| |
| std::vector<proto::ModelInfo> models_info = std::vector<proto::ModelInfo>(); |
| |
| proto::ModelInfo base_model_info; |
| base_model_info.add_supported_model_types(proto::MODEL_TYPE_DECISION_TREE); |
| if (features::IsModelDownloadingEnabled()) { |
| // TODO(crbug/1204614): Remove v2.3* and 2.4 when server supports 2.7. |
| base_model_info.add_supported_model_types(proto::MODEL_TYPE_TFLITE_2_3_0); |
| base_model_info.add_supported_model_types(proto::MODEL_TYPE_TFLITE_2_3_0_1); |
| base_model_info.add_supported_model_types(proto::MODEL_TYPE_TFLITE_2_4); |
| base_model_info.add_supported_model_types(proto::MODEL_TYPE_TFLITE_2_7); |
| } |
| |
| std::string debug_msg; |
| // For now, we will fetch for all registered optimization targets. |
| for (const auto& optimization_target_and_metadata : |
| registered_optimization_targets_and_metadata_) { |
| proto::ModelInfo model_info(base_model_info); |
| model_info.set_optimization_target(optimization_target_and_metadata.first); |
| if (optimization_target_and_metadata.second.has_value()) { |
| *model_info.mutable_model_metadata() = |
| *optimization_target_and_metadata.second; |
| } |
| |
| auto it = optimization_target_prediction_model_map_.find( |
| optimization_target_and_metadata.first); |
| if (it != optimization_target_prediction_model_map_.end()) |
| model_info.set_version(it->second.get()->GetVersion()); |
| |
| auto model_it = optimization_target_model_info_map_.find( |
| optimization_target_and_metadata.first); |
| if (model_it != optimization_target_model_info_map_.end()) |
| model_info.set_version(model_it->second.get()->GetVersion()); |
| |
| models_info.push_back(model_info); |
| if (switches::IsDebugLogsEnabled()) { |
| debug_msg += |
| "\nOptimization Target: " + |
| proto::OptimizationTarget_Name(model_info.optimization_target()); |
| } |
| } |
| if (switches::IsDebugLogsEnabled() && !debug_msg.empty()) { |
| DVLOG(0) << "OptimizationGuide: Fetching models for Optimization Targets: " |
| << debug_msg; |
| } |
| |
| bool fetch_initiated = |
| prediction_model_fetcher_->FetchOptimizationGuideServiceModels( |
| models_info, active_field_trials, proto::CONTEXT_BATCH_UPDATE, |
| g_browser_process->GetApplicationLocale(), |
| base::BindOnce(&PredictionManager::OnModelsFetched, |
| ui_weak_ptr_factory_.GetWeakPtr())); |
| |
| if (fetch_initiated) |
| SetLastModelFetchAttemptTime(clock_->Now()); |
| // Schedule the next fetch regardless since we may not have initiated a fetch |
| // due to a network condition and trying in the next minute to see if that is |
| // unblocked is only a timer firing and not an actual query to the server. |
| ScheduleModelsFetch(); |
| } |
| |
| void PredictionManager::OnModelsFetched( |
| absl::optional<std::unique_ptr<proto::GetModelsResponse>> |
| get_models_response_data) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| if (!get_models_response_data) |
| return; |
| |
| SetLastModelFetchSuccessTime(clock_->Now()); |
| |
| // Update host model features, even if empty so the store metadata |
| // that contains the update time for new models and features to be fetched |
| // from the remote Optimization Guide Service is updated. |
| UpdateHostModelFeatures((*get_models_response_data)->host_model_features()); |
| |
| if ((*get_models_response_data)->models_size() > 0) { |
| // Stash the response so the models can be stored once the host |
| // model features are stored. |
| get_models_response_data_to_store_ = std::move(*get_models_response_data); |
| } |
| |
| fetch_timer_.Stop(); |
| fetch_timer_.Start(FROM_HERE, features::PredictionModelFetchInterval(), this, |
| &PredictionManager::ScheduleModelsFetch); |
| } |
| |
| void PredictionManager::UpdateHostModelFeatures( |
| const google::protobuf::RepeatedPtrField<proto::HostModelFeatures>& |
| host_model_features) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| std::unique_ptr<StoreUpdateData> host_model_features_update_data = |
| StoreUpdateData::CreateHostModelFeaturesStoreUpdateData( |
| /*update_time=*/clock_->Now() + |
| features::PredictionModelFetchInterval(), |
| /*expiry_time=*/clock_->Now() + |
| features::StoredHostModelFeaturesFreshnessDuration()); |
| for (const auto& features : host_model_features) { |
| if (ProcessAndStoreHostModelFeatures(features)) { |
| host_model_features_update_data->CopyHostModelFeaturesIntoUpdateData( |
| features); |
| } |
| } |
| |
| model_and_features_store_->UpdateHostModelFeatures( |
| std::move(host_model_features_update_data), |
| base::BindOnce(&PredictionManager::OnHostModelFeaturesStored, |
| ui_weak_ptr_factory_.GetWeakPtr())); |
| } |
| |
| std::unique_ptr<PredictionModel> PredictionManager::CreatePredictionModel( |
| const proto::PredictionModel& model) const { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| return PredictionModel::Create(model); |
| } |
| |
| void PredictionManager::UpdatePredictionModels( |
| const google::protobuf::RepeatedPtrField<proto::PredictionModel>& |
| prediction_models) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| std::unique_ptr<StoreUpdateData> prediction_model_update_data = |
| StoreUpdateData::CreatePredictionModelStoreUpdateData( |
| clock_->Now() + features::StoredModelsInactiveDuration()); |
| bool has_models_to_update = false; |
| std::string debug_msg; |
| for (const auto& model : prediction_models) { |
| if (model.has_model() && !model.model().download_url().empty()) { |
| if (prediction_model_download_manager_) { |
| GURL download_url(model.model().download_url()); |
| if (download_url.is_valid()) { |
| prediction_model_download_manager_->StartDownload(download_url); |
| } |
| base::UmaHistogramBoolean( |
| "OptimizationGuide.PredictionManager.IsDownloadUrlValid", |
| download_url.is_valid()); |
| if (switches::IsDebugLogsEnabled() && download_url.is_valid()) { |
| debug_msg += "\nOptimization Target: " + |
| proto::OptimizationTarget_Name( |
| model.model_info().optimization_target()); |
| debug_msg += "\nModel Download Was Required."; |
| } |
| } |
| |
| // Skip over models that have a download URL since they will be updated |
| // once the download has completed successfully. |
| continue; |
| } |
| if (!model.has_model()) { |
| // We already have this updated model, so don't update in store. |
| continue; |
| } |
| |
| has_models_to_update = true; |
| // Storing the model regardless of whether the model is valid or not. Model |
| // will be removed from store if it fails to load. |
| prediction_model_update_data->CopyPredictionModelIntoUpdateData(model); |
| RecordModelUpdateVersion(model.model_info()); |
| OnLoadPredictionModel(std::make_unique<proto::PredictionModel>(model)); |
| |
| if (switches::IsDebugLogsEnabled()) { |
| debug_msg += "\nOptimization Target: " + |
| proto::OptimizationTarget_Name( |
| model.model_info().optimization_target()); |
| debug_msg += "\nNew Version: " + |
| base::NumberToString(model.model_info().version()); |
| debug_msg += "\nModel Download Not Required."; |
| } |
| } |
| |
| if (has_models_to_update) { |
| if (switches::IsDebugLogsEnabled() && !debug_msg.empty()) { |
| DVLOG(0) << "OptimizationGuide: Models Fetched for Optimzation Targets: " |
| << debug_msg; |
| } |
| model_and_features_store_->UpdatePredictionModels( |
| std::move(prediction_model_update_data), |
| base::BindOnce(&PredictionManager::OnPredictionModelsStored, |
| ui_weak_ptr_factory_.GetWeakPtr())); |
| } |
| } |
| |
| void PredictionManager::OnModelReady(const proto::PredictionModel& model) { |
| if (switches::IsModelOverridePresent()) |
| return; |
| |
| DCHECK(model.model_info().has_version() && |
| model.model_info().has_optimization_target()); |
| |
| RecordModelUpdateVersion(model.model_info()); |
| if (switches::IsDebugLogsEnabled()) { |
| std::string debug_msg = "Optimization Guide: Model Files Downloaded: "; |
| debug_msg += "\nOptimization Target: " + |
| proto::OptimizationTarget_Name( |
| model.model_info().optimization_target()); |
| debug_msg += |
| "\nNew Version: " + base::NumberToString(model.model_info().version()); |
| DVLOG(0) << debug_msg; |
| } |
| |
| // Store the received model in the store. |
| std::unique_ptr<StoreUpdateData> prediction_model_update_data = |
| StoreUpdateData::CreatePredictionModelStoreUpdateData( |
| clock_->Now() + features::StoredModelsInactiveDuration()); |
| prediction_model_update_data->CopyPredictionModelIntoUpdateData(model); |
| model_and_features_store_->UpdatePredictionModels( |
| std::move(prediction_model_update_data), |
| base::BindOnce(&PredictionManager::OnPredictionModelsStored, |
| ui_weak_ptr_factory_.GetWeakPtr())); |
| |
| if (registered_optimization_targets_and_metadata_.contains( |
| model.model_info().optimization_target())) { |
| OnLoadPredictionModel(std::make_unique<proto::PredictionModel>(model)); |
| } |
| } |
| |
| void PredictionManager::NotifyObserversOfNewModel( |
| proto::OptimizationTarget optimization_target, |
| const ModelInfo& model_info) const { |
| auto observers_it = |
| registered_observers_for_optimization_targets_.find(optimization_target); |
| if (observers_it == registered_observers_for_optimization_targets_.end()) |
| return; |
| |
| for (auto& observer : observers_it->second) { |
| observer.OnModelUpdated(optimization_target, model_info); |
| if (switches::IsDebugLogsEnabled()) { |
| std::string debug_msg = |
| "OptimizationGuide: OnModelFileUpdated for OptimizationTarget: "; |
| debug_msg += proto::OptimizationTarget_Name(optimization_target); |
| debug_msg += "\nFile path: "; |
| debug_msg += model_info.GetModelFilePath().AsUTF8Unsafe(); |
| debug_msg += "\nHas metadata: "; |
| debug_msg += (model_info.GetModelMetadata() ? "True" : "False"); |
| DVLOG(0) << debug_msg; |
| } |
| } |
| } |
| |
| void PredictionManager::OnPredictionModelsStored() { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| LOCAL_HISTOGRAM_BOOLEAN( |
| "OptimizationGuide.PredictionManager.PredictionModelsStored", true); |
| } |
| |
| void PredictionManager::OnHostModelFeaturesStored() { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| LOCAL_HISTOGRAM_BOOLEAN( |
| "OptimizationGuide.PredictionManager.HostModelFeaturesStored", true); |
| |
| if (get_models_response_data_to_store_ && |
| get_models_response_data_to_store_->models_size() > 0) { |
| UpdatePredictionModels(get_models_response_data_to_store_->models()); |
| } |
| // Clear any data remaining in the stored get models response. |
| get_models_response_data_to_store_.reset(); |
| |
| // Purge any expired host model features and inactive models from the store. |
| model_and_features_store_->PurgeExpiredHostModelFeatures(); |
| model_and_features_store_->PurgeInactiveModels(); |
| |
| fetch_timer_.Stop(); |
| ScheduleModelsFetch(); |
| } |
| |
| void PredictionManager::OnStoreInitialized() { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| store_is_ready_ = true; |
| |
| // Create the download manager here if we are allowed to. |
| if (features::IsModelDownloadingEnabled() && !profile_->IsOffTheRecord() && |
| !prediction_model_download_manager_) { |
| prediction_model_download_manager_ = |
| std::make_unique<PredictionModelDownloadManager>( |
| BackgroundDownloadServiceFactory::GetForKey( |
| profile_->GetProfileKey()), |
| base::ThreadPool::CreateSequencedTaskRunner( |
| {base::MayBlock(), base::TaskPriority::BEST_EFFORT, |
| base::TaskShutdownBehavior::CONTINUE_ON_SHUTDOWN})); |
| prediction_model_download_manager_->AddObserver(this); |
| } |
| |
| // Only load host model features if there are optimization targets registered. |
| if (registered_optimization_targets_and_metadata_.empty()) |
| return; |
| |
| // The store is ready so start loading host model features and the models for |
| // the registered optimization targets. Once the host model features are |
| // loaded, prediction models for the registered optimization targets will be |
| // loaded. |
| LoadHostModelFeatures(); |
| |
| MaybeScheduleModelFetch(); |
| } |
| |
| void PredictionManager::LoadHostModelFeatures() { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| // Load the host model features first, each prediction model requires the set |
| // of host model features to be known before creation. |
| model_and_features_store_->LoadAllHostModelFeatures( |
| base::BindOnce(&PredictionManager::OnLoadHostModelFeatures, |
| ui_weak_ptr_factory_.GetWeakPtr())); |
| } |
| |
| void PredictionManager::OnLoadHostModelFeatures( |
| std::unique_ptr<std::vector<proto::HostModelFeatures>> |
| all_host_model_features) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| // If the store returns an empty vector of host model features, the store |
| // contains no host model features. However, the load is otherwise complete |
| // and prediction models can be loaded but they will require no host model |
| // feature information. |
| host_model_features_loaded_ = true; |
| if (all_host_model_features) { |
| for (const auto& host_model_features : *all_host_model_features) |
| ProcessAndStoreHostModelFeatures(host_model_features); |
| } |
| UMA_HISTOGRAM_COUNTS_1000( |
| "OptimizationGuide.PredictionManager.HostModelFeaturesMapSize", |
| host_model_features_cache_.size()); |
| |
| // Load the prediction models for all the registered optimization targets now |
| // that it is not blocked by loading the host model features. |
| LoadPredictionModels(GetRegisteredOptimizationTargets()); |
| } |
| |
| void PredictionManager::LoadPredictionModels( |
| const base::flat_set<proto::OptimizationTarget>& optimization_targets) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| DCHECK(host_model_features_loaded_); |
| |
| if (switches::IsModelOverridePresent()) { |
| for (proto::OptimizationTarget optimization_target : optimization_targets) { |
| std::unique_ptr<proto::PredictionModel> prediction_model = |
| BuildPredictionModelFromCommandLineForOptimizationTarget( |
| optimization_target); |
| OnLoadPredictionModel(std::move(prediction_model)); |
| } |
| return; |
| } |
| |
| OptimizationGuideStore::EntryKey model_entry_key; |
| for (const auto& optimization_target : optimization_targets) { |
| // The prediction model for this optimization target has already been |
| // loaded. |
| if (optimization_target_prediction_model_map_.contains( |
| optimization_target)) { |
| continue; |
| } |
| if (!model_and_features_store_->FindPredictionModelEntryKey( |
| optimization_target, &model_entry_key)) { |
| continue; |
| } |
| model_and_features_store_->LoadPredictionModel( |
| model_entry_key, |
| base::BindOnce(&PredictionManager::OnLoadPredictionModel, |
| ui_weak_ptr_factory_.GetWeakPtr())); |
| } |
| } |
| |
| void PredictionManager::OnLoadPredictionModel( |
| std::unique_ptr<proto::PredictionModel> model) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| if (!model) |
| return; |
| |
| bool success = ProcessAndStoreLoadedModel(*model); |
| OnProcessLoadedModel(*model, success); |
| } |
| |
| void PredictionManager::OnProcessLoadedModel( |
| const proto::PredictionModel& model, |
| bool success) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| if (success) { |
| base::UmaHistogramSparse( |
| "OptimizationGuide.PredictionModelLoadedVersion." + |
| optimization_guide::GetStringNameForOptimizationTarget( |
| model.model_info().optimization_target()), |
| model.model_info().version()); |
| return; |
| } |
| |
| // Remove model from store if it exists. |
| OptimizationGuideStore::EntryKey model_entry_key; |
| if (model_and_features_store_->FindPredictionModelEntryKey( |
| model.model_info().optimization_target(), &model_entry_key)) { |
| LOCAL_HISTOGRAM_BOOLEAN( |
| "OptimizationGuide.PredictionModelRemoved." + |
| optimization_guide::GetStringNameForOptimizationTarget( |
| model.model_info().optimization_target()), |
| true); |
| model_and_features_store_->RemovePredictionModelFromEntryKey( |
| model_entry_key); |
| } |
| } |
| |
| bool PredictionManager::ProcessAndStoreLoadedModel( |
| const proto::PredictionModel& model) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| if (!model.model_info().has_optimization_target()) |
| return false; |
| if (!model.model_info().has_version()) |
| return false; |
| if (!model.has_model()) |
| return false; |
| if (!registered_optimization_targets_and_metadata_.contains( |
| model.model_info().optimization_target())) { |
| return false; |
| } |
| |
| ScopedPredictionModelConstructionAndValidationRecorder |
| prediction_model_recorder(model.model_info().optimization_target()); |
| std::unique_ptr<ModelInfo> model_info = ModelInfo::Create(model); |
| std::unique_ptr<PredictionModel> prediction_model = |
| model_info ? nullptr : CreatePredictionModel(model); |
| if (!model_info && !prediction_model) { |
| prediction_model_recorder.set_is_valid(false); |
| return false; |
| } |
| |
| proto::OptimizationTarget optimization_target = |
| model.model_info().optimization_target(); |
| |
| // See if we should update the loaded model. |
| if (!ShouldUpdateStoredModelForTarget(optimization_target, |
| model.model_info().version())) { |
| return true; |
| } |
| |
| // Update prediction model file if that is what we have loaded. |
| if (model_info) { |
| StoreLoadedModelInfo(optimization_target, std::move(model_info)); |
| } |
| |
| // Update prediction model if that is what we have loaded. |
| if (prediction_model) { |
| StoreLoadedPredictionModel(optimization_target, |
| std::move(prediction_model)); |
| } |
| |
| return true; |
| } |
| |
| bool PredictionManager::ShouldUpdateStoredModelForTarget( |
| proto::OptimizationTarget optimization_target, |
| int64_t new_version) const { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| auto model_meta_it = |
| optimization_target_model_info_map_.find(optimization_target); |
| if (model_meta_it != optimization_target_model_info_map_.end()) |
| return model_meta_it->second->GetVersion() != new_version; |
| |
| auto model_it = |
| optimization_target_prediction_model_map_.find(optimization_target); |
| if (model_it != optimization_target_prediction_model_map_.end()) |
| return model_it->second->GetVersion() != new_version; |
| |
| return true; |
| } |
| |
| void PredictionManager::StoreLoadedModelInfo( |
| proto::OptimizationTarget optimization_target, |
| std::unique_ptr<ModelInfo> model_info) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| DCHECK(model_info); |
| |
| bool has_model_for_target = |
| optimization_target_prediction_model_map_.contains(optimization_target); |
| RecordModelTypeChanged(optimization_target, has_model_for_target); |
| if (has_model_for_target) { |
| // Remove prediction model if we received the update as a model file. In |
| // practice, this shouldn't happen. |
| optimization_target_prediction_model_map_.erase(optimization_target); |
| } |
| |
| // Notify observers of new model file path. |
| base::ThreadTaskRunnerHandle::Get()->PostTask( |
| FROM_HERE, base::BindOnce(&PredictionManager::NotifyObserversOfNewModel, |
| ui_weak_ptr_factory_.GetWeakPtr(), |
| optimization_target, *model_info)); |
| |
| optimization_target_model_info_map_.insert_or_assign(optimization_target, |
| std::move(model_info)); |
| } |
| |
| void PredictionManager::StoreLoadedPredictionModel( |
| proto::OptimizationTarget optimization_target, |
| std::unique_ptr<PredictionModel> prediction_model) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| bool has_model_file_for_target = |
| optimization_target_model_info_map_.contains(optimization_target); |
| RecordModelTypeChanged(optimization_target, has_model_file_for_target); |
| if (has_model_file_for_target) { |
| // Remove prediction model file from map if we received the update as a |
| // PredictionModel. In practice, this shouldn't happen. |
| optimization_target_model_info_map_.erase(optimization_target); |
| } |
| optimization_target_prediction_model_map_.insert_or_assign( |
| optimization_target, std::move(prediction_model)); |
| } |
| |
| bool PredictionManager::ProcessAndStoreHostModelFeatures( |
| const proto::HostModelFeatures& host_model_features) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| if (!host_model_features.has_host()) |
| return false; |
| if (host_model_features.model_features_size() == 0) |
| return false; |
| |
| base::flat_map<std::string, float> model_features_for_host; |
| model_features_for_host.reserve(host_model_features.model_features_size()); |
| for (const auto& model_feature : host_model_features.model_features()) { |
| if (!model_feature.has_feature_name()) |
| continue; |
| switch (model_feature.feature_value_case()) { |
| case proto::ModelFeature::kDoubleValue: |
| // Loss of precision from double is acceptable for features supported |
| // by the prediction models. |
| model_features_for_host.emplace( |
| model_feature.feature_name(), |
| static_cast<float>(model_feature.double_value())); |
| break; |
| case proto::ModelFeature::kInt64Value: |
| model_features_for_host.emplace( |
| model_feature.feature_name(), |
| static_cast<float>(model_feature.int64_value())); |
| break; |
| case proto::ModelFeature::FEATURE_VALUE_NOT_SET: |
| NOTREACHED(); |
| break; |
| } |
| } |
| if (model_features_for_host.empty()) |
| return false; |
| host_model_features_cache_.Put(host_model_features.host(), |
| model_features_for_host); |
| return true; |
| } |
| |
| void PredictionManager::MaybeScheduleModelFetch() { |
| if (!ShouldFetchModels(profile_)) |
| return; |
| |
| if (switches::ShouldOverrideFetchModelsAndFeaturesTimer()) { |
| fetch_timer_.Start(FROM_HERE, base::Seconds(1), this, |
| &PredictionManager::FetchModels); |
| } else { |
| ScheduleModelsFetch(); |
| } |
| } |
| |
| base::Time PredictionManager::GetLastFetchAttemptTime() const { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| return base::Time::FromDeltaSinceWindowsEpoch(base::Microseconds( |
| pref_service_->GetInt64(prefs::kModelAndFeaturesLastFetchAttempt))); |
| } |
| |
| base::Time PredictionManager::GetLastFetchSuccessTime() const { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| return base::Time::FromDeltaSinceWindowsEpoch(base::Microseconds( |
| pref_service_->GetInt64(prefs::kModelLastFetchSuccess))); |
| } |
| |
| void PredictionManager::ScheduleModelsFetch() { |
| DCHECK(!fetch_timer_.IsRunning()); |
| DCHECK(store_is_ready_); |
| const base::TimeDelta time_until_update_time = |
| GetLastFetchSuccessTime() + features::PredictionModelFetchInterval() - |
| clock_->Now(); |
| const base::TimeDelta time_until_retry = |
| GetLastFetchAttemptTime() + features::PredictionModelFetchRetryDelay() - |
| clock_->Now(); |
| base::TimeDelta fetcher_delay = |
| std::max(time_until_update_time, time_until_retry); |
| if (fetcher_delay <= base::TimeDelta()) { |
| fetch_timer_.Start(FROM_HERE, RandomFetchDelay(), this, |
| &PredictionManager::FetchModels); |
| return; |
| } |
| fetch_timer_.Start(FROM_HERE, fetcher_delay, this, |
| &PredictionManager::ScheduleModelsFetch); |
| } |
| |
| void PredictionManager::SetLastModelFetchAttemptTime( |
| base::Time last_attempt_time) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| pref_service_->SetInt64( |
| prefs::kModelAndFeaturesLastFetchAttempt, |
| last_attempt_time.ToDeltaSinceWindowsEpoch().InMicroseconds()); |
| } |
| |
| void PredictionManager::SetLastModelFetchSuccessTime( |
| base::Time last_success_time) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| pref_service_->SetInt64( |
| prefs::kModelLastFetchSuccess, |
| last_success_time.ToDeltaSinceWindowsEpoch().InMicroseconds()); |
| } |
| |
| void PredictionManager::SetClockForTesting(const base::Clock* clock) { |
| clock_ = clock; |
| } |
| |
| void PredictionManager::ClearHostModelFeatures() { |
| host_model_features_cache_.Clear(); |
| if (model_and_features_store_) |
| model_and_features_store_->ClearHostModelFeaturesFromDatabase(); |
| } |
| |
| absl::optional<base::flat_map<std::string, float>> |
| PredictionManager::GetHostModelFeaturesForHost(const std::string& host) const { |
| auto it = host_model_features_cache_.Peek(host); |
| if (it == host_model_features_cache_.end()) |
| return absl::nullopt; |
| return it->second; |
| } |
| |
| void PredictionManager::OverrideTargetModelForTesting( |
| proto::OptimizationTarget optimization_target, |
| std::unique_ptr<ModelInfo> model_info) { |
| if (!model_info) { |
| return; |
| } |
| |
| ModelInfo model_info_copy = *model_info; |
| |
| optimization_target_model_info_map_.insert_or_assign(optimization_target, |
| std::move(model_info)); |
| |
| NotifyObserversOfNewModel(optimization_target, model_info_copy); |
| } |
| |
| } // namespace optimization_guide |