| // Copyright 2019 The Chromium Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "chrome/browser/optimization_guide/prediction/prediction_manager.h" |
| |
| #include <map> |
| #include <memory> |
| #include <string> |
| #include <utility> |
| |
| #include "base/base64.h" |
| #include "base/strings/string_number_conversions.h" |
| #include "base/strings/stringprintf.h" |
| #include "base/test/gtest_util.h" |
| #include "base/test/metrics/histogram_tester.h" |
| #include "base/test/scoped_feature_list.h" |
| #include "base/time/time.h" |
| #include "build/build_config.h" |
| #include "build/chromeos_buildflags.h" |
| #include "chrome/browser/browser_process.h" |
| #include "chrome/browser/optimization_guide/optimization_guide_web_contents_observer.h" |
| #include "chrome/browser/optimization_guide/prediction/prediction_model_download_manager.h" |
| #include "chrome/test/base/testing_profile.h" |
| #include "components/leveldb_proto/testing/fake_db.h" |
| #include "components/optimization_guide/core/optimization_guide_features.h" |
| #include "components/optimization_guide/core/optimization_guide_prefs.h" |
| #include "components/optimization_guide/core/optimization_guide_store.h" |
| #include "components/optimization_guide/core/optimization_guide_switches.h" |
| #include "components/optimization_guide/core/optimization_guide_test_util.h" |
| #include "components/optimization_guide/core/optimization_guide_util.h" |
| #include "components/optimization_guide/core/optimization_target_model_observer.h" |
| #include "components/optimization_guide/core/prediction_model.h" |
| #include "components/optimization_guide/core/prediction_model_fetcher.h" |
| #include "components/optimization_guide/core/proto_database_provider_test_base.h" |
| #include "components/optimization_guide/proto/hint_cache.pb.h" |
| #include "components/optimization_guide/proto/models.pb.h" |
| #include "components/prefs/testing_pref_service.h" |
| #include "components/variations/scoped_variations_ids_provider.h" |
| #include "content/public/test/browser_task_environment.h" |
| #include "content/public/test/mock_navigation_handle.h" |
| #include "content/public/test/test_web_contents_factory.h" |
| #include "content/public/test/web_contents_tester.h" |
| #include "services/network/public/cpp/shared_url_loader_factory.h" |
| #include "services/network/public/cpp/weak_wrapper_shared_url_loader_factory.h" |
| #include "services/network/test/test_network_connection_tracker.h" |
| #include "services/network/test/test_url_loader_factory.h" |
| #include "testing/gtest/include/gtest/gtest.h" |
| #include "ui/base/page_transition_types.h" |
| |
| using leveldb_proto::test::FakeDB; |
| |
| namespace { |
| // Retry delay is 2 minutes to allow for fetch retry delay + some random delay |
| // to pass. |
| constexpr int kTestFetchRetryDelaySecs = 60 * 2 + 62; |
| // 24 hours + random fetch delay. |
| constexpr int kUpdateFetchModelAndFeaturesTimeSecs = 24 * 60 * 60 + 62; |
| |
| } // namespace |
| |
| namespace optimization_guide { |
| |
| proto::PredictionModel CreatePredictionModel( |
| bool output_model_as_download_url = false) { |
| proto::PredictionModel prediction_model; |
| |
| proto::ModelInfo* model_info = prediction_model.mutable_model_info(); |
| model_info->set_version(1); |
| model_info->add_supported_host_model_features("host_feat1"); |
| model_info->set_optimization_target( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD); |
| model_info->add_supported_model_types( |
| proto::ModelType::MODEL_TYPE_DECISION_TREE); |
| if (output_model_as_download_url) { |
| prediction_model.mutable_model()->set_download_url( |
| "https://example.com/model"); |
| } else { |
| prediction_model.mutable_model()->mutable_threshold()->set_value(5.0); |
| } |
| return prediction_model; |
| } |
| |
| std::unique_ptr<proto::GetModelsResponse> BuildGetModelsResponse( |
| bool output_model_as_download_url = false) { |
| std::unique_ptr<proto::GetModelsResponse> get_models_response = |
| std::make_unique<proto::GetModelsResponse>(); |
| |
| proto::PredictionModel prediction_model = |
| CreatePredictionModel(output_model_as_download_url); |
| prediction_model.mutable_model_info()->add_supported_host_model_features( |
| "host_feat1"); |
| prediction_model.mutable_model_info()->set_version(2); |
| *get_models_response->add_models() = std::move(prediction_model); |
| |
| return get_models_response; |
| } |
| |
| class TestPredictionModel : public PredictionModel { |
| public: |
| explicit TestPredictionModel(const proto::PredictionModel& prediction_model) |
| : PredictionModel(prediction_model) {} |
| ~TestPredictionModel() override = default; |
| |
| OptimizationTargetDecision Predict( |
| const base::flat_map<std::string, float>& model_features, |
| double* prediction_score) override { |
| *prediction_score = 0.0; |
| // Check to make sure the all model_features were provided. |
| for (const auto& model_feature : GetModelFeatures()) { |
| if (!model_features.contains(model_feature)) |
| return OptimizationTargetDecision::kUnknown; |
| } |
| *prediction_score = 0.6; |
| model_evaluated_ = true; |
| last_evaluated_features_ = |
| base::flat_map<std::string, float>(model_features); |
| return OptimizationTargetDecision::kPageLoadMatches; |
| } |
| |
| bool WasModelEvaluated() { return model_evaluated_; } |
| |
| void ResetModelEvaluationState() { model_evaluated_ = false; } |
| |
| base::flat_map<std::string, float> last_evaluated_features() { |
| return last_evaluated_features_; |
| } |
| |
| private: |
| bool ValidatePredictionModel() const override { return true; } |
| |
| bool model_evaluated_ = false; |
| base::flat_map<std::string, float> last_evaluated_features_; |
| }; |
| |
| class FakeOptimizationTargetModelObserver |
| : public OptimizationTargetModelObserver { |
| public: |
| void OnModelUpdated(proto::OptimizationTarget optimization_target, |
| const ModelInfo& model_info) override { |
| last_received_models_.insert_or_assign(optimization_target, model_info); |
| } |
| |
| absl::optional<ModelInfo> last_received_model_for_target( |
| proto::OptimizationTarget optimization_target) const { |
| auto model_it = last_received_models_.find(optimization_target); |
| if (model_it == last_received_models_.end()) |
| return absl::nullopt; |
| return model_it->second; |
| } |
| |
| // Resets the state of the observer. |
| void Reset() { last_received_models_.clear(); } |
| |
| private: |
| base::flat_map<proto::OptimizationTarget, ModelInfo> last_received_models_; |
| }; |
| |
| class FakePredictionModelDownloadManager |
| : public PredictionModelDownloadManager { |
| public: |
| FakePredictionModelDownloadManager( |
| scoped_refptr<base::SequencedTaskRunner> task_runner) |
| : PredictionModelDownloadManager(/*download_service=*/nullptr, |
| task_runner) {} |
| ~FakePredictionModelDownloadManager() override = default; |
| |
| void StartDownload(const GURL& url) override { |
| last_requested_download_ = url; |
| } |
| |
| GURL last_requested_download() const { return last_requested_download_; } |
| |
| void CancelAllPendingDownloads() override { cancel_downloads_called_ = true; } |
| bool cancel_downloads_called() const { return cancel_downloads_called_; } |
| |
| bool IsAvailableForDownloads() const override { return is_available_; } |
| void SetAvailableForDownloads(bool is_available) { |
| is_available_ = is_available; |
| } |
| |
| private: |
| GURL last_requested_download_; |
| bool cancel_downloads_called_ = false; |
| bool is_available_ = true; |
| }; |
| |
| enum class PredictionModelFetcherEndState { |
| kFetchFailed = 0, |
| kFetchSuccessWithModels = 1, |
| kFetchSuccessWithEmptyResponse = 2, |
| kFetchSuccessWithModelDownloadUrls = 3, |
| }; |
| |
| void RunGetModelsCallback( |
| ModelsFetchedCallback callback, |
| std::unique_ptr<proto::GetModelsResponse> get_models_response) { |
| if (get_models_response) { |
| std::move(callback).Run(std::move(get_models_response)); |
| return; |
| } |
| std::move(callback).Run(absl::nullopt); |
| } |
| |
| // A mock class implementation of PredictionModelFetcher. |
| class TestPredictionModelFetcher : public PredictionModelFetcher { |
| public: |
| TestPredictionModelFetcher( |
| scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory, |
| const GURL& optimization_guide_service_get_models_url, |
| network::NetworkConnectionTracker* network_connection_tracker, |
| PredictionModelFetcherEndState fetch_state) |
| : PredictionModelFetcher(url_loader_factory, |
| optimization_guide_service_get_models_url, |
| network_connection_tracker), |
| fetch_state_(fetch_state) {} |
| |
| bool FetchOptimizationGuideServiceModels( |
| const std::vector<proto::ModelInfo>& models_request_info, |
| const std::vector<proto::FieldTrial>& active_field_trials, |
| proto::RequestContext request_context, |
| const std::string& locale, |
| ModelsFetchedCallback models_fetched_callback) override { |
| if (!ValidateModelsInfoForFetch(models_request_info)) { |
| std::move(models_fetched_callback).Run(absl::nullopt); |
| return false; |
| } |
| |
| std::unique_ptr<proto::GetModelsResponse> get_models_response; |
| locale_requested_ = locale; |
| switch (fetch_state_) { |
| case PredictionModelFetcherEndState::kFetchFailed: |
| get_models_response = nullptr; |
| break; |
| case PredictionModelFetcherEndState::kFetchSuccessWithModels: |
| models_fetched_ = true; |
| get_models_response = BuildGetModelsResponse(); |
| break; |
| case PredictionModelFetcherEndState::kFetchSuccessWithEmptyResponse: |
| models_fetched_ = true; |
| get_models_response = std::make_unique<proto::GetModelsResponse>(); |
| break; |
| case PredictionModelFetcherEndState::kFetchSuccessWithModelDownloadUrls: |
| models_fetched_ = true; |
| get_models_response = |
| BuildGetModelsResponse(/*output_model_as_download_url=*/true); |
| break; |
| } |
| base::ThreadTaskRunnerHandle::Get()->PostTask( |
| FROM_HERE, base::BindOnce(&RunGetModelsCallback, |
| std::move(models_fetched_callback), |
| std::move(get_models_response))); |
| return true; |
| } |
| |
| bool ValidateModelsInfoForFetch( |
| const std::vector<proto::ModelInfo>& models_request_info) { |
| for (const auto& model_info : models_request_info) { |
| if (model_info.supported_model_types_size() == 0 || |
| !proto::ModelType_IsValid(model_info.supported_model_types(0))) { |
| return false; |
| } |
| if (!model_info.has_optimization_target() || |
| !proto::OptimizationTarget_IsValid( |
| model_info.optimization_target())) { |
| return false; |
| } |
| |
| if (check_expected_version_) { |
| auto version_it = |
| expected_version_.find(model_info.optimization_target()); |
| if (model_info.has_version() != |
| (version_it != expected_version_.end())) { |
| return false; |
| } |
| if (model_info.has_version() && |
| model_info.version() != version_it->second) { |
| return false; |
| } |
| } |
| |
| auto it = expected_metadata_.find(model_info.optimization_target()); |
| if (model_info.has_model_metadata() != (it != expected_metadata_.end())) |
| return false; |
| if (model_info.has_model_metadata()) { |
| proto::Any expected_metadata = it->second; |
| if (model_info.model_metadata().type_url() != |
| expected_metadata.type_url()) { |
| return false; |
| } |
| if (model_info.model_metadata().value() != expected_metadata.value()) |
| return false; |
| } |
| } |
| return true; |
| } |
| |
| void SetExpectedModelMetadataForOptimizationTarget( |
| proto::OptimizationTarget optimization_target, |
| const proto::Any& model_metadata) { |
| expected_metadata_[optimization_target] = model_metadata; |
| } |
| |
| void SetExpectedVersionForOptimizationTarget( |
| proto::OptimizationTarget optimization_target, |
| int64_t version) { |
| expected_version_[optimization_target] = version; |
| } |
| |
| void SetCheckExpectedVersion() { check_expected_version_ = true; } |
| |
| void Reset() { models_fetched_ = false; } |
| |
| bool models_fetched() const { return models_fetched_; } |
| |
| std::string locale_requested() const { return locale_requested_; } |
| |
| private: |
| bool models_fetched_ = false; |
| bool check_expected_version_ = false; |
| std::string locale_requested_; |
| // The desired behavior of the TestPredictionModelFetcher. |
| PredictionModelFetcherEndState fetch_state_; |
| base::flat_map<proto::OptimizationTarget, proto::Any> expected_metadata_; |
| base::flat_map<proto::OptimizationTarget, int64_t> expected_version_; |
| }; |
| |
| class TestOptimizationGuideStore : public OptimizationGuideStore { |
| public: |
| TestOptimizationGuideStore( |
| std::unique_ptr<StoreEntryProtoDatabase> database, |
| scoped_refptr<base::SequencedTaskRunner> store_task_runner) |
| : OptimizationGuideStore(std::move(database), store_task_runner) {} |
| |
| ~TestOptimizationGuideStore() override = default; |
| |
| void Initialize(bool purge_existing_data, |
| base::OnceClosure callback) override { |
| init_callback_ = std::move(callback); |
| status_ = Status::kAvailable; |
| } |
| |
| void RunInitCallback(bool load_models = true, |
| bool load_host_model_features = true, |
| bool have_models_in_store = true) { |
| load_models_ = load_models; |
| load_host_model_features_ = load_host_model_features; |
| have_models_in_store_ = have_models_in_store; |
| std::move(init_callback_).Run(); |
| } |
| |
| void RunUpdateHostModelFeaturesCallback() { |
| std::move(update_host_models_callback_).Run(); |
| } |
| |
| void LoadPredictionModel(const EntryKey& prediction_model_entry_key, |
| PredictionModelLoadedCallback callback) override { |
| model_loaded_ = true; |
| if (load_models_) { |
| std::move(callback).Run( |
| std::make_unique<proto::PredictionModel>(CreatePredictionModel())); |
| } else { |
| std::move(callback).Run(nullptr); |
| } |
| } |
| |
| void LoadAllHostModelFeatures( |
| AllHostModelFeaturesLoadedCallback callback) override { |
| host_model_features_loaded_ = true; |
| if (load_host_model_features_) { |
| proto::HostModelFeatures host_model_features; |
| host_model_features.set_host("foo.com"); |
| proto::ModelFeature* model_feature = |
| host_model_features.add_model_features(); |
| model_feature->set_feature_name("host_feat1"); |
| model_feature->set_double_value(2.0); |
| std::unique_ptr<std::vector<proto::HostModelFeatures>> |
| all_host_model_features = |
| std::make_unique<std::vector<proto::HostModelFeatures>>(); |
| all_host_model_features->emplace_back(host_model_features); |
| std::move(callback).Run(std::move(all_host_model_features)); |
| } else { |
| std::move(callback).Run(nullptr); |
| } |
| } |
| |
| bool FindPredictionModelEntryKey( |
| proto::OptimizationTarget optimization_target, |
| OptimizationGuideStore::EntryKey* out_prediction_model_entry_key) |
| override { |
| if (optimization_target == |
| proto::OptimizationTarget::OPTIMIZATION_TARGET_UNKNOWN) { |
| return false; |
| } |
| if (have_models_in_store_) { |
| *out_prediction_model_entry_key = |
| "4_" + base::NumberToString(static_cast<int>(optimization_target)); |
| return true; |
| } |
| return false; |
| } |
| |
| void UpdateHostModelFeatures( |
| std::unique_ptr<StoreUpdateData> host_model_features_update_data, |
| base::OnceClosure callback) override { |
| host_model_features_update_time_ = |
| *host_model_features_update_data->update_time(); |
| update_host_models_callback_ = std::move(callback); |
| } |
| |
| void UpdatePredictionModels( |
| std::unique_ptr<StoreUpdateData> prediction_models_update_data, |
| base::OnceClosure callback) override { |
| std::move(callback).Run(); |
| } |
| |
| bool WasModelLoaded() const { return model_loaded_; } |
| bool WasHostModelFeaturesLoaded() const { |
| return host_model_features_loaded_; |
| } |
| |
| private: |
| base::OnceClosure init_callback_; |
| base::OnceClosure update_host_models_callback_; |
| bool model_loaded_ = false; |
| bool host_model_features_loaded_ = false; |
| bool load_models_ = true; |
| bool load_host_model_features_ = true; |
| bool have_models_in_store_ = true; |
| }; |
| |
| class TestPredictionManager : public PredictionManager { |
| public: |
| TestPredictionManager( |
| OptimizationGuideStore* model_and_features_store, |
| scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory, |
| PrefService* pref_service, |
| Profile* profile) |
| : PredictionManager(model_and_features_store, |
| url_loader_factory, |
| pref_service, |
| profile) {} |
| |
| ~TestPredictionManager() override = default; |
| |
| std::unique_ptr<PredictionModel> CreatePredictionModel( |
| const proto::PredictionModel& model) const override { |
| if (!create_valid_prediction_model_) |
| return nullptr; |
| return std::make_unique<TestPredictionModel>(model); |
| } |
| |
| void set_create_valid_prediction_model(bool create_valid_prediction_model) { |
| create_valid_prediction_model_ = create_valid_prediction_model; |
| } |
| |
| using PredictionManager::GetHostModelFeaturesForHost; |
| using PredictionManager::GetHostModelFeaturesForTesting; |
| using PredictionManager::GetPredictionModelForTesting; |
| |
| void UpdateHostModelFeaturesForTesting( |
| proto::GetModelsResponse* get_models_response) { |
| UpdateHostModelFeatures(get_models_response->host_model_features()); |
| } |
| |
| void UpdatePredictionModelsForTesting( |
| proto::GetModelsResponse* get_models_response) { |
| UpdatePredictionModels(get_models_response->models()); |
| } |
| |
| private: |
| bool create_valid_prediction_model_ = true; |
| }; |
| |
| class PredictionManagerTestBase : public ProtoDatabaseProviderTestBase { |
| public: |
| using StoreEntry = proto::StoreEntry; |
| using StoreEntryMap = std::map<OptimizationGuideStore::EntryKey, StoreEntry>; |
| PredictionManagerTestBase() = default; |
| ~PredictionManagerTestBase() override = default; |
| |
| PredictionManagerTestBase(const PredictionManagerTestBase&) = delete; |
| PredictionManagerTestBase& operator=(const PredictionManagerTestBase&) = |
| delete; |
| |
| void SetUp() override { |
| ProtoDatabaseProviderTestBase::SetUp(); |
| web_contents_factory_ = std::make_unique<content::TestWebContentsFactory>(); |
| |
| pref_service_ = std::make_unique<TestingPrefServiceSimple>(); |
| prefs::RegisterProfilePrefs(pref_service_->registry()); |
| |
| url_loader_factory_ = |
| base::MakeRefCounted<network::WeakWrapperSharedURLLoaderFactory>( |
| &test_url_loader_factory_); |
| base::CommandLine::ForCurrentProcess()->AppendSwitch( |
| switches::kDisableCheckingUserPermissionsForTesting); |
| base::CommandLine::ForCurrentProcess()->AppendSwitch( |
| switches::kFetchModelsAndHostModelFeaturesOverrideTimer); |
| } |
| |
| void CreatePredictionManager() { |
| if (prediction_manager_) { |
| db_store_.clear(); |
| model_and_features_store_.reset(); |
| prediction_manager_.reset(); |
| } |
| |
| model_and_features_store_ = CreateModelAndHostModelFeaturesStore(); |
| prediction_manager_ = std::make_unique<TestPredictionManager>( |
| model_and_features_store_.get(), url_loader_factory_, |
| pref_service_.get(), &testing_profile_); |
| prediction_manager_->SetClockForTesting(task_environment_.GetMockClock()); |
| } |
| |
| std::unique_ptr<TestOptimizationGuideStore> |
| CreateModelAndHostModelFeaturesStore() { |
| // Setup the fake db and the class under test. |
| auto db = std::make_unique<FakeDB<StoreEntry>>(&db_store_); |
| |
| return std::make_unique<TestOptimizationGuideStore>( |
| std::move(db), task_environment_.GetMainThreadTaskRunner()); |
| } |
| |
| void RegisterOptimizationTargets( |
| const std::vector< |
| std::pair<proto::OptimizationTarget, absl::optional<proto::Any>>>& |
| optimization_targets_and_metadata) { |
| prediction_manager_->RegisterOptimizationTargets( |
| optimization_targets_and_metadata); |
| } |
| |
| TestPredictionManager* prediction_manager() const { |
| return prediction_manager_.get(); |
| } |
| |
| // Creates a navigation handle with the OptimizationGuideWebContentsObserver |
| // attached. |
| std::unique_ptr<content::MockNavigationHandle> |
| CreateMockNavigationHandleWithOptimizationGuideWebContentsObserver( |
| const GURL& url) { |
| content::WebContents* web_contents = |
| web_contents_factory_->CreateWebContents(&testing_profile_); |
| OptimizationGuideWebContentsObserver::CreateForWebContents(web_contents); |
| std::unique_ptr<content::MockNavigationHandle> navigation_handle = |
| std::make_unique<content::MockNavigationHandle>(web_contents); |
| navigation_handle->set_url(url); |
| return navigation_handle; |
| } |
| |
| void TearDown() override { ProtoDatabaseProviderTestBase::TearDown(); } |
| |
| std::unique_ptr<TestPredictionModelFetcher> BuildTestPredictionModelFetcher( |
| PredictionModelFetcherEndState end_state) { |
| std::unique_ptr<TestPredictionModelFetcher> prediction_model_fetcher = |
| std::make_unique<TestPredictionModelFetcher>( |
| url_loader_factory_, GURL("https://hintsserver.com"), |
| network::TestNetworkConnectionTracker::GetInstance(), end_state); |
| return prediction_model_fetcher; |
| } |
| |
| void SetStoreInitialized(bool load_models = true, |
| bool load_host_model_features = true, |
| bool have_models_in_store = true) { |
| models_and_features_store()->RunInitCallback( |
| load_models, load_host_model_features, have_models_in_store); |
| RunUntilIdle(); |
| // Move clock forward for any short delays added for the fetcher. |
| MoveClockForwardBy(base::Seconds(2)); |
| } |
| |
| void MoveClockForwardBy(base::TimeDelta time_delta) { |
| task_environment_.FastForwardBy(time_delta); |
| RunUntilIdle(); |
| } |
| |
| TestPredictionModelFetcher* prediction_model_fetcher() const { |
| return static_cast<TestPredictionModelFetcher*>( |
| prediction_manager()->prediction_model_fetcher()); |
| } |
| |
| FakePredictionModelDownloadManager* prediction_model_download_manager() |
| const { |
| return static_cast<FakePredictionModelDownloadManager*>( |
| prediction_manager()->prediction_model_download_manager()); |
| } |
| |
| TestOptimizationGuideStore* models_and_features_store() const { |
| return static_cast<TestOptimizationGuideStore*>( |
| prediction_manager()->model_and_features_store()); |
| } |
| |
| base::FilePath temp_dir() const { return temp_dir_.GetPath(); } |
| |
| TestingPrefServiceSimple* pref_service() const { return pref_service_.get(); } |
| |
| TestingProfile* profile() { return &testing_profile_; } |
| |
| void RunUntilIdle() { |
| task_environment_.RunUntilIdle(); |
| base::RunLoop().RunUntilIdle(); |
| } |
| |
| content::BrowserTaskEnvironment* task_environment() { |
| return &task_environment_; |
| } |
| |
| protected: |
| // |feature_list_| needs to be destroyed after |task_environment_|, to avoid |
| // tsan flakes caused by other tasks running while |feature_list_| is |
| // destroyed. |
| base::test::ScopedFeatureList feature_list_; |
| |
| private: |
| content::BrowserTaskEnvironment task_environment_{ |
| base::test::TaskEnvironment::MainThreadType::UI, |
| base::test::TaskEnvironment::TimeSource::MOCK_TIME}; |
| StoreEntryMap db_store_; |
| std::unique_ptr<TestOptimizationGuideStore> model_and_features_store_; |
| std::unique_ptr<TestPredictionManager> prediction_manager_; |
| scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory_; |
| network::TestURLLoaderFactory test_url_loader_factory_; |
| TestingProfile testing_profile_; |
| std::unique_ptr<TestingPrefServiceSimple> pref_service_; |
| std::unique_ptr<content::TestWebContentsFactory> web_contents_factory_; |
| }; |
| |
| class PredictionManagerRemoteFetchingDisabledTest |
| : public PredictionManagerTestBase { |
| public: |
| PredictionManagerRemoteFetchingDisabledTest() { |
| // This needs to be done before any tasks are run that might check if a |
| // feature is enabled, to avoid tsan errors. |
| feature_list_.InitAndDisableFeature( |
| features::kRemoteOptimizationGuideFetching); |
| } |
| }; |
| |
| TEST_F(PredictionManagerRemoteFetchingDisabledTest, RemoteFetchingDisabled) { |
| CreatePredictionManager(); |
| |
| prediction_manager()->SetPredictionModelFetcherForTesting( |
| BuildTestPredictionModelFetcher( |
| PredictionModelFetcherEndState::kFetchSuccessWithModels)); |
| |
| RegisterOptimizationTargets( |
| {{proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, absl::nullopt}}); |
| SetStoreInitialized(); |
| |
| EXPECT_FALSE(prediction_model_fetcher()->models_fetched()); |
| } |
| |
| class PredictionManagerTest : public PredictionManagerTestBase { |
| public: |
| PredictionManagerTest() { |
| // This needs to be done before any tasks are run that might check if a |
| // feature is enabled, to avoid tsan errors. |
| feature_list_.InitAndEnableFeature( |
| features::kRemoteOptimizationGuideFetching); |
| } |
| |
| private: |
| variations::ScopedVariationsIdsProvider scoped_variations_ids_provider_{ |
| variations::VariationsIdsProvider::Mode::kUseSignedInState}; |
| }; |
| |
| TEST_F(PredictionManagerTest, OptimizationTargetNotRegisteredForNavigation) { |
| base::HistogramTester histogram_tester; |
| std::unique_ptr<content::MockNavigationHandle> navigation_handle = |
| CreateMockNavigationHandleWithOptimizationGuideWebContentsObserver( |
| GURL("https://foo.com")); |
| |
| CreatePredictionManager(); |
| |
| prediction_manager()->SetPredictionModelFetcherForTesting( |
| BuildTestPredictionModelFetcher( |
| PredictionModelFetcherEndState::kFetchSuccessWithModels)); |
| |
| RegisterOptimizationTargets( |
| {{proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, absl::nullopt}}); |
| SetStoreInitialized(); |
| |
| EXPECT_TRUE(prediction_model_fetcher()->models_fetched()); |
| |
| EXPECT_EQ(OptimizationTargetDecision::kUnknown, |
| prediction_manager()->ShouldTargetNavigation( |
| navigation_handle.get(), proto::OPTIMIZATION_TARGET_UNKNOWN)); |
| histogram_tester.ExpectTotalCount( |
| "OptimizationGuide.PredictionModelEvaluationLatency." + |
| GetStringNameForOptimizationTarget( |
| proto::OPTIMIZATION_TARGET_UNKNOWN), |
| 0); |
| histogram_tester.ExpectTotalCount( |
| "OptimizationGuide.PredictionModelEvaluationLatency." + |
| GetStringNameForOptimizationTarget( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD), |
| 0); |
| } |
| |
| TEST_F(PredictionManagerTest, AddObserverForOptimizationTargetModel) { |
| base::HistogramTester histogram_tester; |
| std::unique_ptr<content::MockNavigationHandle> navigation_handle = |
| CreateMockNavigationHandleWithOptimizationGuideWebContentsObserver( |
| GURL("https://foo.com")); |
| |
| CreatePredictionManager(); |
| |
| prediction_manager()->SetPredictionModelFetcherForTesting( |
| BuildTestPredictionModelFetcher( |
| PredictionModelFetcherEndState::kFetchSuccessWithEmptyResponse)); |
| proto::Any model_metadata; |
| model_metadata.set_type_url("whatever"); |
| prediction_model_fetcher()->SetExpectedModelMetadataForOptimizationTarget( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, model_metadata); |
| |
| FakeOptimizationTargetModelObserver observer; |
| prediction_manager()->AddObserverForOptimizationTargetModel( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, model_metadata, &observer); |
| SetStoreInitialized(/* load_models= */ false, |
| /* load_host_model_features= */ false, |
| /* have_models_in_store= */ false); |
| |
| EXPECT_TRUE(prediction_model_fetcher()->models_fetched()); |
| |
| EXPECT_EQ(OptimizationTargetDecision::kModelNotAvailableOnClient, |
| prediction_manager()->ShouldTargetNavigation( |
| navigation_handle.get(), |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD)); |
| |
| histogram_tester.ExpectTotalCount( |
| "OptimizationGuide.PredictionModelEvaluationLatency." + |
| GetStringNameForOptimizationTarget( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD), |
| 0); |
| |
| EXPECT_TRUE(prediction_manager()->GetRegisteredOptimizationTargets().contains( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD)); |
| EXPECT_FALSE(observer |
| .last_received_model_for_target( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD) |
| .has_value()); |
| |
| base::FilePath additional_file_path = |
| temp_dir().AppendASCII("whatever").AppendASCII("additional_file.txt"); |
| proto::ModelInfo model_info; |
| model_info.set_optimization_target( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD); |
| model_info.set_version(1); |
| model_info.mutable_model_metadata()->set_type_url("sometypeurl"); |
| model_info.add_additional_files()->set_file_path( |
| FilePathToString(additional_file_path)); |
| // An empty file path should be be ignored. |
| model_info.add_additional_files()->set_file_path(""); |
| |
| // Ensure observer is hooked up. |
| proto::PredictionModel model1; |
| *model1.mutable_model_info() = model_info; |
| model1.mutable_model()->set_download_url( |
| FilePathToString(temp_dir().AppendASCII("whatever"))); |
| prediction_manager()->OnModelReady(model1); |
| RunUntilIdle(); |
| |
| absl::optional<ModelInfo> received_model = |
| observer.last_received_model_for_target( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD); |
| EXPECT_EQ(received_model->GetModelMetadata()->type_url(), "sometypeurl"); |
| EXPECT_EQ(received_model->GetModelFilePath().BaseName().value(), |
| FILE_PATH_LITERAL("whatever")); |
| EXPECT_EQ(received_model->GetAdditionalFiles(), |
| base::flat_set<base::FilePath>{additional_file_path}); |
| |
| // Reset fetcher and make sure version is sent in the new request and not |
| // counted as re-loaded or updated. |
| { |
| base::HistogramTester histogram_tester2; |
| |
| prediction_model_fetcher()->Reset(); |
| prediction_model_fetcher()->SetCheckExpectedVersion(); |
| prediction_model_fetcher()->SetExpectedVersionForOptimizationTarget( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, 1); |
| MoveClockForwardBy(base::Seconds(kUpdateFetchModelAndFeaturesTimeSecs)); |
| EXPECT_TRUE(prediction_model_fetcher()->models_fetched()); |
| histogram_tester2.ExpectTotalCount( |
| "OptimizationGuide.PredictionModelUpdateVersion.PainfulPageLoad", 0); |
| histogram_tester2.ExpectTotalCount( |
| "OptimizationGuide.PredictionModelLoadedVersion.PainfulPageLoad", 0); |
| histogram_tester2.ExpectTotalCount( |
| "OptimizationGuide.PredictionModelRemoved.PainfulPageLoad", 0); |
| } |
| |
| // Now remove and reset observer. |
| prediction_manager()->RemoveObserverForOptimizationTargetModel( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, &observer); |
| observer.Reset(); |
| proto::PredictionModel model2; |
| *model2.mutable_model_info() = model_info; |
| model2.mutable_model_info()->set_version(2); |
| model2.mutable_model()->set_download_url( |
| FilePathToString(temp_dir().AppendASCII("whatever2"))); |
| prediction_manager()->OnModelReady(model2); |
| RunUntilIdle(); |
| |
| // Last received path should not have been updated since the observer was |
| // removed. |
| EXPECT_FALSE(observer |
| .last_received_model_for_target( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD) |
| .has_value()); |
| } |
| |
| TEST_F(PredictionManagerTest, |
| AddObserverForOptimizationTargetModelAddAnotherObserverForSameTarget) { |
| // Fails under "threadsafe" mode. |
| testing::GTEST_FLAG(death_test_style) = "fast"; |
| |
| CreatePredictionManager(); |
| |
| FakeOptimizationTargetModelObserver observer1; |
| prediction_manager()->AddObserverForOptimizationTargetModel( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, |
| /*model_metadata=*/absl::nullopt, &observer1); |
| SetStoreInitialized(/* load_models= */ false, |
| /* load_host_model_features= */ false, |
| /* have_models_in_store= */ false); |
| |
| proto::ModelInfo model_info; |
| model_info.set_optimization_target( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD); |
| model_info.set_version(1); |
| |
| // Ensure observer is hooked up. |
| proto::PredictionModel model1; |
| *model1.mutable_model_info() = model_info; |
| model1.mutable_model()->set_download_url( |
| FilePathToString(temp_dir().AppendASCII("whatever"))); |
| prediction_manager()->OnModelReady(model1); |
| RunUntilIdle(); |
| |
| EXPECT_EQ(observer1 |
| .last_received_model_for_target( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD) |
| ->GetModelFilePath() |
| .BaseName() |
| .value(), |
| FILE_PATH_LITERAL("whatever")); |
| |
| #if !defined(OS_WIN) |
| // Do not run the DCHECK death test on Windows since there's some weird |
| // behavior there. |
| |
| // Now, register a new observer - it should die. |
| FakeOptimizationTargetModelObserver observer2; |
| EXPECT_DCHECK_DEATH( |
| prediction_manager()->AddObserverForOptimizationTargetModel( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, |
| /*model_metadata=*/absl::nullopt, &observer2)); |
| RunUntilIdle(); |
| #endif |
| } |
| |
| // See crbug/1227996. |
| #if !defined(OS_WIN) |
| TEST_F(PredictionManagerTest, |
| AddObserverForOptimizationTargetModelCommandLineOverride) { |
| optimization_guide::proto::Any metadata; |
| metadata.set_type_url("sometypeurl"); |
| std::string encoded_metadata; |
| metadata.SerializeToString(&encoded_metadata); |
| base::Base64Encode(encoded_metadata, &encoded_metadata); |
| base::CommandLine::ForCurrentProcess()->AppendSwitchASCII( |
| switches::kModelOverride, |
| base::StringPrintf("OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD:%s:%s", |
| kTestAbsoluteFilePath, encoded_metadata.c_str())); |
| |
| CreatePredictionManager(); |
| |
| prediction_manager()->SetPredictionModelFetcherForTesting( |
| BuildTestPredictionModelFetcher( |
| PredictionModelFetcherEndState::kFetchSuccessWithEmptyResponse)); |
| proto::Any model_metadata; |
| model_metadata.set_type_url("whatever"); |
| prediction_model_fetcher()->SetExpectedModelMetadataForOptimizationTarget( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, model_metadata); |
| |
| FakeOptimizationTargetModelObserver observer; |
| prediction_manager()->AddObserverForOptimizationTargetModel( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, model_metadata, &observer); |
| SetStoreInitialized(/* load_models= */ false, |
| /* load_host_model_features= */ false, |
| /* have_models_in_store= */ false); |
| |
| // Make sure no models are fetched. |
| EXPECT_FALSE(prediction_model_fetcher()->models_fetched()); |
| |
| EXPECT_TRUE(prediction_manager()->GetRegisteredOptimizationTargets().contains( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD)); |
| EXPECT_EQ(observer |
| .last_received_model_for_target( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD) |
| ->GetModelMetadata() |
| ->type_url(), |
| "sometypeurl"); |
| EXPECT_EQ(observer |
| .last_received_model_for_target( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD) |
| ->GetModelFilePath() |
| .value(), |
| FILE_PATH_LITERAL(kTestAbsoluteFilePath)); |
| |
| // Now reset observer. New model downloads should not update the observer. |
| observer.Reset(); |
| proto::PredictionModel model; |
| model.mutable_model_info()->set_optimization_target( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD); |
| model.mutable_model_info()->set_version(1); |
| model.mutable_model()->set_download_url( |
| FilePathToString(temp_dir().AppendASCII("whatever2"))); |
| prediction_manager()->OnModelReady(model); |
| RunUntilIdle(); |
| |
| // Last received path should not have been updated since the observer was |
| // reset and override is in place. |
| EXPECT_FALSE(observer |
| .last_received_model_for_target( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD) |
| .has_value()); |
| } |
| #endif |
| |
| TEST_F(PredictionManagerTest, |
| NoPredictionModelForRegisteredOptimizationTarget) { |
| base::HistogramTester histogram_tester; |
| std::unique_ptr<content::MockNavigationHandle> navigation_handle = |
| CreateMockNavigationHandleWithOptimizationGuideWebContentsObserver( |
| GURL("https://foo.com")); |
| |
| CreatePredictionManager(); |
| RegisterOptimizationTargets( |
| {{proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, absl::nullopt}}); |
| |
| EXPECT_EQ(OptimizationTargetDecision::kModelNotAvailableOnClient, |
| prediction_manager()->ShouldTargetNavigation( |
| navigation_handle.get(), |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD)); |
| |
| histogram_tester.ExpectTotalCount( |
| "OptimizationGuide.PredictionModelEvaluationLatency." + |
| GetStringNameForOptimizationTarget( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD), |
| 0); |
| } |
| |
| TEST_F(PredictionManagerTest, EvaluatePredictionModel) { |
| base::HistogramTester histogram_tester; |
| std::unique_ptr<content::MockNavigationHandle> navigation_handle = |
| CreateMockNavigationHandleWithOptimizationGuideWebContentsObserver( |
| GURL("https://foo.com")); |
| |
| CreatePredictionManager(); |
| // The model will be loaded from the store. |
| prediction_manager()->SetPredictionModelFetcherForTesting( |
| BuildTestPredictionModelFetcher( |
| PredictionModelFetcherEndState::kFetchSuccessWithEmptyResponse)); |
| |
| RegisterOptimizationTargets( |
| {{proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, absl::nullopt}}); |
| SetStoreInitialized(); |
| EXPECT_TRUE(prediction_model_fetcher()->models_fetched()); |
| |
| EXPECT_EQ(OptimizationTargetDecision::kPageLoadMatches, |
| prediction_manager()->ShouldTargetNavigation( |
| navigation_handle.get(), |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD)); |
| TestPredictionModel* test_prediction_model = |
| static_cast<TestPredictionModel*>( |
| prediction_manager()->GetPredictionModelForTesting( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD)); |
| EXPECT_TRUE(test_prediction_model); |
| EXPECT_TRUE(test_prediction_model->WasModelEvaluated()); |
| |
| histogram_tester.ExpectTotalCount( |
| "OptimizationGuide.PredictionModelEvaluationLatency." + |
| GetStringNameForOptimizationTarget( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD), |
| 1); |
| |
| histogram_tester.ExpectUniqueSample( |
| "OptimizationGuide.IsPredictionModelValid." + |
| GetStringNameForOptimizationTarget( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD), |
| true, 1); |
| |
| histogram_tester.ExpectUniqueSample( |
| "OptimizationGuide.IsPredictionModelValid", true, 1); |
| |
| histogram_tester.ExpectTotalCount( |
| "OptimizationGuide.PredictionModelValidationLatency." + |
| GetStringNameForOptimizationTarget( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD), |
| 1); |
| histogram_tester.ExpectTotalCount( |
| "OptimizationGuide.PredictionModelValidationLatency", 1); |
| } |
| |
| TEST_F(PredictionManagerTest, UpdatePredictionModelsWithInvalidModel) { |
| base::HistogramTester histogram_tester; |
| CreatePredictionManager(); |
| prediction_manager()->SetPredictionModelFetcherForTesting( |
| BuildTestPredictionModelFetcher( |
| PredictionModelFetcherEndState::kFetchFailed)); |
| |
| RegisterOptimizationTargets( |
| {{proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, absl::nullopt}}); |
| |
| std::unique_ptr<proto::GetModelsResponse> get_models_response = |
| BuildGetModelsResponse(); |
| |
| // Override the manager so that any prediction model updates will be seen as |
| // invalid. |
| prediction_manager()->set_create_valid_prediction_model(false); |
| prediction_manager()->UpdatePredictionModelsForTesting( |
| get_models_response.get()); |
| |
| histogram_tester.ExpectBucketCount("OptimizationGuide.IsPredictionModelValid", |
| false, 1); |
| |
| histogram_tester.ExpectTotalCount( |
| "OptimizationGuide.PredictionModelValidationLatency", 0); |
| histogram_tester.ExpectTotalCount( |
| "OptimizationGuide.PredictionModelUpdateVersion.PainfulPageLoad", 1); |
| histogram_tester.ExpectTotalCount( |
| "OptimizationGuide.PredictionModelLoadedVersion.PainfulPageLoad", 0); |
| histogram_tester.ExpectUniqueSample( |
| "OptimizationGuide.PredictionModelRemoved.PainfulPageLoad", true, 1); |
| } |
| |
| TEST_F(PredictionManagerTest, UpdateModelWithSameVersion) { |
| base::HistogramTester histogram_tester; |
| CreatePredictionManager(); |
| prediction_manager()->SetPredictionModelFetcherForTesting( |
| BuildTestPredictionModelFetcher( |
| PredictionModelFetcherEndState::kFetchFailed)); |
| |
| RegisterOptimizationTargets( |
| {{proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, absl::nullopt}}); |
| |
| // Seed the PredictionManager with a prediction model with a higher version |
| // to try to be updated. |
| std::unique_ptr<proto::GetModelsResponse> get_models_response = |
| BuildGetModelsResponse(); |
| get_models_response->mutable_models(0)->mutable_model_info()->set_version(3); |
| |
| prediction_manager()->UpdatePredictionModelsForTesting( |
| get_models_response.get()); |
| histogram_tester.ExpectUniqueSample( |
| "OptimizationGuide.PredictionManager.PredictionModelsStored", true, 1); |
| histogram_tester.ExpectUniqueSample( |
| "OptimizationGuide.PredictionModelUpdateVersion.PainfulPageLoad", 3, 1); |
| histogram_tester.ExpectUniqueSample( |
| "OptimizationGuide.PredictionManager.ModelTypeChanged.PainfulPageLoad", |
| false, 1); |
| |
| get_models_response = BuildGetModelsResponse(); |
| |
| get_models_response->mutable_models(0)->mutable_model_info()->set_version(3); |
| prediction_manager()->UpdatePredictionModelsForTesting( |
| get_models_response.get()); |
| |
| TestPredictionModel* stored_prediction_model = |
| static_cast<TestPredictionModel*>( |
| prediction_manager()->GetPredictionModelForTesting( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD)); |
| EXPECT_TRUE(stored_prediction_model); |
| EXPECT_EQ(3, stored_prediction_model->GetVersion()); |
| histogram_tester.ExpectBucketCount("OptimizationGuide.IsPredictionModelValid", |
| true, 2); |
| } |
| |
| TEST_F(PredictionManagerTest, UpdateModelFileWithSameVersion) { |
| base::HistogramTester histogram_tester; |
| |
| CreatePredictionManager(); |
| |
| FakeOptimizationTargetModelObserver observer; |
| prediction_manager()->AddObserverForOptimizationTargetModel( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, |
| /*model_metadata=*/absl::nullopt, &observer); |
| |
| proto::PredictionModel model; |
| model.mutable_model_info()->set_optimization_target( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD); |
| model.mutable_model_info()->set_version(3); |
| model.mutable_model()->set_download_url( |
| FilePathToString(temp_dir().AppendASCII("whatever2"))); |
| prediction_manager()->OnModelReady(model); |
| RunUntilIdle(); |
| |
| EXPECT_TRUE(observer |
| .last_received_model_for_target( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD) |
| .has_value()); |
| |
| // Now reset the observer state. |
| observer.Reset(); |
| |
| // Send the same model again. |
| prediction_manager()->OnModelReady(model); |
| |
| // The observer should not have received an update. |
| EXPECT_FALSE(observer |
| .last_received_model_for_target( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD) |
| .has_value()); |
| |
| histogram_tester.ExpectUniqueSample( |
| "OptimizationGuide.PredictionManager.ModelTypeChanged.PainfulPageLoad", |
| false, 1); |
| } |
| |
| TEST_F(PredictionManagerTest, DownloadManagerUnavailableShouldNotFetch) { |
| base::HistogramTester histogram_tester; |
| std::unique_ptr<content::MockNavigationHandle> navigation_handle = |
| CreateMockNavigationHandleWithOptimizationGuideWebContentsObserver( |
| GURL("https://foo.com")); |
| |
| CreatePredictionManager(); |
| prediction_manager()->SetPredictionModelFetcherForTesting( |
| BuildTestPredictionModelFetcher( |
| PredictionModelFetcherEndState::kFetchSuccessWithModelDownloadUrls)); |
| prediction_manager()->SetPredictionModelDownloadManagerForTesting( |
| std::make_unique<FakePredictionModelDownloadManager>( |
| task_environment()->GetMainThreadTaskRunner())); |
| prediction_model_download_manager()->SetAvailableForDownloads(false); |
| |
| RegisterOptimizationTargets( |
| {{proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, absl::nullopt}}); |
| |
| SetStoreInitialized(); |
| EXPECT_FALSE(prediction_model_fetcher()->models_fetched()); |
| |
| histogram_tester.ExpectUniqueSample( |
| "OptimizationGuide.PredictionManager." |
| "DownloadServiceAvailabilityBlockedFetch", |
| true, 1); |
| } |
| |
| TEST_F(PredictionManagerTest, UpdateModelWithDownloadUrl) { |
| base::HistogramTester histogram_tester; |
| std::unique_ptr<content::MockNavigationHandle> navigation_handle = |
| CreateMockNavigationHandleWithOptimizationGuideWebContentsObserver( |
| GURL("https://foo.com")); |
| |
| CreatePredictionManager(); |
| prediction_manager()->SetPredictionModelFetcherForTesting( |
| BuildTestPredictionModelFetcher( |
| PredictionModelFetcherEndState::kFetchSuccessWithModelDownloadUrls)); |
| prediction_manager()->SetPredictionModelDownloadManagerForTesting( |
| std::make_unique<FakePredictionModelDownloadManager>( |
| task_environment()->GetMainThreadTaskRunner())); |
| |
| RegisterOptimizationTargets( |
| {{proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, absl::nullopt}}); |
| |
| SetStoreInitialized(); |
| EXPECT_TRUE(prediction_model_fetcher()->models_fetched()); |
| EXPECT_TRUE(prediction_model_download_manager()->cancel_downloads_called()); |
| |
| models_and_features_store()->RunUpdateHostModelFeaturesCallback(); |
| histogram_tester.ExpectUniqueSample( |
| "OptimizationGuide.PredictionManager.HostModelFeaturesStored", true, 1); |
| histogram_tester.ExpectTotalCount( |
| "OptimizationGuide.PredictionManager.PredictionModelsStored", 0); |
| histogram_tester.ExpectUniqueSample( |
| "OptimizationGuide.PredictionManager." |
| "DownloadServiceAvailabilityBlockedFetch", |
| false, 1); |
| histogram_tester.ExpectUniqueSample( |
| "OptimizationGuide.PredictionManager.IsDownloadUrlValid", true, 1); |
| |
| EXPECT_EQ(prediction_model_download_manager()->last_requested_download(), |
| GURL("https://example.com/model")); |
| } |
| |
| TEST_F(PredictionManagerTest, ShouldTargetNavigationStoreAvailableNoModel) { |
| base::HistogramTester histogram_tester; |
| std::unique_ptr<content::MockNavigationHandle> navigation_handle = |
| CreateMockNavigationHandleWithOptimizationGuideWebContentsObserver( |
| GURL("https://foo.com")); |
| |
| CreatePredictionManager(); |
| prediction_manager()->SetPredictionModelFetcherForTesting( |
| BuildTestPredictionModelFetcher( |
| PredictionModelFetcherEndState::kFetchSuccessWithEmptyResponse)); |
| |
| RegisterOptimizationTargets( |
| {{proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, absl::nullopt}}); |
| |
| SetStoreInitialized(/* load_models= */ false, |
| /* load_host_model_features= */ true, |
| /* have_models_in_store= */ false); |
| |
| EXPECT_EQ(OptimizationTargetDecision::kModelNotAvailableOnClient, |
| prediction_manager()->ShouldTargetNavigation( |
| navigation_handle.get(), |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD)); |
| histogram_tester.ExpectBucketCount( |
| "OptimizationGuide.ShouldTargetNavigation.PredictionModelStatus", |
| PredictionManagerModelStatus::kStoreAvailableNoModelForTarget, 1); |
| |
| histogram_tester.ExpectBucketCount( |
| "OptimizationGuide.ShouldTargetNavigation.PredictionModelStatus." + |
| GetStringNameForOptimizationTarget( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD), |
| PredictionManagerModelStatus::kStoreAvailableNoModelForTarget, 1); |
| } |
| |
| TEST_F(PredictionManagerTest, |
| ShouldTargetNavigationStoreAvailableModelNotLoaded) { |
| base::HistogramTester histogram_tester; |
| std::unique_ptr<content::MockNavigationHandle> navigation_handle = |
| CreateMockNavigationHandleWithOptimizationGuideWebContentsObserver( |
| GURL("https://foo.com")); |
| |
| CreatePredictionManager(); |
| prediction_manager()->SetPredictionModelFetcherForTesting( |
| BuildTestPredictionModelFetcher( |
| PredictionModelFetcherEndState::kFetchSuccessWithEmptyResponse)); |
| |
| RegisterOptimizationTargets( |
| {{proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, absl::nullopt}}); |
| |
| SetStoreInitialized(/* load_models= */ false, |
| /* load_host_model_features= */ true, |
| /* have_models_in_store= */ true); |
| |
| EXPECT_EQ(OptimizationTargetDecision::kModelNotAvailableOnClient, |
| prediction_manager()->ShouldTargetNavigation( |
| navigation_handle.get(), |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD)); |
| histogram_tester.ExpectBucketCount( |
| "OptimizationGuide.ShouldTargetNavigation.PredictionModelStatus", |
| PredictionManagerModelStatus::kStoreAvailableModelNotLoaded, 1); |
| |
| histogram_tester.ExpectBucketCount( |
| "OptimizationGuide.ShouldTargetNavigation.PredictionModelStatus." + |
| GetStringNameForOptimizationTarget( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD), |
| PredictionManagerModelStatus::kStoreAvailableModelNotLoaded, 1); |
| |
| histogram_tester.ExpectTotalCount( |
| "OptimizationGuide.PredictionModelLoadedVersion.PainfulPageLoad", 0); |
| } |
| |
| TEST_F(PredictionManagerTest, |
| ShouldTargetNavigationStoreUnavailableModelUnknown) { |
| base::HistogramTester histogram_tester; |
| std::unique_ptr<content::MockNavigationHandle> navigation_handle = |
| CreateMockNavigationHandleWithOptimizationGuideWebContentsObserver( |
| GURL("https://foo.com")); |
| |
| CreatePredictionManager(); |
| prediction_manager()->SetPredictionModelFetcherForTesting( |
| BuildTestPredictionModelFetcher( |
| PredictionModelFetcherEndState::kFetchSuccessWithEmptyResponse)); |
| |
| RegisterOptimizationTargets( |
| {{proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, absl::nullopt}}); |
| |
| EXPECT_EQ(OptimizationTargetDecision::kModelNotAvailableOnClient, |
| prediction_manager()->ShouldTargetNavigation( |
| navigation_handle.get(), |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD)); |
| |
| histogram_tester.ExpectBucketCount( |
| "OptimizationGuide.ShouldTargetNavigation.PredictionModelStatus", |
| PredictionManagerModelStatus::kStoreUnavailableModelUnknown, 1); |
| |
| histogram_tester.ExpectBucketCount( |
| "OptimizationGuide.ShouldTargetNavigation.PredictionModelStatus." + |
| GetStringNameForOptimizationTarget( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD), |
| PredictionManagerModelStatus::kStoreUnavailableModelUnknown, 1); |
| } |
| |
| TEST_F(PredictionManagerTest, UpdateModelForUnregisteredTarget) { |
| base::HistogramTester histogram_tester; |
| CreatePredictionManager(); |
| prediction_manager()->SetPredictionModelFetcherForTesting( |
| BuildTestPredictionModelFetcher( |
| PredictionModelFetcherEndState::kFetchSuccessWithModels)); |
| |
| RegisterOptimizationTargets({}); |
| SetStoreInitialized(); |
| |
| EXPECT_FALSE(prediction_model_fetcher()->models_fetched()); |
| |
| std::unique_ptr<proto::GetModelsResponse> get_models_response = |
| BuildGetModelsResponse(); |
| |
| prediction_manager()->UpdatePredictionModelsForTesting( |
| get_models_response.get()); |
| |
| TestPredictionModel* test_prediction_model = |
| static_cast<TestPredictionModel*>( |
| prediction_manager()->GetPredictionModelForTesting( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD)); |
| EXPECT_FALSE(test_prediction_model); |
| |
| histogram_tester.ExpectTotalCount( |
| "OptimizationGuide.PredictionManager.PredictionModelsStored", 1); |
| histogram_tester.ExpectTotalCount( |
| "OptimizationGuide.PredictionManager.HostModelFeaturesStored", 0); |
| histogram_tester.ExpectTotalCount( |
| "OptimizationGuide.PredictionModelLoadedVersion.PainfulPageLoad", 0); |
| } |
| |
| TEST_F(PredictionManagerTest, UpdateModelForUnregisteredTargetOnModelReady) { |
| base::HistogramTester histogram_tester; |
| CreatePredictionManager(); |
| |
| RegisterOptimizationTargets({}); |
| SetStoreInitialized(); |
| |
| proto::PredictionModel model; |
| model.mutable_model_info()->set_optimization_target( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD); |
| model.mutable_model_info()->set_version(3); |
| model.mutable_model()->set_download_url( |
| FilePathToString(temp_dir().AppendASCII("whatever"))); |
| prediction_manager()->OnModelReady(model); |
| |
| histogram_tester.ExpectTotalCount( |
| "OptimizationGuide.PredictionManager.PredictionModelsStored", 1); |
| histogram_tester.ExpectTotalCount( |
| "OptimizationGuide.PredictionManager.HostModelFeaturesStored", 0); |
| histogram_tester.ExpectTotalCount( |
| "OptimizationGuide.PredictionModelLoadedVersion.PainfulPageLoad", 0); |
| } |
| |
| TEST_F(PredictionManagerTest, UpdateModelForRegisteredTargetButNowFile) { |
| std::unique_ptr<content::MockNavigationHandle> navigation_handle = |
| CreateMockNavigationHandleWithOptimizationGuideWebContentsObserver( |
| GURL("https://foo.com")); |
| |
| base::HistogramTester histogram_tester; |
| CreatePredictionManager(); |
| |
| RegisterOptimizationTargets( |
| {{proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, absl::nullopt}}); |
| SetStoreInitialized(); |
| histogram_tester.ExpectUniqueSample( |
| "OptimizationGuide.PredictionModelLoadedVersion.PainfulPageLoad", 1, 1); |
| |
| EXPECT_EQ(OptimizationTargetDecision::kPageLoadMatches, |
| prediction_manager()->ShouldTargetNavigation( |
| navigation_handle.get(), |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD)); |
| |
| // Now, update the model to be a file. |
| proto::PredictionModel model; |
| model.mutable_model_info()->set_optimization_target( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD); |
| model.mutable_model_info()->set_version(3); |
| model.mutable_model()->set_download_url( |
| FilePathToString(temp_dir().AppendASCII("whatever"))); |
| prediction_manager()->OnModelReady(model); |
| |
| histogram_tester.ExpectTotalCount( |
| "OptimizationGuide.PredictionManager.PredictionModelsStored", 1); |
| histogram_tester.ExpectTotalCount( |
| "OptimizationGuide.PredictionManager.HostModelFeaturesStored", 0); |
| histogram_tester.ExpectBucketCount( |
| "OptimizationGuide.PredictionModelLoadedVersion.PainfulPageLoad", 3, 1); |
| histogram_tester.ExpectBucketCount( |
| "OptimizationGuide.PredictionManager.ModelTypeChanged.PainfulPageLoad", |
| true, 1); |
| |
| // Expect that the old decision tree should not be used. |
| EXPECT_EQ(OptimizationTargetDecision::kModelNotAvailableOnClient, |
| prediction_manager()->ShouldTargetNavigation( |
| navigation_handle.get(), |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD)); |
| } |
| |
| TEST_F(PredictionManagerTest, UpdateModelWithUnsupportedOptimizationTarget) { |
| std::unique_ptr<content::MockNavigationHandle> navigation_handle = |
| CreateMockNavigationHandleWithOptimizationGuideWebContentsObserver( |
| GURL("https://foo.com")); |
| |
| CreatePredictionManager(); |
| prediction_manager()->SetPredictionModelFetcherForTesting( |
| BuildTestPredictionModelFetcher( |
| PredictionModelFetcherEndState::kFetchFailed)); |
| |
| RegisterOptimizationTargets( |
| {{proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, absl::nullopt}}); |
| |
| EXPECT_FALSE(prediction_model_fetcher()->models_fetched()); |
| EXPECT_FALSE(models_and_features_store()->WasModelLoaded()); |
| |
| std::unique_ptr<proto::GetModelsResponse> get_models_response = |
| BuildGetModelsResponse(); |
| get_models_response->mutable_models(0) |
| ->mutable_model_info() |
| ->clear_optimization_target(); |
| prediction_manager()->UpdatePredictionModelsForTesting( |
| get_models_response.get()); |
| |
| EXPECT_EQ(OptimizationTargetDecision::kModelNotAvailableOnClient, |
| prediction_manager()->ShouldTargetNavigation( |
| navigation_handle.get(), |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD)); |
| |
| TestPredictionModel* test_prediction_model = |
| static_cast<TestPredictionModel*>( |
| prediction_manager()->GetPredictionModelForTesting( |
| proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD)); |
| EXPECT_FALSE(test_prediction_model); |
| EXPECT_FALSE(models_and_features_store()->WasModelLoaded()); |
| } |
| |
| TEST_F(PredictionManagerTest, |
| StoreInitializedAfterOptimizationTargetRegistered) { |
| base::HistogramTester histogram_tester; |
| CreatePredictionManager(); |
| // Ensure that the fetch does not cause any models or features to load. |
| prediction_manager()->SetPredictionModelFetcherForTesting( |
| BuildTestPredictionModelFetcher( |
| PredictionModelFetcherEndState::kFetchFailed)); |
| RegisterOptimizationTargets( |
| {{proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, absl::nullopt}}); |
| EXPECT_FALSE(models_and_features_store()->WasHostModelFeaturesLoaded()); |
| EXPECT_FALSE(models_and_features_store()->WasModelLoaded()); |
| EXPECT_FALSE(prediction_manager()->GetHostModelFeaturesForHost("foo.com")); |
| |
| SetStoreInitialized(); |
| EXPECT_TRUE(models_and_features_store()->WasHostModelFeaturesLoaded()); |
| EXPECT_TRUE(models_and_features_store()->WasModelLoaded()); |
| EXPECT_TRUE(prediction_manager()->GetHostModelFeaturesForHost("foo.com")); |
| |
| EXPECT_FALSE(prediction_model_fetcher()->models_fetched()); |
| histogram_tester.ExpectUniqueSample( |
| "OptimizationGuide.PredictionModelLoadedVersion.PainfulPageLoad", 1, 1); |
| } |
| |
| TEST_F(PredictionManagerTest, |
| StoreInitializedBeforeOptimizationTargetRegistered) { |
| base::HistogramTester histogram_tester; |
| CreatePredictionManager(); |
| // Ensure that the fetch does not cause any models or features to load. |
| prediction_manager()->SetPredictionModelFetcherForTesting( |
| BuildTestPredictionModelFetcher( |
| PredictionModelFetcherEndState::kFetchFailed)); |
| SetStoreInitialized(); |
| |
| EXPECT_FALSE(models_and_features_store()->WasHostModelFeaturesLoaded()); |
| EXPECT_FALSE(models_and_features_store()->WasModelLoaded()); |
| EXPECT_FALSE(prediction_manager()->GetHostModelFeaturesForHost("foo.com")); |
| RegisterOptimizationTargets( |
| {{proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, absl::nullopt}}); |
| RunUntilIdle(); |
| |
| EXPECT_TRUE(models_and_features_store()->WasHostModelFeaturesLoaded()); |
| EXPECT_TRUE(models_and_features_store()->WasModelLoaded()); |
| EXPECT_TRUE(prediction_manager()->GetHostModelFeaturesForHost("foo.com")); |
| |
| EXPECT_FALSE(prediction_model_fetcher()->models_fetched()); |
| histogram_tester.ExpectUniqueSample( |
| "OptimizationGuide.PredictionModelLoadedVersion.PainfulPageLoad", 1, 1); |
| } |
| |
| TEST_F(PredictionManagerTest, ModelFetcherTimerRetryDelay) { |
| base::CommandLine::ForCurrentProcess()->RemoveSwitch( |
| switches::kFetchModelsAndHostModelFeaturesOverrideTimer); |
| |
| CreatePredictionManager(); |
| prediction_manager()->SetPredictionModelFetcherForTesting( |
| BuildTestPredictionModelFetcher( |
| PredictionModelFetcherEndState::kFetchFailed)); |
| |
| RegisterOptimizationTargets( |
| {{proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, absl::nullopt}}); |
| |
| SetStoreInitialized(); |
| EXPECT_FALSE(prediction_model_fetcher()->models_fetched()); |
| |
| MoveClockForwardBy(base::Seconds(kTestFetchRetryDelaySecs)); |
| EXPECT_FALSE(prediction_model_fetcher()->models_fetched()); |
| |
| prediction_manager()->SetPredictionModelFetcherForTesting( |
| BuildTestPredictionModelFetcher( |
| PredictionModelFetcherEndState::kFetchSuccessWithModels)); |
| |
| MoveClockForwardBy(base::Seconds(kTestFetchRetryDelaySecs)); |
| EXPECT_TRUE(prediction_model_fetcher()->models_fetched()); |
| } |
| |
| TEST_F(PredictionManagerTest, ModelFetcherTimerFetchSucceeds) { |
| base::CommandLine::ForCurrentProcess()->RemoveSwitch( |
| switches::kFetchModelsAndHostModelFeaturesOverrideTimer); |
| |
| CreatePredictionManager(); |
| prediction_manager()->SetPredictionModelFetcherForTesting( |
| BuildTestPredictionModelFetcher( |
| PredictionModelFetcherEndState::kFetchSuccessWithModels)); |
| |
| g_browser_process->SetApplicationLocale("en-US"); |
| |
| RegisterOptimizationTargets( |
| {{proto::OPTIMIZATION_TARGET_PAINFUL_PAGE_LOAD, absl::nullopt}}); |
| |
| SetStoreInitialized(); |
| EXPECT_FALSE(prediction_model_fetcher()->models_fetched()); |
| MoveClockForwardBy(base::Seconds(kTestFetchRetryDelaySecs)); |
| EXPECT_TRUE(prediction_model_fetcher()->models_fetched()); |
| EXPECT_EQ("en-US", prediction_model_fetcher()->locale_requested()); |
| |
| // Reset the prediction model fetcher to detect when the next fetch occurs. |
| prediction_manager()->SetPredictionModelFetcherForTesting( |
| BuildTestPredictionModelFetcher( |
| PredictionModelFetcherEndState::kFetchSuccessWithModels)); |
| MoveClockForwardBy(base::Seconds(kTestFetchRetryDelaySecs)); |
| EXPECT_FALSE(prediction_model_fetcher()->models_fetched()); |
| MoveClockForwardBy(base::Seconds(kUpdateFetchModelAndFeaturesTimeSecs)); |
| EXPECT_TRUE(prediction_model_fetcher()->models_fetched()); |
| } |
| |
| } // namespace optimization_guide |