blob: def36f48704ded93aad6dc8f2ca27be472add7a7 [file] [log] [blame]
// Copyright 2022 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/segmentation_platform/internal/selection/request_handler.h"
#include "base/memory/raw_ptr.h"
#include "base/metrics/user_metrics.h"
#include "base/run_loop.h"
#include "base/test/gmock_callback_support.h"
#include "base/test/simple_test_clock.h"
#include "base/test/task_environment.h"
#include "components/prefs/pref_registry_simple.h"
#include "components/prefs/testing_pref_service.h"
#include "components/segmentation_platform/internal/constants.h"
#include "components/segmentation_platform/internal/data_collection/training_data_collector.h"
#include "components/segmentation_platform/internal/database/signal_database.h"
#include "components/segmentation_platform/internal/database/signal_storage_config.h"
#include "components/segmentation_platform/internal/database/storage_service.h"
#include "components/segmentation_platform/internal/metadata/metadata_writer.h"
#include "components/segmentation_platform/internal/mock_ukm_data_manager.h"
#include "components/segmentation_platform/internal/post_processor/post_processing_test_utils.h"
#include "components/segmentation_platform/internal/selection/segment_result_provider.h"
#include "components/segmentation_platform/public/config.h"
#include "components/segmentation_platform/public/model_provider.h"
#include "components/segmentation_platform/public/prediction_options.h"
#include "components/segmentation_platform/public/proto/prediction_result.pb.h"
#include "components/segmentation_platform/public/result.h"
#include "components/segmentation_platform/public/trigger.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
using testing::_;
using testing::ElementsAre;
using testing::FloatNear;
using testing::Invoke;
using testing::Return;
namespace segmentation_platform {
namespace {
// Test Ids.
const proto::SegmentId kSegmentId =
proto::SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
const std::string& kTestClientKey = "test_client";
class MockResultProvider : public SegmentResultProvider {
public:
MOCK_METHOD1(GetSegmentResult,
void(std::unique_ptr<GetResultOptions> options));
};
class MockTrainingDataCollector : public TrainingDataCollector {
public:
MOCK_METHOD0(OnModelMetadataUpdated, void());
MOCK_METHOD0(OnServiceInitialized, void());
MOCK_METHOD0(ReportCollectedContinuousTrainingData, void());
MOCK_METHOD5(OnDecisionTime,
TrainingRequestId(proto::SegmentId id,
scoped_refptr<InputContext> input_context,
DecisionType type,
std::optional<ModelProvider::Request> inputs,
bool decision_result_update_trigger));
MOCK_METHOD4(CollectTrainingData,
void(SegmentId segment_id,
TrainingRequestId request_id,
const TrainingLabels& param,
SuccessCallback callback));
};
proto::PredictionResult CreatePredictionResultWithBinaryClassifier() {
proto::SegmentationModelMetadata model_metadata;
MetadataWriter writer(&model_metadata);
writer.AddOutputConfigForBinaryClassifier(0.5f, "positive_label",
"negative_label");
proto::PredictionResult result;
result.add_result(0.8f);
result.mutable_output_config()->Swap(model_metadata.mutable_output_config());
return result;
}
proto::PredictionResult CreatePredictionResultWithGenericPredictor() {
proto::SegmentationModelMetadata model_metadata;
MetadataWriter writer(&model_metadata);
writer.AddOutputConfigForGenericPredictor({"output1", "output2"});
proto::PredictionResult prediction_result;
prediction_result.add_result(0.8f);
prediction_result.add_result(0.2f);
prediction_result.mutable_output_config()->Swap(
model_metadata.mutable_output_config());
return prediction_result;
}
class RequestHandlerTest : public testing::Test {
public:
RequestHandlerTest() = default;
~RequestHandlerTest() override = default;
void SetUp() override {
base::SetRecordActionTaskRunner(
task_environment_.GetMainThreadTaskRunner());
auto training_data_collector =
std::make_unique<MockTrainingDataCollector>();
training_data_collector_ = training_data_collector.get();
execution_service_.set_training_data_collector_for_testing(
std::move(training_data_collector));
config_ = test_utils::CreateTestConfig(kTestClientKey, kSegmentId);
auto provider = std::make_unique<MockResultProvider>();
result_provider_ = provider.get();
std::vector<std::unique_ptr<Config>> configs;
configs.emplace_back(
test_utils::CreateTestConfig(kTestClientKey, kSegmentId));
configs.back()->auto_execute_and_cache = false;
auto config_holder = std::make_unique<ConfigHolder>(std::move(configs));
prefs_.registry()->RegisterStringPref(kSegmentationClientResultPrefs,
std::string());
client_result_prefs_ = std::make_unique<ClientResultPrefs>(&prefs_);
auto cached_result_writer = std::make_unique<CachedResultWriter>(
client_result_prefs_.get(), &clock_);
cached_result_writer_ = cached_result_writer.get();
storage_service_ = std::make_unique<StorageService>(
nullptr, nullptr, nullptr, nullptr, std::move(config_holder),
&ukm_data_manager_);
storage_service_->set_cached_result_writer_for_testing(
std::move(cached_result_writer));
request_handler_ =
RequestHandler::Create(*(config_.get()), std::move(provider),
&execution_service_, storage_service_.get());
}
void OnGetPredictionResult(base::RepeatingClosure closure,
const RawResult& result) {
EXPECT_EQ(result.status, PredictionStatus::kSucceeded);
EXPECT_NEAR(0.8, result.result.result(0), 0.001);
EXPECT_EQ(result.request_id, TrainingRequestId::FromUnsafeValue(15));
std::move(closure).Run();
}
base::test::TaskEnvironment task_environment_{
base::test::TaskEnvironment::TimeSource::MOCK_TIME};
std::unique_ptr<Config> config_;
base::SimpleTestClock clock_;
TestingPrefServiceSimple prefs_;
std::unique_ptr<ClientResultPrefs> client_result_prefs_;
ExecutionService execution_service_;
raw_ptr<MockTrainingDataCollector> training_data_collector_;
MockUkmDataManager ukm_data_manager_;
std::unique_ptr<StorageService> storage_service_;
raw_ptr<CachedResultWriter> cached_result_writer_;
std::unique_ptr<RequestHandler> request_handler_;
raw_ptr<MockResultProvider> result_provider_ = nullptr;
};
TEST_F(RequestHandlerTest, GetPredictionResult) {
PredictionOptions options;
options.on_demand_execution = true;
options.can_update_cache_for_future_requests = true;
EXPECT_CALL(
*training_data_collector_,
OnDecisionTime(
kSegmentId, _, proto::TrainingOutputs::TriggerConfig::ONDEMAND,
std::make_optional(ModelProvider::Request{1, 2, 3}), false))
.WillOnce(Return(TrainingRequestId::FromUnsafeValue(15)));
EXPECT_CALL(*result_provider_, GetSegmentResult(_))
.WillOnce(Invoke(
[](std::unique_ptr<SegmentResultProvider::GetResultOptions> options) {
EXPECT_TRUE(options->ignore_db_scores);
EXPECT_EQ(options->segment_id, kSegmentId);
auto result =
std::make_unique<SegmentResultProvider::SegmentResult>(
SegmentResultProvider::ResultState::
kServerModelExecutionScoreUsed,
CreatePredictionResultWithBinaryClassifier(),
/*rank=*/2);
result->model_inputs = {1, 2, 3};
std::move(options->callback).Run(std::move(result));
}));
base::RunLoop loop;
request_handler_->GetPredictionResult(
options, scoped_refptr<InputContext>(),
base::BindOnce(&RequestHandlerTest::OnGetPredictionResult,
base::Unretained(this), loop.QuitClosure()));
loop.Run();
// Check prefs is updated if `can_update_cache_for_future_requests` is set to
// true.
const proto::ClientResult* result_from_pref =
client_result_prefs_->ReadClientResultFromPrefs(
config_->segmentation_key);
EXPECT_EQ(CreatePredictionResultWithBinaryClassifier().SerializeAsString(),
result_from_pref->client_result().SerializeAsString());
}
TEST_F(RequestHandlerTest, ExecuteOndemandAsFallbackCase) {
PredictionOptions options;
options.on_demand_execution = false;
options.fallback_allowed = true;
EXPECT_CALL(
*training_data_collector_,
OnDecisionTime(
kSegmentId, _, proto::TrainingOutputs::TriggerConfig::ONDEMAND,
std::make_optional(ModelProvider::Request{1, 2, 3}), false))
.WillOnce(Return(TrainingRequestId::FromUnsafeValue(15)));
EXPECT_CALL(*result_provider_, GetSegmentResult(_))
.WillOnce(Invoke([](std::unique_ptr<
SegmentResultProvider::GetResultOptions> options) {
EXPECT_TRUE(options->ignore_db_scores);
EXPECT_EQ(options->segment_id, kSegmentId);
auto result = std::make_unique<SegmentResultProvider::SegmentResult>(
SegmentResultProvider::ResultState::kServerModelExecutionScoreUsed,
CreatePredictionResultWithBinaryClassifier(),
/*rank=*/2);
result->model_inputs = {1, 2, 3};
std::move(options->callback).Run(std::move(result));
}));
base::RunLoop loop;
request_handler_->GetPredictionResult(
options, scoped_refptr<InputContext>(),
base::BindOnce(&RequestHandlerTest::OnGetPredictionResult,
base::Unretained(this), loop.QuitClosure()));
loop.Run();
}
TEST_F(RequestHandlerTest, GetGenericPredictionResult) {
PredictionOptions options;
options.on_demand_execution = true;
options.can_update_cache_for_future_requests = false;
EXPECT_CALL(
*training_data_collector_,
OnDecisionTime(kSegmentId, _,
proto::TrainingOutputs::TriggerConfig::ONDEMAND,
std::make_optional(ModelProvider::Request{1}), false))
.WillOnce(Return(TrainingRequestId::FromUnsafeValue(15)));
EXPECT_CALL(*result_provider_, GetSegmentResult(_))
.WillOnce(Invoke(
[](std::unique_ptr<SegmentResultProvider::GetResultOptions> options) {
EXPECT_TRUE(options->ignore_db_scores);
EXPECT_EQ(options->segment_id, kSegmentId);
auto result =
std::make_unique<SegmentResultProvider::SegmentResult>(
SegmentResultProvider::ResultState::
kServerModelExecutionScoreUsed,
CreatePredictionResultWithGenericPredictor(),
/*rank=*/2);
result->model_inputs = {1};
std::move(options->callback).Run(std::move(result));
}));
base::RunLoop loop;
request_handler_->GetPredictionResult(
options, scoped_refptr<InputContext>(),
base::BindOnce(&RequestHandlerTest::OnGetPredictionResult,
base::Unretained(this), loop.QuitClosure()));
loop.Run();
}
} // namespace
} // namespace segmentation_platform