blob: b907e04d8e647db7370e63feb651ed1dfc7f0f7f [file] [log] [blame]
// Copyright 2023 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/segmentation_platform/segmentation_platform_service_factory.h"
#include <memory>
#include <string_view>
#include "base/functional/bind.h"
#include "base/memory/scoped_refptr.h"
#include "base/test/scoped_command_line.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/simple_test_clock.h"
#include "base/test/task_environment.h"
#include "chrome/browser/segmentation_platform/ukm_data_manager_test_utils.h"
#include "chrome/test/base/testing_profile.h"
#include "components/optimization_guide/core/optimization_guide_features.h"
#include "components/prefs/pref_change_registrar.h"
#include "components/prefs/pref_observer.h"
#include "components/prefs/pref_service.h"
#include "components/segmentation_platform/embedder/default_model/contextual_page_actions_model.h"
#include "components/segmentation_platform/embedder/default_model/most_visited_tiles_user.h"
#include "components/segmentation_platform/internal/constants.h"
#include "components/segmentation_platform/internal/database/client_result_prefs.h"
#include "components/segmentation_platform/internal/segmentation_ukm_helper.h"
#include "components/segmentation_platform/public/constants.h"
#include "components/segmentation_platform/public/features.h"
#include "components/segmentation_platform/public/prediction_options.h"
#include "components/segmentation_platform/public/result.h"
#include "components/segmentation_platform/public/segmentation_platform_service.h"
#include "components/segmentation_platform/public/service_proxy.h"
#include "components/ukm/test_ukm_recorder.h"
#include "content/public/test/browser_task_environment.h"
#include "services/metrics/public/cpp/ukm_builders.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace segmentation_platform {
namespace {
using Segmentation_ModelExecutionUkmRecorder =
::ukm::builders::Segmentation_ModelExecution;
// Observer that waits for service_ initialization.
class WaitServiceInitializedObserver : public ServiceProxy::Observer {
public:
explicit WaitServiceInitializedObserver(base::OnceClosure closure)
: closure_(std::move(closure)) {}
void OnServiceStatusChanged(bool initialized, int status_flags) override {
if (initialized) {
std::move(closure_).Run();
}
}
private:
base::OnceClosure closure_;
};
class SegmentationPlatformServiceFactoryTest : public testing::Test {
protected:
SegmentationPlatformServiceFactoryTest()
: task_environment_{base::test::TaskEnvironment::TimeSource::MOCK_TIME},
test_utils_(std::make_unique<UkmDataManagerTestUtils>(&ukm_recorder_)) {
scoped_command_line_.GetProcessCommandLine()->AppendSwitch(
kSegmentationPlatformRefreshResultsSwitch);
scoped_command_line_.GetProcessCommandLine()->AppendSwitch(
kSegmentationPlatformDisableModelExecutionDelaySwitch);
}
~SegmentationPlatformServiceFactoryTest() override = default;
void SetUp() override { test_utils_->PreProfileInit({}); }
void TearDown() override {
task_environment_.RunUntilIdle();
profile_.reset();
test_utils_.reset();
}
void InitServiceAndCacheResults(const std::string& segmentation_key) {
InitService();
WaitForClientResultPrefUpdate(segmentation_key);
// Getting the updated prefs from this session to be copied to the next
// session. In the test environment, new session doesn't have prefs from
// previous session, hence copying is required to get the cached result from
// last session.
const std::string output = profile_->profile->GetPrefs()->GetString(
kSegmentationClientResultPrefs);
// TODO(b/297091996): Remove this when leak is fixed.
task_environment_.RunUntilIdle();
profile_.reset();
// Creating profile and initialising segmentation service again with prefs
// from the last session.
profile_ = std::make_unique<ProfileData>(test_utils_.get(), output);
// Copying the prefs from last session.
WaitForServiceInit();
// TODO(b/297091996): Remove this when leak is fixed.
task_environment_.RunUntilIdle();
}
void InitService() {
scoped_feature_list_.InitWithFeaturesAndParameters(
{{optimization_guide::features::kOptimizationTargetPrediction, {}},
{features::kSegmentationPlatformFeature, {}},
{features::kSegmentationPlatformUkmEngine, {}},
{features::kContextualPageActionShareModel, {}},
{features::kSegmentationPlatformTimeDelaySampling,
{{"SamplingRate", "1"}}},
{features::kSegmentationPlatformTabResumptionRanker, {}},
{features::kSegmentationPlatformAndroidHomeModuleRanker, {}}},
{});
// Creating profile and initialising segmentation service.
profile_ = std::make_unique<ProfileData>(test_utils_.get(), "");
WaitForServiceInit();
clock_.SetNow(base::Time::Now());
// TODO(b/297091996): Remove this when leak is fixed.
task_environment_.RunUntilIdle();
}
void ExpectGetClassificationResult(
const std::string& segmentation_key,
const PredictionOptions& prediction_options,
scoped_refptr<InputContext> input_context,
PredictionStatus expected_status,
std::optional<std::vector<std::string>> expected_labels) {
base::RunLoop loop;
profile_->service->GetClassificationResult(
segmentation_key, prediction_options, input_context,
base::BindOnce(
&SegmentationPlatformServiceFactoryTest::OnGetClassificationResult,
base::Unretained(this), loop.QuitClosure(), expected_status,
expected_labels));
loop.Run();
}
void OnGetClassificationResult(
base::RepeatingClosure closure,
PredictionStatus expected_status,
std::optional<std::vector<std::string>> expected_labels,
const ClassificationResult& actual_result) {
EXPECT_EQ(actual_result.status, expected_status);
if (expected_labels.has_value()) {
EXPECT_EQ(actual_result.ordered_labels, expected_labels.value());
}
std::move(closure).Run();
}
void ExpectGetAnnotatedNumericResult(
const std::string& segmentation_key,
const PredictionOptions& prediction_options,
scoped_refptr<InputContext> input_context,
PredictionStatus expected_status) {
base::RunLoop loop;
profile_->service->GetAnnotatedNumericResult(
segmentation_key, prediction_options, input_context,
base::BindOnce(&SegmentationPlatformServiceFactoryTest::
OnGetAnnotatedNumericResult,
base::Unretained(this), loop.QuitClosure(),
expected_status));
loop.Run();
}
void OnGetAnnotatedNumericResult(
base::RepeatingClosure closure,
PredictionStatus expected_status,
const AnnotatedNumericResult& actual_result) {
ASSERT_EQ(expected_status, actual_result.status);
std::move(closure).Run();
}
void WaitForServiceInit() {
base::RunLoop wait_for_init;
WaitServiceInitializedObserver wait_observer(wait_for_init.QuitClosure());
profile_->service->GetServiceProxy()->AddObserver(&wait_observer);
wait_for_init.Run();
ASSERT_TRUE(profile_->service->IsPlatformInitialized());
profile_->service->GetServiceProxy()->RemoveObserver(&wait_observer);
}
bool HasClientResultPref(const std::string& segmentation_key) {
PrefService* pref_service_ = profile_->profile->GetPrefs();
std::unique_ptr<ClientResultPrefs> result_prefs_ =
std::make_unique<ClientResultPrefs>(pref_service_);
return result_prefs_->ReadClientResultFromPrefs(segmentation_key) !=
nullptr;
}
void OnClientResultPrefUpdated(const std::string& segmentation_key) {
if (!wait_for_pref_callback_.is_null() &&
HasClientResultPref(segmentation_key)) {
std::move(wait_for_pref_callback_).Run();
}
}
void WaitForClientResultPrefUpdate(const std::string& segmentation_key) {
if (HasClientResultPref(segmentation_key)) {
return;
}
base::RunLoop wait_for_pref;
wait_for_pref_callback_ = wait_for_pref.QuitClosure();
pref_registrar_.Init(profile_->profile->GetPrefs());
pref_registrar_.Add(
kSegmentationClientResultPrefs,
base::BindRepeating(
&SegmentationPlatformServiceFactoryTest::OnClientResultPrefUpdated,
base::Unretained(this), segmentation_key));
wait_for_pref.Run();
pref_registrar_.RemoveAll();
}
void ExpectUkm(std::vector<std::string_view> metric_names,
std::vector<int64_t> expected_values) {
const auto& entries = test_recorder_.GetEntriesByName(
Segmentation_ModelExecutionUkmRecorder::kEntryName);
ASSERT_EQ(1u, entries.size());
for (size_t i = 0; i < metric_names.size(); ++i) {
test_recorder_.ExpectEntryMetric(entries[0], metric_names[i],
expected_values[i]);
}
}
void ExpectUkmCount(size_t count) {
const auto& entries = test_recorder_.GetEntriesByName(
Segmentation_ModelExecutionUkmRecorder::kEntryName);
ASSERT_EQ(count, entries.size());
}
void WaitForUkmRecord(proto::SegmentId segment_id) {
base::RunLoop run_loop;
test_recorder()->SetOnAddEntryCallback(
Segmentation_ModelExecutionUkmRecorder::kEntryName,
base::BindRepeating(
[](proto::SegmentId id, ukm::TestAutoSetUkmRecorder* test_recorder,
base::OnceClosure loop) {
const auto& entries = test_recorder->GetEntriesByName(
Segmentation_ModelExecutionUkmRecorder::kEntryName);
if (entries.size() == 1u) {
const int64_t* metric = test_recorder->GetEntryMetric(
entries[0], Segmentation_ModelExecutionUkmRecorder::
kOptimizationTargetName);
if (metric && *metric == id) {
std::move(loop).Run();
}
}
},
segment_id, base::Unretained(test_recorder()),
run_loop.QuitClosure()));
run_loop.Run();
}
// This only checks for training data ukm records. Model execution UKM records
// are not collected for default model.
void WaitAndCheckUkmRecord(proto::SegmentId segment_id,
std::vector<int64_t> inputs,
std::vector<int64_t> outputs) {
WaitForUkmRecord(segment_id);
ExpectUkm({Segmentation_ModelExecutionUkmRecorder::kOptimizationTargetName},
{segment_id});
// Check for inputs in the model.
// Append more if required.
std::vector<std::string_view> inputs_ukm_metrics = {
Segmentation_ModelExecutionUkmRecorder::kInput0Name,
Segmentation_ModelExecutionUkmRecorder::kInput1Name,
Segmentation_ModelExecutionUkmRecorder::kInput2Name,
Segmentation_ModelExecutionUkmRecorder::kInput3Name,
Segmentation_ModelExecutionUkmRecorder::kInput4Name,
Segmentation_ModelExecutionUkmRecorder::kInput5Name,
Segmentation_ModelExecutionUkmRecorder::kInput6Name,
Segmentation_ModelExecutionUkmRecorder::kInput7Name,
Segmentation_ModelExecutionUkmRecorder::kInput8Name,
Segmentation_ModelExecutionUkmRecorder::kInput9Name,
Segmentation_ModelExecutionUkmRecorder::kInput10Name,
};
if (inputs.size() > 0) {
std::vector<std::string_view> input_metric_name(
inputs_ukm_metrics.begin(),
inputs_ukm_metrics.begin() + inputs.size());
ExpectUkm({input_metric_name}, {inputs});
}
// Check for output in the model.
// Append more if required.
std::vector<std::string_view> outputs_ukm_metrics = {
Segmentation_ModelExecutionUkmRecorder::kActualResultName,
Segmentation_ModelExecutionUkmRecorder::kActualResult2Name,
Segmentation_ModelExecutionUkmRecorder::kActualResult3Name,
Segmentation_ModelExecutionUkmRecorder::kActualResult4Name,
Segmentation_ModelExecutionUkmRecorder::kActualResult5Name,
Segmentation_ModelExecutionUkmRecorder::kActualResult6Name};
if (outputs.size() > 0) {
std::vector<std::string_view> output_metric_name(
outputs_ukm_metrics.begin(),
outputs_ukm_metrics.begin() + outputs.size());
ExpectUkm({output_metric_name}, {outputs});
}
}
struct ProfileData {
explicit ProfileData(UkmDataManagerTestUtils* test_utils,
const std::string& result_pref)
: test_utils(test_utils), profile(TestingProfile::Builder().Build()) {
profile->GetPrefs()->SetString(kSegmentationClientResultPrefs,
result_pref);
test_utils->SetupForProfile(profile.get());
service =
SegmentationPlatformServiceFactory::GetForProfile(profile.get());
}
~ProfileData() { test_utils->WillDestroyProfile(profile.get()); }
ProfileData(ProfileData&) = delete;
const raw_ptr<UkmDataManagerTestUtils> test_utils;
std::unique_ptr<TestingProfile> profile;
raw_ptr<SegmentationPlatformService> service;
};
base::SimpleTestClock* clock() { return &clock_; }
ukm::TestAutoSetUkmRecorder* test_recorder() { return &test_recorder_; }
base::SimpleTestClock clock_;
content::BrowserTaskEnvironment task_environment_;
ukm::TestAutoSetUkmRecorder test_recorder_;
base::test::ScopedFeatureList scoped_feature_list_;
base::test::ScopedCommandLine scoped_command_line_;
ukm::TestUkmRecorder ukm_recorder_;
std::unique_ptr<UkmDataManagerTestUtils> test_utils_;
PrefChangeRegistrar pref_registrar_;
base::OnceClosure wait_for_pref_callback_;
std::unique_ptr<ProfileData> profile_;
};
TEST_F(SegmentationPlatformServiceFactoryTest, TestPasswordManagerUserSegment) {
InitServiceAndCacheResults(kPasswordManagerUserKey);
PredictionOptions prediction_options;
ExpectGetClassificationResult(
kPasswordManagerUserKey, prediction_options, nullptr,
/*expected_status=*/PredictionStatus::kSucceeded,
/*expected_labels=*/
std::vector<std::string>(1, "Not_PasswordManagerUser"));
}
// Segmentation Ukm Engine is disabled on CrOS.
#if !BUILDFLAG(IS_CHROMEOS)
TEST_F(SegmentationPlatformServiceFactoryTest, TestSearchUserModel) {
InitServiceAndCacheResults(kSearchUserKey);
PredictionOptions prediction_options;
ExpectGetClassificationResult(
kSearchUserKey, prediction_options, nullptr,
/*expected_status=*/PredictionStatus::kSucceeded,
/*expected_labels=*/
std::vector<std::string>(1, kSearchUserModelLabelNone));
}
#endif //! BUILDFLAG(IS_CHROMEOS)
TEST_F(SegmentationPlatformServiceFactoryTest, TestShoppingUserModel) {
InitServiceAndCacheResults(kShoppingUserSegmentationKey);
PredictionOptions prediction_options;
ExpectGetClassificationResult(
kShoppingUserSegmentationKey, prediction_options, nullptr,
/*expected_status=*/PredictionStatus::kSucceeded,
/*expected_labels=*/
std::vector<std::string>(1, kLegacyNegativeLabel));
}
TEST_F(SegmentationPlatformServiceFactoryTest, TestResumeHeavyUserModel) {
InitServiceAndCacheResults(kResumeHeavyUserKey);
PredictionOptions prediction_options;
ExpectGetClassificationResult(
kResumeHeavyUserKey, prediction_options, nullptr,
/*expected_status=*/PredictionStatus::kSucceeded,
/*expected_labels=*/
std::vector<std::string>(1, kLegacyNegativeLabel));
}
TEST_F(SegmentationPlatformServiceFactoryTest, TestLowUserEngagementModel) {
InitServiceAndCacheResults(kChromeLowUserEngagementSegmentationKey);
PredictionOptions prediction_options;
ExpectGetClassificationResult(
kChromeLowUserEngagementSegmentationKey, prediction_options, nullptr,
/*expected_status=*/PredictionStatus::kSucceeded,
/*expected_labels=*/
std::vector<std::string>(1, kChromeLowUserEngagementUmaName));
}
TEST_F(SegmentationPlatformServiceFactoryTest, TestCrossDeviceModel) {
InitServiceAndCacheResults(segmentation_platform::kCrossDeviceUserKey);
segmentation_platform::PredictionOptions prediction_options;
ExpectGetClassificationResult(
segmentation_platform::kCrossDeviceUserKey, prediction_options, nullptr,
/*expected_status=*/segmentation_platform::PredictionStatus::kSucceeded,
/*expected_labels=*/
std::vector<std::string>(1, segmentation_platform::kNoCrossDeviceUsage));
}
TEST_F(SegmentationPlatformServiceFactoryTest, TestDeviceSwitcherModel) {
InitService();
PredictionOptions prediction_options;
prediction_options.on_demand_execution = true;
auto input_context = base::MakeRefCounted<InputContext>();
input_context->metadata_args.emplace("wait_for_device_info_in_seconds",
processing::ProcessedValue(0));
ExpectGetClassificationResult(
kDeviceSwitcherKey, prediction_options, input_context,
/*expected_status=*/PredictionStatus::kSucceeded,
/*expected_labels=*/std::vector<std::string>(1, "NotSynced"));
}
// Segmentation Ukm Engine is disabled on CrOS.
#if !BUILDFLAG(IS_CHROMEOS)
TEST_F(SegmentationPlatformServiceFactoryTest, TabResupmtionRanker) {
InitService();
PredictionOptions prediction_options;
prediction_options.on_demand_execution = true;
scoped_refptr<InputContext> input_context =
base::MakeRefCounted<InputContext>();
input_context->metadata_args.emplace(
"session_tag", processing::ProcessedValue(std::string("")));
input_context->metadata_args.emplace("tab_id", processing::ProcessedValue(1));
input_context->metadata_args.emplace(
"origin", processing::ProcessedValue(GURL("https://www.google.com")));
ExpectGetAnnotatedNumericResult(kTabResumptionClassifierKey,
prediction_options, input_context,
PredictionStatus::kSucceeded);
}
#endif //! BUILDFLAG(IS_CHROMEOS)
#if BUILDFLAG(IS_ANDROID)
// Tests for models in android platform.
TEST_F(SegmentationPlatformServiceFactoryTest, TestDeviceTierSegment) {
InitServiceAndCacheResults(kDeviceTierKey);
PredictionOptions prediction_options;
ExpectGetClassificationResult(
kDeviceTierKey, prediction_options, nullptr,
/*expected_status=*/PredictionStatus::kSucceeded,
/*expected_labels=*/std::nullopt);
}
TEST_F(SegmentationPlatformServiceFactoryTest,
TestTabletProductivityUserModel) {
InitServiceAndCacheResults(kTabletProductivityUserKey);
PredictionOptions prediction_options;
ExpectGetClassificationResult(
kTabletProductivityUserKey, prediction_options, nullptr,
/*expected_status=*/PredictionStatus::kSucceeded,
/*expected_labels=*/
std::vector<std::string>(1, kTabletProductivityUserModelLabelNone));
}
TEST_F(SegmentationPlatformServiceFactoryTest, TestContextualPageActionsShare) {
InitService();
PredictionOptions prediction_options;
prediction_options.on_demand_execution = true;
auto input_context = base::MakeRefCounted<InputContext>();
input_context->metadata_args.emplace(
segmentation_platform::kContextualPageActionModelInputPriceTracking,
segmentation_platform::processing::ProcessedValue::FromFloat(1));
input_context->metadata_args.emplace(
segmentation_platform::kContextualPageActionModelInputReaderMode,
segmentation_platform::processing::ProcessedValue::FromFloat(0));
ExpectGetClassificationResult(
kContextualPageActionsKey, prediction_options, input_context,
/*expected_status=*/PredictionStatus::kSucceeded,
/*expected_labels=*/
std::vector<std::string>(1,
kContextualPageActionModelLabelPriceTracking));
clock()->Advance(base::Seconds(
ContextualPageActionsModel::kShareOutputCollectionDelayInSec));
WaitAndCheckUkmRecord(
proto::OPTIMIZATION_TARGET_CONTEXTUAL_PAGE_ACTION_PRICE_TRACKING,
/*inputs=*/
{SegmentationUkmHelper::FloatToInt64(1.f), 0, 0, 0, 0, 0, 0, 0},
/*outputs=*/{0, 0, 0, 0, 0, 0});
}
TEST_F(SegmentationPlatformServiceFactoryTest, TestFrequentFeatureModel) {
InitServiceAndCacheResults(kFrequentFeatureUserKey);
PredictionOptions prediction_options;
ExpectGetClassificationResult(
kFrequentFeatureUserKey, prediction_options, nullptr,
/*expected_status=*/PredictionStatus::kSucceeded,
/*expected_labels=*/
std::vector<std::string>{kLegacyNegativeLabel});
}
TEST_F(SegmentationPlatformServiceFactoryTest, TestIntentionalUserModel) {
InitServiceAndCacheResults(segmentation_platform::kIntentionalUserKey);
segmentation_platform::PredictionOptions prediction_options;
ExpectGetClassificationResult(
segmentation_platform::kIntentionalUserKey, prediction_options, nullptr,
/*expected_status=*/segmentation_platform::PredictionStatus::kSucceeded,
/*expected_labels=*/
std::vector<std::string>(1, kLegacyNegativeLabel));
}
TEST_F(SegmentationPlatformServiceFactoryTest, TestPowerUserSegment) {
InitServiceAndCacheResults(kPowerUserKey);
PredictionOptions prediction_options;
ExpectGetClassificationResult(
kPowerUserKey, prediction_options, nullptr,
/*expected_status=*/PredictionStatus::kSucceeded,
/*expected_labels=*/
std::vector<std::string>{"None"});
}
TEST_F(SegmentationPlatformServiceFactoryTest, MostVisitedTilesUser) {
InitServiceAndCacheResults(
segmentation_platform::MostVisitedTilesUser::kMostVisitedTilesUserKey);
segmentation_platform::PredictionOptions prediction_options;
ExpectGetClassificationResult(
segmentation_platform::MostVisitedTilesUser::kMostVisitedTilesUserKey,
prediction_options, nullptr,
/*expected_status=*/segmentation_platform::PredictionStatus::kSucceeded,
/*expected_labels=*/
std::vector<std::string>(1, "None"));
}
TEST_F(SegmentationPlatformServiceFactoryTest, TestFeedUserModel) {
InitServiceAndCacheResults(segmentation_platform::kFeedUserSegmentationKey);
segmentation_platform::PredictionOptions prediction_options;
ExpectGetClassificationResult(
segmentation_platform::kFeedUserSegmentationKey, prediction_options,
nullptr,
/*expected_status=*/segmentation_platform::PredictionStatus::kSucceeded,
/*expected_labels=*/
std::vector<std::string>(1, kLegacyNegativeLabel));
}
TEST_F(SegmentationPlatformServiceFactoryTest, TestAndroidHomeModuleRanker) {
InitService();
PredictionOptions prediction_options;
prediction_options.on_demand_execution = true;
auto input_context = base::MakeRefCounted<InputContext>();
input_context->metadata_args.emplace(
segmentation_platform::kSingleTabFreshness,
segmentation_platform::processing::ProcessedValue::FromFloat(-1));
input_context->metadata_args.emplace(
segmentation_platform::kPriceChangeFreshness,
segmentation_platform::processing::ProcessedValue::FromFloat(-1));
input_context->metadata_args.emplace(
segmentation_platform::kTabResumptionForAndroidHomeFreshness,
segmentation_platform::processing::ProcessedValue::FromFloat(-1));
std::vector<std::string> result = {kPriceChange, kSingleTab,
kTabResumptionForAndroidHome};
ExpectGetClassificationResult(
segmentation_platform::kAndroidHomeModuleRankerKey, prediction_options,
input_context,
/*expected_status=*/segmentation_platform::PredictionStatus::kSucceeded,
/*expected_labels=*/result);
}
#endif // BUILDFLAG(IS_ANDROID)
} // namespace
} // namespace segmentation_platform