| // Copyright 2019 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "components/optimization_guide/core/prediction_manager.h" |
| |
| #include <memory> |
| #include <utility> |
| #include <vector> |
| |
| #include "base/containers/contains.h" |
| #include "base/containers/fixed_flat_set.h" |
| #include "base/containers/flat_map.h" |
| #include "base/containers/flat_set.h" |
| #include "base/containers/flat_tree.h" |
| #include "base/functional/callback.h" |
| #include "base/metrics/histogram_functions.h" |
| #include "base/metrics/histogram_macros.h" |
| #include "base/metrics/histogram_macros_local.h" |
| #include "base/not_fatal_until.h" |
| #include "base/observer_list.h" |
| #include "base/path_service.h" |
| #include "base/sequence_checker.h" |
| #include "base/task/sequenced_task_runner.h" |
| #include "base/task/single_thread_task_runner.h" |
| #include "base/task/thread_pool.h" |
| #include "base/time/time.h" |
| #include "base/uuid.h" |
| #include "components/optimization_guide/core/model_info.h" |
| #include "components/optimization_guide/core/model_util.h" |
| #include "components/optimization_guide/core/optimization_guide_constants.h" |
| #include "components/optimization_guide/core/optimization_guide_enums.h" |
| #include "components/optimization_guide/core/optimization_guide_features.h" |
| #include "components/optimization_guide/core/optimization_guide_logger.h" |
| #include "components/optimization_guide/core/optimization_guide_permissions_util.h" |
| #include "components/optimization_guide/core/optimization_guide_prefs.h" |
| #include "components/optimization_guide/core/optimization_guide_store.h" |
| #include "components/optimization_guide/core/optimization_guide_switches.h" |
| #include "components/optimization_guide/core/optimization_guide_util.h" |
| #include "components/optimization_guide/core/optimization_target_model_observer.h" |
| #include "components/optimization_guide/core/prediction_model_download_manager.h" |
| #include "components/optimization_guide/core/prediction_model_fetcher_impl.h" |
| #include "components/optimization_guide/core/prediction_model_override.h" |
| #include "components/optimization_guide/core/prediction_model_store.h" |
| #include "components/optimization_guide/core/store_update_data.h" |
| #include "components/optimization_guide/optimization_guide_internals/webui/optimization_guide_internals.mojom.h" |
| #include "components/optimization_guide/proto/models.pb.h" |
| #include "components/prefs/pref_service.h" |
| #include "google_apis/google_api_keys.h" |
| #include "mojo/public/cpp/bindings/remote.h" |
| #include "services/network/public/cpp/shared_url_loader_factory.h" |
| |
| namespace optimization_guide { |
| |
| namespace { |
| |
| proto::ModelCacheKey GetModelCacheKey(const std::string& locale) { |
| proto::ModelCacheKey model_cache_key; |
| model_cache_key.set_locale(locale); |
| return model_cache_key; |
| } |
| |
| // Util class for recording the construction and validation of a prediction |
| // model. The result is recorded when it goes out of scope and its destructor is |
| // called. |
| class ScopedPredictionModelConstructionAndValidationRecorder { |
| public: |
| explicit ScopedPredictionModelConstructionAndValidationRecorder( |
| proto::OptimizationTarget optimization_target) |
| : validation_start_time_(base::TimeTicks::Now()), |
| optimization_target_(optimization_target) {} |
| |
| ~ScopedPredictionModelConstructionAndValidationRecorder() { |
| base::UmaHistogramBoolean( |
| "OptimizationGuide.IsPredictionModelValid." + |
| GetStringNameForOptimizationTarget(optimization_target_), |
| is_valid_); |
| |
| // Only record the timing if the model is valid and was able to be |
| // constructed. |
| if (is_valid_) { |
| base::TimeDelta validation_latency = |
| base::TimeTicks::Now() - validation_start_time_; |
| base::UmaHistogramTimes( |
| "OptimizationGuide.PredictionModelValidationLatency." + |
| GetStringNameForOptimizationTarget(optimization_target_), |
| validation_latency); |
| } |
| } |
| |
| void set_is_valid(bool is_valid) { is_valid_ = is_valid; } |
| |
| private: |
| bool is_valid_ = true; |
| const base::TimeTicks validation_start_time_; |
| const proto::OptimizationTarget optimization_target_; |
| }; |
| |
| void RecordModelUpdateVersion(const proto::ModelInfo& model_info) { |
| base::UmaHistogramSparse( |
| "OptimizationGuide.PredictionModelUpdateVersion." + |
| GetStringNameForOptimizationTarget(model_info.optimization_target()), |
| model_info.version()); |
| } |
| |
| void RecordLifecycleState(proto::OptimizationTarget optimization_target, |
| ModelDeliveryEvent event) { |
| base::UmaHistogramEnumeration( |
| "OptimizationGuide.PredictionManager.ModelDeliveryEvents." + |
| GetStringNameForOptimizationTarget(optimization_target), |
| event); |
| } |
| |
| // Returns whether models should be fetched from the |
| // remote Optimization Guide Service. |
| bool ShouldFetchModels(bool off_the_record, |
| bool component_updates_enabled, |
| bool should_check_google_api_key_configuration) { |
| return features::IsRemoteFetchingEnabled() && !off_the_record && |
| features::IsModelDownloadingEnabled() && component_updates_enabled && |
| (!should_check_google_api_key_configuration || |
| google_apis::HasAPIKeyConfigured()); |
| } |
| |
| // Returns whether the model metadata proto is on the server allowlist. |
| bool IsModelMetadataTypeOnServerAllowlist(const proto::Any& model_metadata) { |
| return model_metadata.type_url() == |
| "type.googleapis.com/" |
| "google.internal.chrome.optimizationguide.v1." |
| "OnDeviceTailSuggestModelMetadata" || |
| model_metadata.type_url() == |
| "type.googleapis.com/" |
| "google.internal.chrome.optimizationguide.v1." |
| "PageTopicsModelMetadata" || |
| model_metadata.type_url() == |
| "type.googleapis.com/" |
| "google.internal.chrome.optimizationguide.v1." |
| "SegmentationModelMetadata" || |
| model_metadata.type_url() == |
| "type.googleapis.com/" |
| "google.privacy.webpermissionpredictions.v1." |
| "WebPermissionPredictionsModelMetadata" || |
| model_metadata.type_url() == |
| "type.googleapis.com/" |
| "google.internal.chrome.optimizationguide.v1." |
| "ClientSidePhishingModelMetadata" || |
| model_metadata.type_url() == |
| "type.googleapis.com/" |
| "lens.prime.csc.VisualSearchModelMetadata" || |
| model_metadata.type_url() == |
| "type.googleapis.com/" |
| "google.internal.chrome.optimizationguide.v1." |
| "OnDeviceBaseModelMetadata" || |
| model_metadata.type_url() == |
| "type.googleapis.com/" |
| "google.internal.chrome.optimizationguide.v1." |
| "AutofillFieldClassificationModelMetadata" || |
| model_metadata.type_url() == |
| "type.googleapis.com/" |
| "google.internal.chrome.optimizationguide.v1." |
| "AutocompleteScoringModelMetadata" || |
| model_metadata.type_url() == |
| "type.googleapis.com/" |
| "google.privacy.webpermissionpredictions.v1." |
| "WebPermissionPredictionsClientInfo"; |
| } |
| |
| void RecordModelAvailableAtRegistration( |
| proto::OptimizationTarget optimization_target, |
| bool model_available_at_registration) { |
| base::UmaHistogramBoolean( |
| "OptimizationGuide.PredictionManager.ModelAvailableAtRegistration." + |
| GetStringNameForOptimizationTarget(optimization_target), |
| model_available_at_registration); |
| } |
| |
| } // namespace |
| |
| PredictionManager::ModelRegistrationInfo::ModelRegistrationInfo( |
| std::optional<proto::Any> metadata) |
| : metadata(metadata) {} |
| |
| PredictionManager::ModelRegistrationInfo::~ModelRegistrationInfo() = default; |
| |
| PredictionManager::PredictionManager( |
| PredictionModelStore* prediction_model_store, |
| scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory, |
| PrefService* pref_service, |
| bool off_the_record, |
| const std::string& application_locale, |
| const base::FilePath& models_dir_path, |
| OptimizationGuideLogger* optimization_guide_logger, |
| BackgroundDownloadServiceProvider background_download_service_provider, |
| ComponentUpdatesEnabledProvider component_updates_enabled_provider) |
| : prediction_model_download_manager_(nullptr), |
| prediction_model_store_(prediction_model_store), |
| url_loader_factory_(url_loader_factory), |
| optimization_guide_logger_(optimization_guide_logger), |
| component_updates_enabled_provider_(component_updates_enabled_provider), |
| prediction_model_fetch_timer_( |
| pref_service, |
| base::BindRepeating( |
| &PredictionManager::FetchModels, |
| // Its safe to use `base::Unretained(this)` here since |
| // `prediction_model_fetch_timer_` is owned by `this`. |
| base::Unretained(this))), |
| off_the_record_(off_the_record), |
| application_locale_(application_locale), |
| model_cache_key_(GetModelCacheKey(application_locale_)), |
| models_dir_path_(models_dir_path), |
| should_check_google_api_key_configuration_( |
| !switches::ShouldSkipGoogleApiKeyConfigurationCheck()) { |
| DCHECK(prediction_model_store_); |
| Initialize(std::move(background_download_service_provider)); |
| } |
| |
| PredictionManager::~PredictionManager() { |
| if (prediction_model_download_manager_) { |
| prediction_model_download_manager_->RemoveObserver(this); |
| } |
| } |
| |
| void PredictionManager::Initialize( |
| BackgroundDownloadServiceProvider background_download_service_provider) { |
| LoadPredictionModels(GetRegisteredOptimizationTargets()); |
| LOCAL_HISTOGRAM_BOOLEAN( |
| "OptimizationGuide.PredictionManager.StoreInitialized", true); |
| } |
| |
| void PredictionManager::AddObserverForOptimizationTargetModel( |
| proto::OptimizationTarget optimization_target, |
| const std::optional<proto::Any>& model_metadata, |
| OptimizationTargetModelObserver* observer) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| // A limited number of targets support multiple registrations. In general |
| // multiple registrations are disallowed to mitigate the risk of subtle, |
| // conflicting behavior between two different uses of the same model file. If |
| // adding a target to this set, please document below why it's necessary. |
| constexpr auto kAllowedMultipleRegistrations = |
| base::MakeFixedFlatSet<proto::OptimizationTarget>({ |
| // In addition to use by Translate's language detection features, this |
| // model is also needed by the On-Device Model service process, and |
| // ModelExecutionManager monitors for updates on its behalf. |
| proto::OptimizationTarget::OPTIMIZATION_TARGET_LANGUAGE_DETECTION, |
| }); |
| |
| DCHECK(base::Contains(kAllowedMultipleRegistrations, optimization_target) || |
| !base::Contains(model_registration_info_map_, optimization_target)); |
| DCHECK(!model_metadata || |
| IsModelMetadataTypeOnServerAllowlist(*model_metadata)); |
| |
| // As DCHECKS don't run in the wild, just do not register the observer if |
| // something is already registered for the type. Otherwise, file reads may |
| // blow up. |
| if (!base::Contains(kAllowedMultipleRegistrations, optimization_target) && |
| base::Contains(model_registration_info_map_, optimization_target)) { |
| DLOG(ERROR) << "Did not add observer for optimization target " |
| << static_cast<int>(optimization_target) |
| << " since an observer for the target was already registered "; |
| return; |
| } |
| |
| auto [it, registered] = model_registration_info_map_.emplace( |
| std::piecewise_construct, std::forward_as_tuple(optimization_target), |
| std::forward_as_tuple(model_metadata)); |
| DCHECK(registered || |
| base::Contains(kAllowedMultipleRegistrations, optimization_target)); |
| it->second.model_observers.AddObserver(observer); |
| if (optimization_guide_logger_->ShouldEnableDebugLogs()) { |
| OPTIMIZATION_GUIDE_LOGGER( |
| optimization_guide_common::mojom::LogSource::MODEL_MANAGEMENT, |
| optimization_guide_logger_) |
| << "Observer added for OptimizationTarget: " << optimization_target; |
| } |
| |
| // Notify observer of existing model file path. |
| auto model_it = optimization_target_model_info_map_.find(optimization_target); |
| if (model_it != optimization_target_model_info_map_.end()) { |
| observer->OnModelUpdated(optimization_target, *model_it->second); |
| if (optimization_guide_logger_->ShouldEnableDebugLogs()) { |
| OPTIMIZATION_GUIDE_LOGGER( |
| optimization_guide_common::mojom::LogSource::MODEL_MANAGEMENT, |
| optimization_guide_logger_) |
| << "OnModelFileUpdated for OptimizationTarget: " |
| << optimization_target << "\nFile path: " |
| << model_it->second->GetModelFilePath().AsUTF8Unsafe() |
| << "\nHas metadata: " << (model_metadata ? "True" : "False"); |
| } |
| RecordLifecycleState(optimization_target, |
| ModelDeliveryEvent::kModelDeliveredAtRegistration); |
| } |
| base::UmaHistogramMediumTimes( |
| "OptimizationGuide.PredictionManager.RegistrationTimeSinceServiceInit." + |
| GetStringNameForOptimizationTarget(optimization_target), |
| !init_time_.is_null() ? base::TimeTicks::Now() - init_time_ |
| : base::TimeDelta()); |
| |
| if (optimization_guide_logger_->ShouldEnableDebugLogs()) { |
| OPTIMIZATION_GUIDE_LOGGER( |
| optimization_guide_common::mojom::LogSource::MODEL_MANAGEMENT, |
| optimization_guide_logger_) |
| << "Registered new OptimizationTarget: " << optimization_target; |
| } |
| |
| if (ShouldFetchModels(off_the_record_, |
| component_updates_enabled_provider_.Run(), |
| should_check_google_api_key_configuration_)) { |
| prediction_model_fetch_timer_.ScheduleFetchOnModelRegistration(); |
| } |
| |
| // Otherwise, load prediction models for any newly registered targets. |
| LoadPredictionModels({optimization_target}); |
| } |
| |
| void PredictionManager::RemoveObserverForOptimizationTargetModel( |
| proto::OptimizationTarget optimization_target, |
| OptimizationTargetModelObserver* observer) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| auto registration_info = |
| model_registration_info_map_.find(optimization_target); |
| CHECK(registration_info != model_registration_info_map_.end(), |
| base::NotFatalUntil::M130); |
| |
| auto& observers = registration_info->second.model_observers; |
| DCHECK(observers.HasObserver(observer)); |
| observers.RemoveObserver(observer); |
| if (observers.empty()) { |
| model_registration_info_map_.erase(registration_info); |
| } |
| } |
| |
| base::flat_set<proto::OptimizationTarget> |
| PredictionManager::GetRegisteredOptimizationTargets() const { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| base::flat_set<proto::OptimizationTarget> optimization_targets; |
| for (const auto& registration_info : model_registration_info_map_) { |
| optimization_targets.insert(registration_info.first); |
| } |
| return optimization_targets; |
| } |
| |
| void PredictionManager::SetPredictionModelFetcherForTesting( |
| std::unique_ptr<PredictionModelFetcher> prediction_model_fetcher) { |
| prediction_model_fetcher_ = std::move(prediction_model_fetcher); |
| } |
| |
| void PredictionManager::SetPredictionModelDownloadManagerForTesting( |
| std::unique_ptr<PredictionModelDownloadManager> |
| prediction_model_download_manager) { |
| prediction_model_download_manager_ = |
| std::move(prediction_model_download_manager); |
| } |
| |
| void PredictionManager::FetchModels() { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| // The histogram that gets recorded here is used for integration tests that |
| // pass in a model override. For simplicity, we place the recording of this |
| // histogram here rather than somewhere else earlier in the session |
| // initialization flow since the model engine version needs to continuously be |
| // updated for the fetch. |
| proto::ModelInfo base_model_info; |
| // There should only be one supported model engine version at a time. |
| base_model_info.add_supported_model_engine_versions( |
| proto::MODEL_ENGINE_VERSION_TFLITE_2_20_0); |
| // This histogram is used for integration tests. Do not remove. |
| // Update this to be 10000 if/when we exceed 100 model engine versions. |
| LOCAL_HISTOGRAM_COUNTS_100( |
| "OptimizationGuide.PredictionManager.SupportedModelEngineVersion", |
| static_cast<int>( |
| *base_model_info.supported_model_engine_versions().begin())); |
| |
| if (!ShouldFetchModels(off_the_record_, |
| component_updates_enabled_provider_.Run(), |
| should_check_google_api_key_configuration_)) { |
| return; |
| } |
| |
| if (prediction_model_fetch_timer_.IsFirstModelFetch()) { |
| DCHECK(!init_time_.is_null()); |
| base::UmaHistogramMediumTimes( |
| "OptimizationGuide.PredictionManager.FirstModelFetchSinceServiceInit", |
| base::TimeTicks::Now() - init_time_); |
| } |
| |
| // Models should not be fetched if there are no optimization targets |
| // registered. |
| if (model_registration_info_map_.empty()) { |
| return; |
| } |
| |
| // We should have already created a prediction model download manager if we |
| // initiated the fetching of models. |
| DCHECK(prediction_model_download_manager_); |
| if (prediction_model_download_manager_) { |
| bool download_service_available = |
| prediction_model_download_manager_->IsAvailableForDownloads(); |
| base::UmaHistogramBoolean( |
| "OptimizationGuide.PredictionManager." |
| "DownloadServiceAvailabilityBlockedFetch", |
| !download_service_available); |
| if (!download_service_available) { |
| for (const auto& registration_info : model_registration_info_map_) { |
| RecordLifecycleState(registration_info.first, |
| ModelDeliveryEvent::kDownloadServiceUnavailable); |
| } |
| // We cannot download any models from the server, so don't refresh them. |
| return; |
| } |
| |
| prediction_model_download_manager_->CancelAllPendingDownloads(); |
| } |
| |
| std::vector<proto::ModelInfo> models_info = std::vector<proto::ModelInfo>(); |
| models_info.reserve(model_registration_info_map_.size()); |
| |
| // For now, we will fetch for all registered optimization targets. |
| auto overrides = PredictionModelOverrides::ParseFromCommandLine( |
| base::CommandLine::ForCurrentProcess()); |
| for (const auto& registration_info : model_registration_info_map_) { |
| if (overrides.Get(registration_info.first)) { |
| // Do not download models that were overridden. |
| continue; |
| } |
| |
| proto::ModelInfo model_info(base_model_info); |
| model_info.set_optimization_target(registration_info.first); |
| if (registration_info.second.metadata) { |
| *model_info.mutable_model_metadata() = *registration_info.second.metadata; |
| } |
| |
| auto model_it = |
| optimization_target_model_info_map_.find(registration_info.first); |
| if (model_it != optimization_target_model_info_map_.end()) { |
| model_info.set_version(model_it->second.get()->GetVersion()); |
| } |
| |
| models_info.push_back(model_info); |
| if (optimization_guide_logger_->ShouldEnableDebugLogs()) { |
| OPTIMIZATION_GUIDE_LOGGER( |
| optimization_guide_common::mojom::LogSource::MODEL_MANAGEMENT, |
| optimization_guide_logger_) |
| << "Fetching models for Optimization Target " |
| << model_info.optimization_target(); |
| } |
| RecordLifecycleState(registration_info.first, |
| ModelDeliveryEvent::kGetModelsRequest); |
| } |
| if (models_info.empty()) { |
| return; |
| } |
| |
| // NOTE: ALL PRECONDITIONS FOR THIS FUNCTION MUST BE CHECKED ABOVE THIS LINE. |
| // It is assumed that if we proceed past here, that a fetch will at least be |
| // attempted. |
| |
| if (!prediction_model_fetcher_) { |
| prediction_model_fetcher_ = std::make_unique<PredictionModelFetcherImpl>( |
| url_loader_factory_, |
| features::GetOptimizationGuideServiceGetModelsURL()); |
| } |
| |
| bool fetch_initiated = |
| prediction_model_fetcher_->FetchOptimizationGuideServiceModels( |
| models_info, proto::CONTEXT_BATCH_UPDATE_MODELS, application_locale_, |
| base::BindOnce(&PredictionManager::OnModelsFetched, |
| ui_weak_ptr_factory_.GetWeakPtr(), models_info)); |
| |
| if (fetch_initiated) { |
| prediction_model_fetch_timer_.NotifyModelFetchAttempt(); |
| } |
| // Schedule the next fetch regardless since we may not have initiated a fetch |
| // due to a network condition and trying in the next minute to see if that is |
| // unblocked is only a timer firing and not an actual query to the server. |
| prediction_model_fetch_timer_.SchedulePeriodicModelsFetch(); |
| } |
| |
| void PredictionManager::OnModelsFetched( |
| const std::vector<proto::ModelInfo> models_request_info, |
| std::optional<std::unique_ptr<proto::GetModelsResponse>> |
| get_models_response_data) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| if (!get_models_response_data) { |
| for (const auto& model_info : models_request_info) { |
| RecordLifecycleState(model_info.optimization_target(), |
| ModelDeliveryEvent::kGetModelsResponseFailure); |
| } |
| return; |
| } |
| |
| if ((*get_models_response_data)->models_size() > 0 || |
| models_request_info.size() > 0) { |
| UpdatePredictionModels(models_request_info, |
| (*get_models_response_data)->models()); |
| } |
| |
| prediction_model_fetch_timer_.NotifyModelFetchSuccess(); |
| prediction_model_fetch_timer_.Stop(); |
| prediction_model_fetch_timer_.SchedulePeriodicModelsFetch(); |
| } |
| |
| void PredictionManager::UpdateModelMetadata( |
| const proto::PredictionModel& model) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| // Model update is needed when download URL is set, which indicates the model |
| // has changed. |
| if (model.model().download_url().empty()) { |
| return; |
| } |
| if (!model.model_info().has_model_cache_key()) { |
| return; |
| } |
| if (optimization_guide_logger_->ShouldEnableDebugLogs()) { |
| OPTIMIZATION_GUIDE_LOGGER( |
| optimization_guide_common::mojom::LogSource::MODEL_MANAGEMENT, |
| optimization_guide_logger_) |
| << "Optimization Target: " << model.model_info().optimization_target() |
| << " for locale: " << application_locale_ |
| << " sharing models with locale: " |
| << model.model_info().model_cache_key().locale(); |
| } |
| prediction_model_store_->UpdateModelCacheKeyMapping( |
| model.model_info().optimization_target(), model_cache_key_, |
| model.model_info().model_cache_key()); |
| } |
| |
| bool PredictionManager::ShouldDownloadNewModel( |
| const proto::PredictionModel& model) const { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| // No download needed if URL is not set. |
| if (model.model().download_url().empty()) { |
| return false; |
| } |
| // Though the server set the download URL indicating the model is old or does |
| // not exist in client, the same version model could exist in the store, if |
| // the model is shared across different profile characteristics, based on |
| // ModelCacheKey. So, download only when the same version is not found in the |
| // store. |
| return !prediction_model_store_->HasModelWithVersion( |
| model.model_info().optimization_target(), model_cache_key_, |
| model.model_info().version()); |
| } |
| |
| void PredictionManager::StartModelDownload( |
| proto::OptimizationTarget optimization_target, |
| const GURL& download_url) { |
| // We should only be downloading models and updating the store for |
| // on-the-record profiles and after the store has been initialized. |
| DCHECK(prediction_model_download_manager_); |
| if (!prediction_model_download_manager_) { |
| return; |
| } |
| if (download_url.is_valid()) { |
| prediction_model_download_manager_->StartDownload(download_url, |
| optimization_target); |
| if (optimization_guide_logger_->ShouldEnableDebugLogs()) { |
| OPTIMIZATION_GUIDE_LOGGER( |
| optimization_guide_common::mojom::LogSource::MODEL_MANAGEMENT, |
| optimization_guide_logger_) |
| << "Model download started for Optimization Target: " |
| << optimization_target << " download URL: " << download_url; |
| } |
| } |
| RecordLifecycleState(optimization_target, |
| download_url.is_valid() |
| ? ModelDeliveryEvent::kDownloadServiceRequest |
| : ModelDeliveryEvent::kDownloadURLInvalid); |
| base::UmaHistogramBoolean( |
| "OptimizationGuide.PredictionManager.IsDownloadUrlValid." + |
| GetStringNameForOptimizationTarget(optimization_target), |
| download_url.is_valid()); |
| } |
| |
| void PredictionManager::MaybeDownloadOrUpdatePredictionModel( |
| proto::OptimizationTarget optimization_target, |
| const proto::PredictionModel& get_models_response_model, |
| std::unique_ptr<proto::PredictionModel> loaded_model) { |
| if (!loaded_model) { |
| // Model load failed, redownload the model. |
| RecordLifecycleState(optimization_target, |
| ModelDeliveryEvent::kModelLoadFailed); |
| DCHECK(!get_models_response_model.model().download_url().empty()); |
| StartModelDownload(optimization_target, |
| GURL(get_models_response_model.model().download_url())); |
| return; |
| } |
| prediction_model_store_->UpdateMetadataForExistingModel( |
| optimization_target, model_cache_key_, |
| get_models_response_model.model_info()); |
| OnLoadPredictionModel(optimization_target, |
| /*record_availability_metrics=*/false, |
| std::move(loaded_model)); |
| } |
| |
| void PredictionManager::UpdatePredictionModels( |
| const std::vector<proto::ModelInfo>& models_request_info, |
| const google::protobuf::RepeatedPtrField<proto::PredictionModel>& |
| prediction_models) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| std::set<proto::OptimizationTarget> received_optimization_targets; |
| for (const auto& model : prediction_models) { |
| auto optimization_target = model.model_info().optimization_target(); |
| received_optimization_targets.emplace(optimization_target); |
| if (!model.has_model()) { |
| // We already have this updated model, so don't update in store. |
| continue; |
| } |
| DCHECK(!model.model().download_url().empty()); |
| UpdateModelMetadata(model); |
| if (ShouldDownloadNewModel(model)) { |
| StartModelDownload(optimization_target, |
| GURL(model.model().download_url())); |
| // Skip over models that have a download URL since they will be updated |
| // once the download has completed successfully. |
| continue; |
| } |
| |
| RecordModelUpdateVersion(model.model_info()); |
| DCHECK(prediction_model_store_->HasModel(optimization_target, |
| model_cache_key_)); |
| // Load the model from the store to see whether it is valid or not. |
| prediction_model_store_->LoadModel( |
| optimization_target, model_cache_key_, |
| base::BindOnce(&PredictionManager::MaybeDownloadOrUpdatePredictionModel, |
| ui_weak_ptr_factory_.GetWeakPtr(), optimization_target, |
| model)); |
| if (optimization_guide_logger_->ShouldEnableDebugLogs()) { |
| OPTIMIZATION_GUIDE_LOGGER( |
| optimization_guide_common::mojom::LogSource::MODEL_MANAGEMENT, |
| optimization_guide_logger_) |
| << "Model Download Not Required for target: " << optimization_target |
| << "\nNew Version: " |
| << base::NumberToString(model.model_info().version()); |
| } |
| } |
| |
| for (const auto& model_info : models_request_info) { |
| if (received_optimization_targets.find(model_info.optimization_target()) == |
| received_optimization_targets.end()) { |
| RemoveModelFromStore( |
| model_info.optimization_target(), |
| PredictionModelStoreModelRemovalReason::kNoModelInGetModelsResponse); |
| } |
| } |
| } |
| |
| void PredictionManager::OnModelReady(const base::FilePath& base_model_dir, |
| const proto::PredictionModel& model) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| DCHECK(model.model_info().has_version() && |
| model.model_info().has_optimization_target()); |
| |
| auto overrides = PredictionModelOverrides::ParseFromCommandLine( |
| base::CommandLine::ForCurrentProcess()); |
| if (overrides.Get(model.model_info().optimization_target())) { |
| // Skip updating the model if override is present. |
| return; |
| } |
| |
| RecordModelUpdateVersion(model.model_info()); |
| RecordLifecycleState(model.model_info().optimization_target(), |
| ModelDeliveryEvent::kModelDownloaded); |
| if (optimization_guide_logger_->ShouldEnableDebugLogs()) { |
| OPTIMIZATION_GUIDE_LOGGER( |
| optimization_guide_common::mojom::LogSource::MODEL_MANAGEMENT, |
| optimization_guide_logger_) |
| << "Model Files Downloaded target: " |
| << model.model_info().optimization_target() |
| << "\nNew Version: " + |
| base::NumberToString(model.model_info().version()); |
| } |
| |
| // Store the received model in the store. |
| prediction_model_store_->UpdateModel( |
| model.model_info().optimization_target(), model_cache_key_, |
| model.model_info(), base_model_dir, |
| base::BindOnce(&PredictionManager::OnPredictionModelsStored, |
| ui_weak_ptr_factory_.GetWeakPtr())); |
| |
| if (model_registration_info_map_.contains( |
| model.model_info().optimization_target())) { |
| OnLoadPredictionModel(model.model_info().optimization_target(), |
| /*record_availability_metrics=*/false, |
| std::make_unique<proto::PredictionModel>(model)); |
| } |
| } |
| |
| void PredictionManager::OnModelDownloadStarted( |
| proto::OptimizationTarget optimization_target) { |
| RecordLifecycleState(optimization_target, |
| ModelDeliveryEvent::kModelDownloadStarted); |
| } |
| |
| void PredictionManager::OnModelDownloadFailed( |
| proto::OptimizationTarget optimization_target) { |
| RecordLifecycleState(optimization_target, |
| ModelDeliveryEvent::kModelDownloadFailure); |
| } |
| |
| std::vector<optimization_guide_internals::mojom::DownloadedModelInfoPtr> |
| PredictionManager::GetDownloadedModelsInfoForWebUI() const { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| std::vector<optimization_guide_internals::mojom::DownloadedModelInfoPtr> |
| downloaded_models_info; |
| downloaded_models_info.reserve(optimization_target_model_info_map_.size()); |
| for (const auto& it : optimization_target_model_info_map_) { |
| const std::string& optimization_target_name = |
| optimization_guide::proto::OptimizationTarget_Name(it.first); |
| const optimization_guide::ModelInfo* const model_info = it.second.get(); |
| auto downloaded_model_info_ptr = |
| optimization_guide_internals::mojom::DownloadedModelInfo::New( |
| optimization_target_name, model_info->GetVersion(), |
| model_info->GetModelFilePath().AsUTF8Unsafe()); |
| downloaded_models_info.push_back(std::move(downloaded_model_info_ptr)); |
| } |
| return downloaded_models_info; |
| } |
| |
| base::flat_map<std::string, bool> |
| PredictionManager::GetOnDeviceSupplementaryModelsInfoForWebUI() const { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| std::vector<proto::OptimizationTarget> supp_targets = { |
| proto::OptimizationTarget::OPTIMIZATION_TARGET_TEXT_SAFETY, |
| proto::OptimizationTarget::OPTIMIZATION_TARGET_LANGUAGE_DETECTION}; |
| base::flat_map<std::string, bool> supp_models_info; |
| for (const auto target : supp_targets) { |
| supp_models_info[optimization_guide::proto::OptimizationTarget_Name( |
| target)] = optimization_target_model_info_map_.contains(target); |
| } |
| |
| return supp_models_info; |
| } |
| |
| void PredictionManager::NotifyObserversOfNewModel( |
| proto::OptimizationTarget optimization_target, |
| base::optional_ref<const ModelInfo> model_info) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| auto registration_info_it = |
| model_registration_info_map_.find(optimization_target); |
| if (registration_info_it == model_registration_info_map_.end()) { |
| return; |
| } |
| RecordLifecycleState(optimization_target, |
| ModelDeliveryEvent::kModelDelivered); |
| for (auto& observer : registration_info_it->second.model_observers) { |
| observer.OnModelUpdated(optimization_target, model_info); |
| } |
| if (optimization_guide_logger_->ShouldEnableDebugLogs()) { |
| if (model_info.has_value()) { |
| OPTIMIZATION_GUIDE_LOGGER( |
| optimization_guide_common::mojom::LogSource::MODEL_MANAGEMENT, |
| optimization_guide_logger_) |
| << "OnModelFileUpdated for target: " << optimization_target |
| << "\nFile path: " << model_info->GetModelFilePath().AsUTF8Unsafe() |
| << "\nHas metadata: " |
| << (model_info->GetModelMetadata() ? "True" : "False"); |
| } else { |
| OPTIMIZATION_GUIDE_LOGGER( |
| optimization_guide_common::mojom::LogSource::MODEL_MANAGEMENT, |
| optimization_guide_logger_) |
| << "OnModelFileUpdated for target: " << optimization_target |
| << " for model removed"; |
| } |
| } |
| } |
| |
| void PredictionManager::OnPredictionModelsStored() { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| LOCAL_HISTOGRAM_BOOLEAN( |
| "OptimizationGuide.PredictionManager.PredictionModelsStored", true); |
| } |
| |
| void PredictionManager::MaybeInitializeModelDownloads( |
| download::BackgroundDownloadService* background_download_service) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| init_time_ = base::TimeTicks::Now(); |
| |
| // Create the download manager here if we are allowed to. |
| if (features::IsModelDownloadingEnabled() && !off_the_record_ && |
| !prediction_model_download_manager_) { |
| prediction_model_download_manager_ = |
| std::make_unique<PredictionModelDownloadManager>( |
| background_download_service, |
| base::BindRepeating( |
| &PredictionManager::GetBaseModelDirForDownload, |
| // base::Unretained is safe here because the |
| // PredictionModelDownloadManager is owned by `this` |
| base::Unretained(this)), |
| base::ThreadPool::CreateSequencedTaskRunner( |
| {base::MayBlock(), base::TaskPriority::BEST_EFFORT, |
| base::TaskShutdownBehavior::CONTINUE_ON_SHUTDOWN})); |
| prediction_model_download_manager_->AddObserver(this); |
| } |
| |
| // Only load models if there are optimization targets registered. |
| if (!model_registration_info_map_.empty() && |
| ShouldFetchModels(off_the_record_, |
| component_updates_enabled_provider_.Run(), |
| should_check_google_api_key_configuration_)) { |
| prediction_model_fetch_timer_.MaybeScheduleFirstModelFetch(); |
| } |
| } |
| |
| void PredictionManager::OnPredictionModelOverrideLoaded( |
| proto::OptimizationTarget optimization_target, |
| std::unique_ptr<proto::PredictionModel> prediction_model) { |
| const bool is_available = prediction_model != nullptr; |
| VLOG(0) << "Loading override for " |
| << proto::OptimizationTarget_Name(optimization_target) |
| << (is_available ? "succeeded" : "failed"); |
| OnLoadPredictionModel(optimization_target, |
| /*record_availability_metrics=*/false, |
| std::move(prediction_model)); |
| RecordModelAvailableAtRegistration(optimization_target, is_available); |
| } |
| |
| void PredictionManager::LoadPredictionModels( |
| const base::flat_set<proto::OptimizationTarget>& optimization_targets) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| auto overrides = PredictionModelOverrides::ParseFromCommandLine( |
| base::CommandLine::ForCurrentProcess()); |
| for (proto::OptimizationTarget optimization_target : optimization_targets) { |
| // Give preference to any overrides given on the command line. |
| if (auto* entry = overrides.Get(optimization_target); entry) { |
| base::FilePath base_model_dir = |
| GetBaseModelDirForDownload(optimization_target); |
| entry->BuildModel( |
| base_model_dir, |
| base::BindOnce(&PredictionManager::OnPredictionModelOverrideLoaded, |
| ui_weak_ptr_factory_.GetWeakPtr(), |
| optimization_target)); |
| continue; |
| } |
| |
| if (!prediction_model_store_->HasModel(optimization_target, |
| model_cache_key_)) { |
| RecordModelAvailableAtRegistration(optimization_target, false); |
| continue; |
| } |
| prediction_model_store_->LoadModel( |
| optimization_target, model_cache_key_, |
| base::BindOnce(&PredictionManager::OnLoadPredictionModel, |
| ui_weak_ptr_factory_.GetWeakPtr(), optimization_target, |
| /*record_availability_metrics=*/true)); |
| } |
| } |
| |
| void PredictionManager::OnLoadPredictionModel( |
| proto::OptimizationTarget optimization_target, |
| bool record_availability_metrics, |
| std::unique_ptr<proto::PredictionModel> model) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| if (!model) { |
| if (record_availability_metrics) { |
| RecordModelAvailableAtRegistration(optimization_target, false); |
| } |
| return; |
| } |
| bool success = ProcessAndStoreLoadedModel(*model); |
| DCHECK_EQ(optimization_target, model->model_info().optimization_target()); |
| if (record_availability_metrics) { |
| RecordModelAvailableAtRegistration(optimization_target, success); |
| } |
| OnProcessLoadedModel(*model, success); |
| } |
| |
| void PredictionManager::OnProcessLoadedModel( |
| const proto::PredictionModel& model, |
| bool success) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| if (success) { |
| base::UmaHistogramSparse("OptimizationGuide.PredictionModelLoadedVersion." + |
| GetStringNameForOptimizationTarget( |
| model.model_info().optimization_target()), |
| model.model_info().version()); |
| return; |
| } |
| RemoveModelFromStore( |
| model.model_info().optimization_target(), |
| PredictionModelStoreModelRemovalReason::kModelLoadFailed); |
| } |
| |
| void PredictionManager::RemoveModelFromStore( |
| proto::OptimizationTarget optimization_target, |
| PredictionModelStoreModelRemovalReason model_removal_reason) { |
| if (prediction_model_store_->HasModel(optimization_target, |
| model_cache_key_)) { |
| prediction_model_store_->RemoveModel(optimization_target, model_cache_key_, |
| model_removal_reason); |
| NotifyObserversOfNewModel(optimization_target, std::nullopt); |
| } |
| } |
| |
| bool PredictionManager::ProcessAndStoreLoadedModel( |
| const proto::PredictionModel& model) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| if (!model.model_info().has_optimization_target()) { |
| return false; |
| } |
| if (!model.model_info().has_version()) { |
| return false; |
| } |
| if (!model.has_model()) { |
| return false; |
| } |
| if (!model_registration_info_map_.contains( |
| model.model_info().optimization_target())) { |
| return false; |
| } |
| |
| ScopedPredictionModelConstructionAndValidationRecorder |
| prediction_model_recorder(model.model_info().optimization_target()); |
| std::unique_ptr<ModelInfo> model_info = ModelInfo::Create(model); |
| if (!model_info) { |
| prediction_model_recorder.set_is_valid(false); |
| return false; |
| } |
| |
| proto::OptimizationTarget optimization_target = |
| model.model_info().optimization_target(); |
| |
| // See if we should update the loaded model. |
| if (!ShouldUpdateStoredModelForTarget(optimization_target, |
| model.model_info().version())) { |
| return true; |
| } |
| |
| // Update prediction model file if that is what we have loaded. |
| if (model_info) { |
| StoreLoadedModelInfo(optimization_target, std::move(model_info)); |
| } |
| |
| return true; |
| } |
| |
| bool PredictionManager::ShouldUpdateStoredModelForTarget( |
| proto::OptimizationTarget optimization_target, |
| int64_t new_version) const { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| auto model_meta_it = |
| optimization_target_model_info_map_.find(optimization_target); |
| if (model_meta_it != optimization_target_model_info_map_.end()) { |
| return model_meta_it->second->GetVersion() != new_version; |
| } |
| |
| return true; |
| } |
| |
| void PredictionManager::StoreLoadedModelInfo( |
| proto::OptimizationTarget optimization_target, |
| std::unique_ptr<ModelInfo> model_info) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| DCHECK(model_info); |
| |
| // Notify observers of new model file path. |
| base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( |
| FROM_HERE, base::BindOnce(&PredictionManager::NotifyObserversOfNewModel, |
| ui_weak_ptr_factory_.GetWeakPtr(), |
| optimization_target, *model_info)); |
| |
| optimization_target_model_info_map_.insert_or_assign(optimization_target, |
| std::move(model_info)); |
| } |
| |
| base::FilePath PredictionManager::GetBaseModelDirForDownload( |
| proto::OptimizationTarget optimization_target) { |
| return prediction_model_store_->GetBaseModelDirForModelCacheKey( |
| optimization_target, model_cache_key_); |
| } |
| |
| void PredictionManager::OverrideTargetModelForTesting( |
| proto::OptimizationTarget optimization_target, |
| std::unique_ptr<ModelInfo> model_info) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| std::optional<ModelInfo> model_info_copy; |
| if (model_info) { |
| model_info_copy = *model_info; |
| } |
| optimization_target_model_info_map_.insert_or_assign(optimization_target, |
| std::move(model_info)); |
| |
| NotifyObserversOfNewModel(optimization_target, model_info_copy); |
| } |
| |
| } // namespace optimization_guide |