blob: 0c1263d53426e38388145895b97f49f7478acdda [file] [log] [blame]
// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/segmentation_platform/internal/execution/model_execution_manager_impl.h"
#include <memory>
#include "base/functional/bind.h"
#include "base/functional/callback_forward.h"
#include "base/functional/callback_helpers.h"
#include "base/memory/raw_ptr.h"
#include "base/metrics/metrics_hashes.h"
#include "base/run_loop.h"
#include "base/test/gmock_callback_support.h"
#include "base/test/gtest_util.h"
#include "base/test/mock_callback.h"
#include "base/test/simple_test_clock.h"
#include "base/test/task_environment.h"
#include "base/time/time.h"
#include "components/segmentation_platform/internal/database/mock_signal_database.h"
#include "components/segmentation_platform/internal/database/signal_database.h"
#include "components/segmentation_platform/internal/database/test_segment_info_database.h"
#include "components/segmentation_platform/internal/execution/mock_model_provider.h"
#include "components/segmentation_platform/internal/execution/model_execution_manager.h"
#include "components/segmentation_platform/internal/execution/model_execution_status.h"
#include "components/segmentation_platform/internal/execution/processing/feature_list_query_processor.h"
#include "components/segmentation_platform/internal/metadata/metadata_utils.h"
#include "components/segmentation_platform/public/model_provider.h"
#include "components/segmentation_platform/public/proto/aggregation.pb.h"
#include "components/segmentation_platform/public/proto/model_metadata.pb.h"
#include "components/segmentation_platform/public/proto/segmentation_platform.pb.h"
#include "components/segmentation_platform/public/proto/types.pb.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
using ::base::test::RunOnceCallback;
using testing::_;
using testing::Invoke;
using testing::Return;
using testing::SaveArg;
using testing::SetArgReferee;
namespace segmentation_platform {
namespace {
const int64_t kModelVersion = 123;
using Sample = SignalDatabase::Sample;
class MockSegmentInfoDatabase : public test::TestSegmentInfoDatabase {
public:
MOCK_METHOD(void, Initialize, (SuccessCallback callback), (override));
MOCK_METHOD(void,
GetSegmentInfoForSegments,
(const base::flat_set<SegmentId>& segment_ids,
MultipleSegmentInfoCallback callback),
(override));
MOCK_METHOD(void,
GetSegmentInfo,
(SegmentId segment_id,
proto::ModelSource model_source,
SegmentInfoCallback callback),
(override));
MOCK_METHOD(void,
UpdateSegment,
(SegmentId segment_id,
absl::optional<proto::SegmentInfo> segment_info,
SuccessCallback callback),
(override));
MOCK_METHOD(void,
SaveSegmentResult,
(SegmentId segment_id,
ModelSource model_source,
absl::optional<proto::PredictionResult> result,
SuccessCallback callback),
(override));
};
} // namespace
class ModelExecutionManagerTest : public testing::Test {
public:
ModelExecutionManagerTest()
: model_provider_factory_(&model_provider_data_) {}
~ModelExecutionManagerTest() override = default;
void SetUp() override {
segment_database_ = std::make_unique<test::TestSegmentInfoDatabase>();
signal_database_ = std::make_unique<MockSignalDatabase>();
clock_.SetNow(base::Time::Now());
}
void TearDown() override {
model_execution_manager_.reset();
// Allow for the SegmentationModelExecutor owned by ModelProvider
// to be destroyed.
RunUntilIdle();
}
void CreateModelExecutionManager(
const base::flat_set<SegmentId>& segment_ids,
const ModelExecutionManager::SegmentationModelUpdatedCallback& callback) {
model_execution_manager_ = std::make_unique<ModelExecutionManagerImpl>(
segment_ids, &model_provider_factory_, &clock_, segment_database_.get(),
callback);
}
void RunUntilIdle() { task_environment_.RunUntilIdle(); }
MockModelProvider& FindHandler(proto::SegmentId segment_id) {
return *(*model_provider_data_.model_providers.find(segment_id)).second;
}
protected:
base::test::TaskEnvironment task_environment_;
TestModelProviderFactory::Data model_provider_data_;
TestModelProviderFactory model_provider_factory_;
base::SimpleTestClock clock_;
std::unique_ptr<test::TestSegmentInfoDatabase> segment_database_;
std::unique_ptr<MockSignalDatabase> signal_database_;
std::unique_ptr<ModelExecutionManagerImpl> model_execution_manager_;
};
TEST_F(ModelExecutionManagerTest, OnSegmentationModelUpdatedInvalidMetadata) {
// Use a MockSegmentInfoDatabase for this test in particular, to verify that
// it is never used.
auto mock_segment_database = std::make_unique<MockSegmentInfoDatabase>();
auto* mock_segment_database_ptr = mock_segment_database.get();
segment_database_ = std::move(mock_segment_database);
// Construct the ModelExecutionManager.
base::MockCallback<ModelExecutionManager::SegmentationModelUpdatedCallback>
callback;
auto segment_id = SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
CreateModelExecutionManager({segment_id}, callback.Get());
// Create invalid metadata, which should be ignored.
proto::SegmentInfo segment_info;
proto::SegmentationModelMetadata metadata;
metadata.set_time_unit(proto::UNKNOWN_TIME_UNIT);
// Verify that the ModelExecutionManager never invokes its
// SegmentInfoDatabase, nor invokes the callback.
EXPECT_CALL(*mock_segment_database_ptr, GetSegmentInfo(_, _, _)).Times(0);
EXPECT_CALL(callback, Run(_)).Times(0);
model_provider_data_.model_providers_callbacks[segment_id].Run(
segment_id, metadata, kModelVersion);
}
TEST_F(ModelExecutionManagerTest, OnSegmentationModelUpdatedNoOldMetadata) {
base::MockCallback<ModelExecutionManager::SegmentationModelUpdatedCallback>
callback;
auto segment_id = SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
CreateModelExecutionManager({segment_id}, callback.Get());
proto::SegmentInfo segment_info;
proto::SegmentationModelMetadata metadata;
metadata.set_bucket_duration(42u);
metadata.set_time_unit(proto::TimeUnit::DAY);
EXPECT_CALL(callback, Run(_)).WillOnce(SaveArg<0>(&segment_info));
model_provider_data_.model_providers_callbacks[segment_id].Run(
segment_id, metadata, kModelVersion);
// Verify that the resulting callback was invoked correctly.
EXPECT_EQ(segment_id, segment_info.segment_id());
EXPECT_EQ(42u, segment_info.model_metadata().bucket_duration());
EXPECT_EQ(proto::ModelSource::SERVER_MODEL_SOURCE,
segment_info.model_source());
// Also verify that the database has been updated.
base::MockCallback<SegmentInfoDatabase::SegmentInfoCallback> db_callback;
absl::optional<proto::SegmentInfo> segment_info_from_db;
EXPECT_CALL(db_callback, Run(_)).WillOnce(SaveArg<0>(&segment_info_from_db));
// Fetch SegmentInfo from the database.
segment_database_->GetSegmentInfo(
segment_id, proto::ModelSource::SERVER_MODEL_SOURCE, db_callback.Get());
EXPECT_TRUE(segment_info_from_db.has_value());
EXPECT_EQ(segment_id, segment_info_from_db->segment_id());
// The metadata should have been stored.
EXPECT_EQ(42u, segment_info_from_db->model_metadata().bucket_duration());
// Model update time should be updated.
EXPECT_EQ(clock_.Now().ToDeltaSinceWindowsEpoch().InSeconds(),
segment_info_from_db->model_update_time_s());
}
TEST_F(ModelExecutionManagerTest,
OnSegmentationModelUpdatedWithPreviousMetadataAndPredictionResult) {
base::MockCallback<ModelExecutionManager::SegmentationModelUpdatedCallback>
callback;
auto segment_id = SegmentId::OPTIMIZATION_TARGET_SEGMENTATION_NEW_TAB;
CreateModelExecutionManager({segment_id}, callback.Get());
// Fill in old data in the SegmentInfo database.
segment_database_->SetBucketDuration(segment_id, 456, proto::TimeUnit::MONTH);
segment_database_->AddUserActionFeature(segment_id, "hello", 2, 2,
proto::Aggregation::BUCKETED_COUNT);
segment_database_->AddPredictionResult(segment_id, 2, clock_.Now());
base::MockCallback<SegmentInfoDatabase::SegmentInfoCallback> db_callback_1;
absl::optional<proto::SegmentInfo> segment_info_from_db_1;
EXPECT_CALL(db_callback_1, Run(_))
.WillOnce(SaveArg<0>(&segment_info_from_db_1));
segment_database_->GetSegmentInfo(
segment_id, proto::ModelSource::SERVER_MODEL_SOURCE, db_callback_1.Get());
EXPECT_TRUE(segment_info_from_db_1.has_value());
EXPECT_EQ(segment_id, segment_info_from_db_1->segment_id());
// Verify the old metadata and prediction result has been stored correctly.
EXPECT_EQ(456u, segment_info_from_db_1->model_metadata().bucket_duration());
EXPECT_THAT(segment_info_from_db_1->prediction_result().result(),
testing::ElementsAre(2));
EXPECT_FALSE(segment_info_from_db_1->has_model_source());
// Verify the metadata features have been stored correctly.
EXPECT_EQ(proto::SignalType::USER_ACTION,
segment_info_from_db_1->model_metadata()
.input_features(0)
.uma_feature()
.type());
EXPECT_EQ("hello", segment_info_from_db_1->model_metadata()
.input_features(0)
.uma_feature()
.name());
EXPECT_EQ(proto::Aggregation::BUCKETED_COUNT,
segment_info_from_db_1->model_metadata()
.input_features(0)
.uma_feature()
.aggregation());
// Create segment info that does not match.
proto::SegmentInfo segment_info;
proto::SegmentationModelMetadata metadata;
metadata.set_bucket_duration(42u);
metadata.set_time_unit(proto::TimeUnit::HOUR);
// Create one feature that does not match the stored feature.
auto* feature = metadata.add_features();
feature->set_type(proto::SignalType::HISTOGRAM_VALUE);
feature->set_name("other");
// Intentionally not set the name hash, as it should be set automatically.
feature->set_aggregation(proto::Aggregation::BUCKETED_SUM);
feature->set_bucket_count(3);
feature->set_tensor_length(3);
// Invoke the callback and store the resulting invocation of the outer
// callback for verification.
EXPECT_CALL(callback, Run(_)).WillOnce(SaveArg<0>(&segment_info));
model_provider_data_.model_providers_callbacks[segment_id].Run(
segment_id, metadata, kModelVersion);
// Should now have the metadata from the new proto.
EXPECT_EQ(segment_id, segment_info.segment_id());
EXPECT_EQ(42u, segment_info.model_metadata().bucket_duration());
EXPECT_EQ(proto::SignalType::HISTOGRAM_VALUE,
segment_info.model_metadata().features(0).type());
EXPECT_EQ("other", segment_info.model_metadata().features(0).name());
// The name_hash should have been set automatically.
EXPECT_EQ(base::HashMetricName("other"),
segment_info.model_metadata().features(0).name_hash());
EXPECT_EQ(proto::Aggregation::BUCKETED_SUM,
segment_info.model_metadata().features(0).aggregation());
EXPECT_THAT(segment_info.prediction_result().result(),
testing::ElementsAre(2));
EXPECT_EQ(clock_.Now().ToDeltaSinceWindowsEpoch().InMicroseconds(),
segment_info.prediction_result().timestamp_us());
// Also verify that the database has been updated.
base::MockCallback<SegmentInfoDatabase::SegmentInfoCallback> db_callback_2;
absl::optional<proto::SegmentInfo> segment_info_from_db_2;
EXPECT_CALL(db_callback_2, Run(_))
.WillOnce(SaveArg<0>(&segment_info_from_db_2));
segment_database_->GetSegmentInfo(
segment_id, proto::ModelSource::SERVER_MODEL_SOURCE, db_callback_2.Get());
EXPECT_TRUE(segment_info_from_db_2.has_value());
EXPECT_EQ(segment_id, segment_info_from_db_2->segment_id());
EXPECT_EQ(clock_.Now().ToDeltaSinceWindowsEpoch().InSeconds(),
segment_info_from_db_2->model_update_time_s());
// The metadata should have been updated.
EXPECT_EQ(42u, segment_info_from_db_2->model_metadata().bucket_duration());
// The metadata features should have been updated.
EXPECT_EQ(proto::SignalType::HISTOGRAM_VALUE,
segment_info_from_db_2->model_metadata().features(0).type());
EXPECT_EQ("other",
segment_info_from_db_2->model_metadata().features(0).name());
EXPECT_EQ(base::HashMetricName("other"),
segment_info_from_db_2->model_metadata().features(0).name_hash());
EXPECT_EQ(proto::Aggregation::BUCKETED_SUM,
segment_info_from_db_2->model_metadata().features(0).aggregation());
// We shuold have kept the prediction result.
EXPECT_THAT(segment_info.prediction_result().result(),
testing::ElementsAre(2));
}
} // namespace segmentation_platform