| // Copyright 2024 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "chrome/browser/ui/android/toolbar/adaptive_toolbar_bridge.h" |
| |
| #include "base/run_loop.h" |
| #include "chrome/browser/segmentation_platform/segmentation_platform_service_factory.h" |
| #include "chrome/test/base/chrome_test_utils.h" |
| #include "chrome/test/base/testing_profile.h" |
| #include "components/segmentation_platform/public/constants.h" |
| #include "components/segmentation_platform/public/result.h" |
| #include "components/segmentation_platform/public/testing/mock_segmentation_platform_service.h" |
| #include "content/public/test/browser_task_environment.h" |
| #include "testing/gmock/include/gmock/gmock.h" |
| #include "testing/gtest/include/gtest/gtest.h" |
| |
| using segmentation_platform::MockSegmentationPlatformService; |
| using testing::_; |
| |
| namespace adaptive_toolbar { |
| |
| class AdaptiveToolbarBridgeTest : public ::testing::Test { |
| public: |
| AdaptiveToolbarBridgeTest(const AdaptiveToolbarBridgeTest&) = delete; |
| AdaptiveToolbarBridgeTest& operator=(const AdaptiveToolbarBridgeTest&) = |
| delete; |
| |
| ~AdaptiveToolbarBridgeTest() override = default; |
| |
| void SetUp() override { |
| TestingProfile::Builder builder; |
| builder.AddTestingFactory( |
| segmentation_platform::SegmentationPlatformServiceFactory:: |
| GetInstance(), |
| base::BindRepeating([](content::BrowserContext* context) |
| -> std::unique_ptr<KeyedService> { |
| return std::make_unique<MockSegmentationPlatformService>(); |
| })); |
| |
| profile_ = builder.Build(); |
| } |
| |
| void TearDown() override { |
| // Clear default actions for safe teardown. |
| testing::Mock::VerifyAndClear(&GetSegmentationPlatformService()); |
| } |
| |
| AdaptiveToolbarBridgeTest() = default; |
| |
| MockSegmentationPlatformService& GetSegmentationPlatformService() { |
| return *static_cast<MockSegmentationPlatformService*>( |
| segmentation_platform::SegmentationPlatformServiceFactory:: |
| GetForProfile(profile_.get())); |
| } |
| |
| void BridgeCallback(bool is_ready, std::vector<int> ranked_buttons) { |
| callback_buttons_ = ranked_buttons; |
| callback_is_ready_ = is_ready; |
| run_loop_.Quit(); |
| } |
| |
| protected: |
| // Needed for TestingProfile::Builder. |
| content::BrowserTaskEnvironment task_environment_; |
| std::unique_ptr<TestingProfile> profile_; |
| |
| std::vector<int> callback_buttons_; |
| bool callback_is_ready_; |
| base::RunLoop run_loop_; |
| }; |
| |
| TEST_F(AdaptiveToolbarBridgeTest, GetRankedButtons) { |
| Profile* profile = profile_.get(); |
| |
| ON_CALL(GetSegmentationPlatformService(), GetClassificationResult(_, _, _, _)) |
| .WillByDefault(testing::WithArg<3>(testing::Invoke( |
| [](segmentation_platform::ClassificationResultCallback callback) { |
| auto result = segmentation_platform::ClassificationResult( |
| segmentation_platform::PredictionStatus::kSucceeded); |
| // Set segmentation to return a sorted list of labels. |
| result.ordered_labels = { |
| segmentation_platform::kAdaptiveToolbarModelLabelShare, |
| segmentation_platform::kAdaptiveToolbarModelLabelAddToBookmarks, |
| segmentation_platform::kAdaptiveToolbarModelLabelTranslate}; |
| base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( |
| FROM_HERE, base::BindOnce(std::move(callback), result)); |
| }))); |
| |
| adaptive_toolbar::GetRankedSessionVariantButtons( |
| profile, /* use_raw_results= */ false, |
| base::BindOnce(&AdaptiveToolbarBridgeTest::BridgeCallback, |
| base::Unretained(this))); |
| run_loop_.Run(); |
| |
| EXPECT_TRUE(callback_is_ready_); |
| // The returned enum values should match the order of the segmentation result. |
| std::vector<int> expected_buttons = { |
| static_cast<int>(AdaptiveToolbarButtonVariant::kShare), |
| static_cast<int>(AdaptiveToolbarButtonVariant::kAddToBookmarks), |
| static_cast<int>(AdaptiveToolbarButtonVariant::kTranslate)}; |
| EXPECT_EQ(callback_buttons_, expected_buttons); |
| } |
| |
| TEST_F(AdaptiveToolbarBridgeTest, GetRankedButtons_NotReady) { |
| Profile* profile = profile_.get(); |
| |
| ON_CALL(GetSegmentationPlatformService(), GetClassificationResult(_, _, _, _)) |
| .WillByDefault(testing::WithArg<3>(testing::Invoke( |
| [](segmentation_platform::ClassificationResultCallback callback) { |
| // Set segmentation to return kNotReady. |
| auto result = segmentation_platform::ClassificationResult( |
| segmentation_platform::PredictionStatus::kNotReady); |
| base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( |
| FROM_HERE, base::BindOnce(std::move(callback), result)); |
| }))); |
| |
| adaptive_toolbar::GetRankedSessionVariantButtons( |
| profile, /* use_raw_results= */ false, |
| base::BindOnce(&AdaptiveToolbarBridgeTest::BridgeCallback, |
| base::Unretained(this))); |
| run_loop_.Run(); |
| |
| EXPECT_FALSE(callback_is_ready_); |
| std::vector<int> expected_buttons = { |
| static_cast<int>(AdaptiveToolbarButtonVariant::kUnknown), |
| }; |
| EXPECT_EQ(callback_buttons_, expected_buttons); |
| } |
| |
| TEST_F(AdaptiveToolbarBridgeTest, GetRankedButtons_RawResults) { |
| Profile* profile = profile_.get(); |
| |
| ON_CALL(GetSegmentationPlatformService(), |
| GetAnnotatedNumericResult(_, _, _, _)) |
| .WillByDefault(testing::WithArg<3>(testing::Invoke( |
| [](segmentation_platform::AnnotatedNumericResultCallback callback) { |
| segmentation_platform::AnnotatedNumericResult result( |
| segmentation_platform::PredictionStatus::kSucceeded); |
| // Set segmentation to result an annotated numeric result, this |
| // includes a list of labels on the model's config and a list of |
| // scores. Both lists are parallel, the first score belongs to the |
| // first label. |
| auto* result_classifier = result.result.mutable_output_config() |
| ->mutable_predictor() |
| ->mutable_multi_class_classifier(); |
| |
| result_classifier->add_class_labels( |
| segmentation_platform::kAdaptiveToolbarModelLabelNewTab); |
| result_classifier->add_class_labels( |
| segmentation_platform::kAdaptiveToolbarModelLabelVoice); |
| result_classifier->add_class_labels( |
| segmentation_platform::kAdaptiveToolbarModelLabelShare); |
| result_classifier->add_class_labels( |
| segmentation_platform::kAdaptiveToolbarModelLabelTranslate); |
| result_classifier->add_class_labels( |
| segmentation_platform:: |
| kAdaptiveToolbarModelLabelAddToBookmarks); |
| |
| result.result.add_result(0.2); |
| result.result.add_result(0.3); |
| result.result.add_result(0.7); |
| result.result.add_result(0.9); |
| result.result.add_result(0.4); |
| |
| base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( |
| FROM_HERE, base::BindOnce(std::move(callback), result)); |
| }))); |
| |
| adaptive_toolbar::GetRankedSessionVariantButtons( |
| profile, /* use_raw_results= */ true, |
| base::BindOnce(&AdaptiveToolbarBridgeTest::BridgeCallback, |
| base::Unretained(this))); |
| run_loop_.Run(); |
| |
| EXPECT_TRUE(callback_is_ready_); |
| // The returned enum values should be sorted by score. |
| std::vector<int> expected_buttons = { |
| static_cast<int>(AdaptiveToolbarButtonVariant::kTranslate), |
| static_cast<int>(AdaptiveToolbarButtonVariant::kShare), |
| static_cast<int>(AdaptiveToolbarButtonVariant::kAddToBookmarks), |
| static_cast<int>(AdaptiveToolbarButtonVariant::kVoice), |
| static_cast<int>(AdaptiveToolbarButtonVariant::kNewTab), |
| }; |
| EXPECT_EQ(callback_buttons_, expected_buttons); |
| } |
| |
| TEST_F(AdaptiveToolbarBridgeTest, GetRankedButtons_RawResults_NotReady) { |
| Profile* profile = profile_.get(); |
| |
| ON_CALL(GetSegmentationPlatformService(), |
| GetAnnotatedNumericResult(_, _, _, _)) |
| .WillByDefault(testing::WithArg<3>(testing::Invoke( |
| [](segmentation_platform::AnnotatedNumericResultCallback callback) { |
| segmentation_platform::AnnotatedNumericResult result( |
| segmentation_platform::PredictionStatus::kNotReady); |
| |
| base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask( |
| FROM_HERE, base::BindOnce(std::move(callback), result)); |
| }))); |
| |
| adaptive_toolbar::GetRankedSessionVariantButtons( |
| profile, /* use_raw_results= */ true, |
| base::BindOnce(&AdaptiveToolbarBridgeTest::BridgeCallback, |
| base::Unretained(this))); |
| run_loop_.Run(); |
| |
| EXPECT_FALSE(callback_is_ready_); |
| std::vector<int> expected_buttons = { |
| static_cast<int>(AdaptiveToolbarButtonVariant::kUnknown), |
| }; |
| EXPECT_EQ(callback_buttons_, expected_buttons); |
| } |
| |
| } // namespace adaptive_toolbar |