|  | // Copyright 2023 The Chromium Authors | 
|  | // Use of this source code is governed by a BSD-style license that can be | 
|  | // found in the LICENSE file. | 
|  |  | 
|  | #include "components/omnibox/browser/on_device_tail_model_service.h" | 
|  |  | 
|  | #include "base/containers/flat_set.h" | 
|  | #include "base/files/file_path.h" | 
|  | #include "base/files/file_util.h" | 
|  | #include "base/functional/bind.h" | 
|  | #include "base/path_service.h" | 
|  | #include "base/test/task_environment.h" | 
|  | #include "components/omnibox/browser/on_device_tail_model_executor.h" | 
|  | #include "components/optimization_guide/core/test_model_info_builder.h" | 
|  | #include "components/optimization_guide/core/test_optimization_guide_model_provider.h" | 
|  | #include "testing/gmock/include/gmock/gmock.h" | 
|  | #include "testing/gtest/include/gtest/gtest.h" | 
|  |  | 
|  | using ::testing::ElementsAreArray; | 
|  |  | 
|  | namespace { | 
|  |  | 
|  | static const char kTailModelFilename[] = "test_tail_model.tflite"; | 
|  | static const char kVocabFilename[] = "vocab_test.txt"; | 
|  |  | 
|  | constexpr int kNumLayer = 1; | 
|  | constexpr int kStateSize = 512; | 
|  | constexpr int kEmbeddingDim = 64; | 
|  |  | 
|  | }  // namespace | 
|  |  | 
|  | class OnDeviceTailModelServiceTest : public ::testing::Test { | 
|  | protected: | 
|  | void SetUp() override { | 
|  | test_model_provider_ = std::make_unique< | 
|  | optimization_guide::TestOptimizationGuideModelProvider>(); | 
|  | service_ = | 
|  | std::make_unique<OnDeviceTailModelService>(test_model_provider_.get()); | 
|  |  | 
|  | optimization_guide::proto::Any any_metadata; | 
|  | any_metadata.set_type_url( | 
|  | "type.googleapis.com/com.foo.OnDeviceTailSuggestModelMetadata"); | 
|  |  | 
|  | base::FilePath test_data_dir; | 
|  | base::PathService::Get(base::DIR_SRC_TEST_DATA_ROOT, &test_data_dir); | 
|  | test_data_dir = test_data_dir.AppendASCII("components/test/data/omnibox"); | 
|  |  | 
|  | base::flat_set<base::FilePath> additional_files; | 
|  | additional_files.insert(test_data_dir.AppendASCII(kVocabFilename)); | 
|  |  | 
|  | optimization_guide::proto::OnDeviceTailSuggestModelMetadata metadata; | 
|  | metadata.mutable_lstm_model_params()->set_num_layer(kNumLayer); | 
|  | metadata.mutable_lstm_model_params()->set_state_size(kStateSize); | 
|  | metadata.mutable_lstm_model_params()->set_embedding_dimension( | 
|  | kEmbeddingDim); | 
|  | metadata.SerializeToString(any_metadata.mutable_value()); | 
|  |  | 
|  | model_info_ = | 
|  | optimization_guide::TestModelInfoBuilder() | 
|  | .SetModelFilePath(test_data_dir.AppendASCII(kTailModelFilename)) | 
|  | .SetAdditionalFiles(additional_files) | 
|  | .SetVersion(123) | 
|  | .SetModelMetadata(any_metadata) | 
|  | .Build(); | 
|  |  | 
|  | task_environment_.RunUntilIdle(); | 
|  | } | 
|  |  | 
|  | void TearDown() override { | 
|  | service_ = nullptr; | 
|  | task_environment_.RunUntilIdle(); | 
|  | } | 
|  |  | 
|  | bool IsExecutorReady() const { | 
|  | return service_->tail_model_executor_->IsReady(); | 
|  | } | 
|  |  | 
|  | base::test::TaskEnvironment task_environment_; | 
|  | std::unique_ptr<OnDeviceTailModelService> service_; | 
|  | std::unique_ptr<optimization_guide::TestOptimizationGuideModelProvider> | 
|  | test_model_provider_; | 
|  | std::unique_ptr<optimization_guide::ModelInfo> model_info_; | 
|  | }; | 
|  |  | 
|  | TEST_F(OnDeviceTailModelServiceTest, OnModelUpdated) { | 
|  | service_->OnModelUpdated( | 
|  | optimization_guide::proto::OptimizationTarget:: | 
|  | OPTIMIZATION_TARGET_OMNIBOX_ON_DEVICE_TAIL_SUGGEST, | 
|  | *model_info_); | 
|  | task_environment_.RunUntilIdle(); | 
|  |  | 
|  | EXPECT_TRUE(IsExecutorReady()); | 
|  | } | 
|  |  | 
|  | TEST_F(OnDeviceTailModelServiceTest, GetPredictionsForInput) { | 
|  | std::vector<OnDeviceTailModelExecutor::Prediction> results; | 
|  |  | 
|  | OnDeviceTailModelExecutor::ModelInput input("faceb", "", 5, 20, 0.05); | 
|  | OnDeviceTailModelService::ResultCallback callback = base::BindOnce( | 
|  | [](std::vector<OnDeviceTailModelExecutor::Prediction>* results, | 
|  | std::vector<OnDeviceTailModelExecutor::Prediction> predictions) { | 
|  | *results = std::move(predictions); | 
|  | }, | 
|  | &results); | 
|  |  | 
|  | service_->OnModelUpdated( | 
|  | optimization_guide::proto::OptimizationTarget:: | 
|  | OPTIMIZATION_TARGET_OMNIBOX_ON_DEVICE_TAIL_SUGGEST, | 
|  | *model_info_); | 
|  | service_->GetPredictionsForInput(input, std::move(callback)); | 
|  |  | 
|  | task_environment_.RunUntilIdle(); | 
|  |  | 
|  | EXPECT_FALSE(results.empty()); | 
|  | EXPECT_TRUE(base::StartsWith(results[0].suggestion, "facebook", | 
|  | base::CompareCase::SENSITIVE)); | 
|  | } | 
|  |  | 
|  | TEST_F(OnDeviceTailModelServiceTest, NullModelUpdate) { | 
|  | service_->OnModelUpdated( | 
|  | optimization_guide::proto::OptimizationTarget:: | 
|  | OPTIMIZATION_TARGET_OMNIBOX_ON_DEVICE_TAIL_SUGGEST, | 
|  | *model_info_); | 
|  | task_environment_.RunUntilIdle(); | 
|  | EXPECT_TRUE(IsExecutorReady()); | 
|  |  | 
|  | // Null model update shoud disable the executor. | 
|  | service_->OnModelUpdated( | 
|  | optimization_guide::proto::OptimizationTarget:: | 
|  | OPTIMIZATION_TARGET_OMNIBOX_ON_DEVICE_TAIL_SUGGEST, | 
|  | std::nullopt); | 
|  | task_environment_.RunUntilIdle(); | 
|  | EXPECT_FALSE(IsExecutorReady()); | 
|  | } |