blob: 31ded2ed7df3354e3c6bd93be9fd93de2f799e61 [file] [log] [blame]
// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/segmentation_platform/internal/selection/request_dispatcher.h"
#include <memory>
#include "base/memory/raw_ptr.h"
#include "base/metrics/user_metrics.h"
#include "base/run_loop.h"
#include "base/task/single_thread_task_executor.h"
#include "base/task/single_thread_task_runner.h"
#include "base/test/gmock_callback_support.h"
#include "base/test/simple_test_clock.h"
#include "base/test/task_environment.h"
#include "base/time/time.h"
#include "components/prefs/pref_registry_simple.h"
#include "components/prefs/testing_pref_service.h"
#include "components/segmentation_platform/internal/constants.h"
#include "components/segmentation_platform/internal/database/cached_result_writer.h"
#include "components/segmentation_platform/internal/database/config_holder.h"
#include "components/segmentation_platform/internal/database/signal_database.h"
#include "components/segmentation_platform/internal/database/signal_storage_config.h"
#include "components/segmentation_platform/internal/database/storage_service.h"
#include "components/segmentation_platform/internal/metadata/metadata_utils.h"
#include "components/segmentation_platform/internal/metadata/metadata_writer.h"
#include "components/segmentation_platform/internal/mock_ukm_data_manager.h"
#include "components/segmentation_platform/internal/post_processor/post_processing_test_utils.h"
#include "components/segmentation_platform/internal/selection/request_handler.h"
#include "components/segmentation_platform/internal/selection/segment_result_provider.h"
#include "components/segmentation_platform/public/config.h"
#include "components/segmentation_platform/public/prediction_options.h"
#include "components/segmentation_platform/public/proto/prediction_result.pb.h"
#include "components/segmentation_platform/public/result.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
using testing::_;
using testing::Invoke;
namespace segmentation_platform {
namespace {
// Test clients.
const char kDeviceSwitcherClient[] = "device_switcher";
const char kAdaptiveToolbarClient[] = "adaptive_toolbar";
const char kTestLabel1[] = "test_label1";
const char kTestLabel2[] = "test_label2";
proto::PredictionResult CreatePredictionResultWithBinaryClassifier(
const char* const label) {
proto::SegmentationModelMetadata model_metadata;
MetadataWriter writer(&model_metadata);
writer.AddOutputConfigForBinaryClassifier(0.5f, label, "unused");
proto::PredictionResult prediction_result;
prediction_result.add_result(0.8f);
prediction_result.mutable_output_config()->Swap(
model_metadata.mutable_output_config());
return prediction_result;
}
proto::PredictionResult CreatePredictionResultWithGenericPredictor() {
proto::SegmentationModelMetadata model_metadata;
MetadataWriter writer(&model_metadata);
writer.AddOutputConfigForGenericPredictor({"output1", "output2"});
proto::PredictionResult prediction_result;
prediction_result.add_result(0.8f);
prediction_result.add_result(0.2f);
prediction_result.mutable_output_config()->Swap(
model_metadata.mutable_output_config());
return prediction_result;
}
class MockRequestHandler : public RequestHandler {
public:
MockRequestHandler() = default;
~MockRequestHandler() override = default;
MOCK_METHOD3(GetPredictionResult,
void(const PredictionOptions& prediction_options,
scoped_refptr<InputContext> input_context,
RawResultCallback callback));
};
class RequestDispatcherTest : public testing::Test {
public:
RequestDispatcherTest() = default;
~RequestDispatcherTest() override = default;
void SetUp() override {
base::SetRecordActionTaskRunner(
task_environment_.GetMainThreadTaskRunner());
std::vector<std::unique_ptr<Config>> configs;
configs.emplace_back(test_utils::CreateTestConfig(
kDeviceSwitcherClient,
SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_DEVICE_SWITCHER));
configs.back()->auto_execute_and_cache = false;
configs.emplace_back(test_utils::CreateTestConfig(
kAdaptiveToolbarClient,
SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_ADAPTIVE_TOOLBAR));
configs.back()->auto_execute_and_cache = false;
configs.emplace_back(test_utils::CreateTestConfig(
kShoppingUserSegmentationKey,
SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHOPPING_USER));
configs.back()->auto_execute_and_cache = true;
auto config_holder = std::make_unique<ConfigHolder>(std::move(configs));
prefs_.registry()->RegisterStringPref(kSegmentationClientResultPrefs,
std::string());
client_result_prefs_ = std::make_unique<ClientResultPrefs>(&prefs_);
auto cached_result_writer = std::make_unique<CachedResultWriter>(
client_result_prefs_.get(), &clock_);
storage_service_ = std::make_unique<StorageService>(
nullptr, nullptr, nullptr, nullptr, std::move(config_holder),
&ukm_data_manager_);
storage_service_->set_cached_result_writer_for_testing(
std::move(cached_result_writer));
request_dispatcher_ =
std::make_unique<RequestDispatcher>(storage_service_.get());
auto handler1 = std::make_unique<MockRequestHandler>();
request_handler1_ = handler1.get();
request_dispatcher_->set_request_handler_for_testing(kDeviceSwitcherClient,
std::move(handler1));
auto handler2 = std::make_unique<MockRequestHandler>();
request_handler2_ = handler2.get();
request_dispatcher_->set_request_handler_for_testing(kAdaptiveToolbarClient,
std::move(handler2));
auto handler3 = std::make_unique<MockRequestHandler>();
request_handler3_ = handler3.get();
request_dispatcher_->set_request_handler_for_testing(
kShoppingUserSegmentationKey, std::move(handler3));
}
void OnGetClassificationResult(base::RepeatingClosure closure,
const ClassificationResult& expected,
const ClassificationResult& actual) {
EXPECT_EQ(expected.ordered_labels, actual.ordered_labels);
EXPECT_EQ(expected.status, actual.status);
std::move(closure).Run();
}
void OnGetAnnotatedNumericResult(base::RepeatingClosure closure,
const AnnotatedNumericResult& expected,
const AnnotatedNumericResult& actual) {
EXPECT_EQ(expected.result.SerializeAsString(),
actual.result.SerializeAsString());
EXPECT_EQ(expected.status, actual.status);
std::move(closure).Run();
}
base::test::TaskEnvironment task_environment_{
base::test::TaskEnvironment::TimeSource::MOCK_TIME};
base::SimpleTestClock clock_;
TestingPrefServiceSimple prefs_;
std::unique_ptr<ClientResultPrefs> client_result_prefs_;
MockUkmDataManager ukm_data_manager_;
std::unique_ptr<StorageService> storage_service_;
raw_ptr<MockRequestHandler, DanglingUntriaged> request_handler1_ = nullptr;
raw_ptr<MockRequestHandler, DanglingUntriaged> request_handler2_ = nullptr;
raw_ptr<MockRequestHandler, DanglingUntriaged> request_handler3_ = nullptr;
std::unique_ptr<RequestDispatcher> request_dispatcher_;
};
TEST_F(RequestDispatcherTest, TestRequestQueuingWithInitFailure) {
PredictionOptions options;
options.on_demand_execution = true;
EXPECT_EQ(0, request_dispatcher_->GetPendingActionCountForTesting());
// Request handler will never be invoked if init fails.
EXPECT_CALL(*request_handler1_, GetPredictionResult(_, _, _)).Times(0);
base::RunLoop loop;
request_dispatcher_->GetClassificationResult(
kDeviceSwitcherClient, options, scoped_refptr<InputContext>(),
base::BindOnce(&RequestDispatcherTest::OnGetClassificationResult,
base::Unretained(this), loop.QuitClosure(),
ClassificationResult(PredictionStatus::kFailed)));
EXPECT_EQ(1, request_dispatcher_->GetPendingActionCountForTesting());
// Finish platform initialization with failure. The request queue is flushed
// and callbacks are invoked with empty results.
std::map<std::string, std::unique_ptr<SegmentResultProvider>>
result_providers;
ExecutionService execution_service;
request_dispatcher_->OnPlatformInitialized(false, &execution_service,
std::move(result_providers));
loop.Run();
EXPECT_EQ(0, request_dispatcher_->GetPendingActionCountForTesting());
}
TEST_F(RequestDispatcherTest,
TestRequestQueuingWithInitSuccessAndNoModelsLoading) {
base::RunLoop run_loop_1;
base::RunLoop run_loop_2;
PredictionOptions options;
options.on_demand_execution = true;
EXPECT_EQ(0, request_dispatcher_->GetPendingActionCountForTesting());
// Request from client 1.
RawResult raw_result1(PredictionStatus::kSucceeded);
raw_result1.result = CreatePredictionResultWithBinaryClassifier(kTestLabel1);
EXPECT_CALL(*request_handler1_, GetPredictionResult(_, _, _))
.WillRepeatedly(
Invoke([&raw_result1](const PredictionOptions& options,
scoped_refptr<InputContext> input_context,
RawResultCallback callback) {
std::move(callback).Run(raw_result1);
}));
ClassificationResult result1(PredictionStatus::kSucceeded);
result1.ordered_labels.emplace_back(kTestLabel1);
request_dispatcher_->GetClassificationResult(
kDeviceSwitcherClient, options, scoped_refptr<InputContext>(),
base::BindOnce(&RequestDispatcherTest::OnGetClassificationResult,
base::Unretained(this), run_loop_1.QuitClosure(),
result1));
EXPECT_EQ(1, request_dispatcher_->GetPendingActionCountForTesting());
// Request from client 2.
RawResult raw_result2(PredictionStatus::kSucceeded);
raw_result2.result = CreatePredictionResultWithBinaryClassifier(kTestLabel2);
EXPECT_CALL(*request_handler2_, GetPredictionResult(_, _, _))
.WillRepeatedly(
Invoke([&raw_result2](const PredictionOptions& options,
scoped_refptr<InputContext> input_context,
RawResultCallback callback) {
std::move(callback).Run(raw_result2);
}));
ClassificationResult result2(PredictionStatus::kSucceeded);
result2.ordered_labels.emplace_back(kTestLabel2);
request_dispatcher_->GetClassificationResult(
kAdaptiveToolbarClient, options, scoped_refptr<InputContext>(),
base::BindOnce(&RequestDispatcherTest::OnGetClassificationResult,
base::Unretained(this), run_loop_2.QuitClosure(),
result2));
EXPECT_EQ(2, request_dispatcher_->GetPendingActionCountForTesting());
// Finish platform initialization with success. The request queue shouldn't be
// cleared because the models for the queued segments haven't been
// initialized.
std::map<std::string, std::unique_ptr<SegmentResultProvider>>
result_providers;
ExecutionService execution_service;
request_dispatcher_->OnPlatformInitialized(true, &execution_service,
std::move(result_providers));
EXPECT_EQ(2, request_dispatcher_->GetPendingActionCountForTesting());
// Run all pending tasks, this triggers a timeout to clear the request queue
// even if the models didn't load.
task_environment_.FastForwardBy(base::Seconds(1));
EXPECT_EQ(0, request_dispatcher_->GetPendingActionCountForTesting());
run_loop_1.Run();
run_loop_2.Run();
}
TEST_F(RequestDispatcherTest,
TestRequestQueuingWithInitSuccessAndAfterModelsLoaded) {
base::RunLoop run_loop_1;
base::RunLoop run_loop_2;
PredictionOptions options;
options.on_demand_execution = true;
EXPECT_EQ(0, request_dispatcher_->GetPendingActionCountForTesting());
// Request from client 1.
RawResult raw_result1(PredictionStatus::kSucceeded);
raw_result1.result = CreatePredictionResultWithBinaryClassifier(kTestLabel1);
EXPECT_CALL(*request_handler1_, GetPredictionResult(_, _, _))
.WillRepeatedly(
Invoke([&raw_result1](const PredictionOptions& options,
scoped_refptr<InputContext> input_context,
RawResultCallback callback) {
std::move(callback).Run(raw_result1);
}));
ClassificationResult result1(PredictionStatus::kSucceeded);
result1.ordered_labels.emplace_back(kTestLabel1);
request_dispatcher_->GetClassificationResult(
kDeviceSwitcherClient, options, scoped_refptr<InputContext>(),
base::BindOnce(&RequestDispatcherTest::OnGetClassificationResult,
base::Unretained(this), run_loop_1.QuitClosure(),
result1));
EXPECT_EQ(1, request_dispatcher_->GetPendingActionCountForTesting());
// Request from client 2.
RawResult raw_result2(PredictionStatus::kSucceeded);
raw_result2.result = CreatePredictionResultWithBinaryClassifier(kTestLabel2);
EXPECT_CALL(*request_handler2_, GetPredictionResult(_, _, _))
.WillRepeatedly(
Invoke([&raw_result2](const PredictionOptions& options,
scoped_refptr<InputContext> input_context,
RawResultCallback callback) {
std::move(callback).Run(raw_result2);
}));
ClassificationResult result2(PredictionStatus::kSucceeded);
result2.ordered_labels.emplace_back("test_label2");
request_dispatcher_->GetClassificationResult(
kAdaptiveToolbarClient, options, scoped_refptr<InputContext>(),
base::BindOnce(&RequestDispatcherTest::OnGetClassificationResult,
base::Unretained(this), run_loop_2.QuitClosure(),
result2));
EXPECT_EQ(2, request_dispatcher_->GetPendingActionCountForTesting());
// Finish platform initialization with success. The request queue is posted,
// but no requests are dispatched because their models are still not yet
// loaded.
std::map<std::string, std::unique_ptr<SegmentResultProvider>>
result_providers;
ExecutionService execution_service;
request_dispatcher_->OnPlatformInitialized(true, &execution_service,
std::move(result_providers));
// Initialize platform, no requests should be executed.
EXPECT_EQ(2, request_dispatcher_->GetPendingActionCountForTesting());
// Set the device switcher model as initialized. Its request should be
// executed.
request_dispatcher_->OnModelUpdated(
SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_DEVICE_SWITCHER);
// The device switcher request should be dispatched and
// the other one gets enqueued again.
run_loop_1.Run();
EXPECT_EQ(1, request_dispatcher_->GetPendingActionCountForTesting());
// Set the new tab model as initialized.
request_dispatcher_->OnModelUpdated(
SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_ADAPTIVE_TOOLBAR);
// The last request should be dispatched.
run_loop_2.Run();
EXPECT_EQ(0, request_dispatcher_->GetPendingActionCountForTesting());
}
TEST_F(RequestDispatcherTest, TestRequestAfterInitSuccessAndModelsLoaded) {
base::RunLoop loop;
PredictionOptions options;
options.on_demand_execution = true;
// Init platform.
std::map<std::string, std::unique_ptr<SegmentResultProvider>>
result_providers;
ExecutionService execution_service;
// Set platform as initialized.
request_dispatcher_->OnPlatformInitialized(true, &execution_service,
std::move(result_providers));
// Set both models as initialized, now requests should be dispatched
// immediately without queueing.
request_dispatcher_->OnModelUpdated(
SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_DEVICE_SWITCHER);
request_dispatcher_->OnModelUpdated(
SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_ADAPTIVE_TOOLBAR);
// Request from client 1.
RawResult raw_result1(PredictionStatus::kSucceeded);
raw_result1.result = CreatePredictionResultWithBinaryClassifier(kTestLabel1);
EXPECT_CALL(*request_handler1_, GetPredictionResult(_, _, _))
.WillRepeatedly(
Invoke([&raw_result1](const PredictionOptions& options,
scoped_refptr<InputContext> input_context,
RawResultCallback callback) {
std::move(callback).Run(raw_result1);
}));
ClassificationResult result1(PredictionStatus::kSucceeded);
result1.ordered_labels.emplace_back(kTestLabel1);
request_dispatcher_->GetClassificationResult(
kDeviceSwitcherClient, options, scoped_refptr<InputContext>(),
base::BindOnce(&RequestDispatcherTest::OnGetClassificationResult,
base::Unretained(this), loop.QuitClosure(), result1));
EXPECT_EQ(0, request_dispatcher_->GetPendingActionCountForTesting());
// Request from client 2.
RawResult raw_result2(PredictionStatus::kSucceeded);
raw_result2.result = CreatePredictionResultWithBinaryClassifier(kTestLabel2);
EXPECT_CALL(*request_handler2_, GetPredictionResult(_, _, _))
.WillRepeatedly(
Invoke([&raw_result2](const PredictionOptions& options,
scoped_refptr<InputContext> input_context,
RawResultCallback callback) {
std::move(callback).Run(raw_result2);
}));
ClassificationResult result2(PredictionStatus::kSucceeded);
result2.ordered_labels.emplace_back(kTestLabel2);
request_dispatcher_->GetClassificationResult(
kAdaptiveToolbarClient, options, scoped_refptr<InputContext>(),
base::BindOnce(&RequestDispatcherTest::OnGetClassificationResult,
base::Unretained(this), loop.QuitClosure(), result2));
loop.Run();
EXPECT_EQ(0, request_dispatcher_->GetPendingActionCountForTesting());
}
TEST_F(RequestDispatcherTest, TestAnnotatedNumericResultRequestWithWaiting) {
base::RunLoop loop;
PredictionOptions options;
options.on_demand_execution = true;
// Request from client 1.
RawResult raw_result1(PredictionStatus::kSucceeded);
raw_result1.result = CreatePredictionResultWithGenericPredictor();
EXPECT_CALL(*request_handler1_, GetPredictionResult(_, _, _))
.WillRepeatedly(
Invoke([&raw_result1](const PredictionOptions& options,
scoped_refptr<InputContext> input_context,
RawResultCallback callback) {
std::move(callback).Run(raw_result1);
}));
request_dispatcher_->GetAnnotatedNumericResult(
kDeviceSwitcherClient, options, scoped_refptr<InputContext>(),
base::BindOnce(&RequestDispatcherTest::OnGetAnnotatedNumericResult,
base::Unretained(this), loop.QuitClosure(), raw_result1));
EXPECT_EQ(1, request_dispatcher_->GetPendingActionCountForTesting());
// Init platform.
std::map<std::string, std::unique_ptr<SegmentResultProvider>>
result_providers;
ExecutionService execution_service;
request_dispatcher_->OnPlatformInitialized(true, &execution_service,
std::move(result_providers));
loop.Run();
EXPECT_EQ(0, request_dispatcher_->GetPendingActionCountForTesting());
}
TEST_F(RequestDispatcherTest, TestOnDemandWithFallback) {
// Result available in client prefs.
client_result_prefs_->SaveClientResultToPrefs(
kDeviceSwitcherKey,
metadata_utils::CreateClientResultFromPredResult(
CreatePredictionResultWithBinaryClassifier(kTestLabel1),
/*timestamp=*/base::Time::Now()));
auto cached_result_provider = std::make_unique<CachedResultProvider>(
client_result_prefs_.get(), storage_service_->config_holder()->configs());
storage_service_->set_cached_result_provider_for_testing(
std::move(cached_result_provider));
base::RunLoop loop;
PredictionOptions options = PredictionOptions::ForOnDemand(true);
options.can_update_cache_for_future_requests = true;
// Init platform.
std::map<std::string, std::unique_ptr<SegmentResultProvider>>
result_providers;
ExecutionService execution_service;
// Set platform as initialized.
request_dispatcher_->OnPlatformInitialized(true, &execution_service,
std::move(result_providers));
// Set both models as initialized, now requests should be dispatched
// immediately without queueing.
request_dispatcher_->OnModelUpdated(
SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_DEVICE_SWITCHER);
// Request from client.
RawResult raw_result(PredictionStatus::kFailed);
EXPECT_CALL(*request_handler1_, GetPredictionResult(_, _, _))
.WillOnce(Invoke([&raw_result](const PredictionOptions& options,
scoped_refptr<InputContext> input_context,
RawResultCallback callback) {
std::move(callback).Run(raw_result);
}));
ClassificationResult result(PredictionStatus::kSucceeded);
result.ordered_labels.emplace_back(kTestLabel1);
request_dispatcher_->GetClassificationResult(
kDeviceSwitcherClient, options, scoped_refptr<InputContext>(),
base::BindOnce(&RequestDispatcherTest::OnGetClassificationResult,
base::Unretained(this), loop.QuitClosure(), result));
loop.Run();
EXPECT_EQ(0, request_dispatcher_->GetPendingActionCountForTesting());
}
TEST_F(RequestDispatcherTest, TestCachedExecutionWithoutFallback) {
// Result available in client prefs.
client_result_prefs_->SaveClientResultToPrefs(
kShoppingUserSegmentationKey,
metadata_utils::CreateClientResultFromPredResult(
CreatePredictionResultWithBinaryClassifier(kTestLabel1),
/*timestamp=*/base::Time::Now()));
auto cached_result_provider = std::make_unique<CachedResultProvider>(
client_result_prefs_.get(), storage_service_->config_holder()->configs());
storage_service_->set_cached_result_provider_for_testing(
std::move(cached_result_provider));
base::RunLoop loop;
PredictionOptions options = PredictionOptions::ForCached(true);
options.can_update_cache_for_future_requests = true;
// Init platform.
std::map<std::string, std::unique_ptr<SegmentResultProvider>>
result_providers;
ExecutionService execution_service;
// Set platform as initialized.
request_dispatcher_->OnPlatformInitialized(true, &execution_service,
std::move(result_providers));
// Set both models as initialized, now requests should be dispatched
// immediately without queueing.
request_dispatcher_->OnModelUpdated(
SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHOPPING_USER);
// Request from client.
ClassificationResult result(PredictionStatus::kSucceeded);
result.ordered_labels.emplace_back(kTestLabel1);
request_dispatcher_->GetClassificationResult(
kShoppingUserSegmentationKey, options, scoped_refptr<InputContext>(),
base::BindOnce(&RequestDispatcherTest::OnGetClassificationResult,
base::Unretained(this), loop.QuitClosure(), result));
loop.Run();
EXPECT_EQ(0, request_dispatcher_->GetPendingActionCountForTesting());
}
TEST_F(RequestDispatcherTest, TestCachedExecutionWithFallback) {
// Result not available in client prefs.
auto cached_result_provider = std::make_unique<CachedResultProvider>(
client_result_prefs_.get(), storage_service_->config_holder()->configs());
storage_service_->set_cached_result_provider_for_testing(
std::move(cached_result_provider));
base::RunLoop loop;
PredictionOptions options = PredictionOptions::ForCached(true);
options.can_update_cache_for_future_requests = true;
// Init platform.
std::map<std::string, std::unique_ptr<SegmentResultProvider>>
result_providers;
ExecutionService execution_service;
// Set platform as initialized.
request_dispatcher_->OnPlatformInitialized(true, &execution_service,
std::move(result_providers));
// Set both models as initialized, now requests should be dispatched
// immediately without queueing.
request_dispatcher_->OnModelUpdated(
SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_SHOPPING_USER);
// Request from client.
RawResult raw_result(PredictionStatus::kSucceeded);
raw_result.result = CreatePredictionResultWithBinaryClassifier(kTestLabel1);
EXPECT_CALL(*request_handler3_, GetPredictionResult(_, _, _))
.WillRepeatedly(
Invoke([&raw_result](const PredictionOptions& options,
scoped_refptr<InputContext> input_context,
RawResultCallback callback) {
std::move(callback).Run(raw_result);
}));
ClassificationResult result(PredictionStatus::kSucceeded);
result.ordered_labels.emplace_back(kTestLabel1);
request_dispatcher_->GetClassificationResult(
kShoppingUserSegmentationKey, options, scoped_refptr<InputContext>(),
base::BindOnce(&RequestDispatcherTest::OnGetClassificationResult,
base::Unretained(this), loop.QuitClosure(), result));
loop.Run();
EXPECT_EQ(0, request_dispatcher_->GetPendingActionCountForTesting());
}
} // namespace
} // namespace segmentation_platform