| // Copyright 2024 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "services/passage_embeddings/passage_embeddings_service.h" |
| |
| #include "base/path_service.h" |
| #include "base/test/bind.h" |
| #include "base/test/metrics/histogram_tester.h" |
| #include "base/test/task_environment.h" |
| #include "base/test/test_future.h" |
| #include "mojo/public/cpp/bindings/remote.h" |
| #include "services/passage_embeddings/passage_embedder.h" |
| #include "testing/gtest/include/gtest/gtest.h" |
| |
| namespace passage_embeddings { |
| namespace { |
| |
| constexpr uint32_t kInputWindowSize = 256; |
| constexpr size_t kEmbeddingsOutputSize = 768; |
| |
| class PassageEmbeddingsServiceTest : public testing::Test { |
| public: |
| PassageEmbeddingsServiceTest() |
| : service_impl_(service_.BindNewPipeAndPassReceiver()) {} |
| |
| mojo::Remote<mojom::PassageEmbeddingsService>& service() { return service_; } |
| |
| void SetUp() override { |
| base::FilePath test_data_dir; |
| base::PathService::Get(base::DIR_SRC_TEST_DATA_ROOT, &test_data_dir); |
| test_data_dir = test_data_dir.AppendASCII("services") |
| .AppendASCII("test") |
| .AppendASCII("data") |
| .AppendASCII("passage_embeddings"); |
| embeddings_path_ = |
| test_data_dir.AppendASCII("dummy_embeddings_model.tflite"); |
| sp_path_ = test_data_dir.AppendASCII("sentencepiece.model"); |
| } |
| |
| mojom::PassageEmbeddingsLoadModelsParamsPtr MakeModelParams( |
| base::FilePath embeddings_path, |
| base::FilePath sp_path, |
| uint32_t input_window_size) { |
| auto params = mojom::PassageEmbeddingsLoadModelsParams::New(); |
| params->embeddings_model = base::File( |
| embeddings_path, base::File::FLAG_OPEN | base::File::FLAG_READ); |
| params->sp_model = |
| base::File(sp_path, base::File::FLAG_OPEN | base::File::FLAG_READ); |
| params->input_window_size = input_window_size; |
| return params; |
| } |
| |
| mojom::PassageEmbedderParamsPtr MakeEmbedderParams() { |
| auto params = mojom::PassageEmbedderParams::New(); |
| params->user_initiated_priority_num_threads = 4; |
| params->passive_priority_num_threads = 1; |
| params->embedder_cache_size = 1000; |
| return params; |
| } |
| |
| protected: |
| base::FilePath embeddings_path_; |
| base::FilePath sp_path_; |
| base::HistogramTester histogram_tester_; |
| |
| private: |
| base::test::TaskEnvironment task_environment_; |
| mojo::Remote<mojom::PassageEmbeddingsService> service_; |
| PassageEmbeddingsService service_impl_; |
| }; |
| |
| TEST_F(PassageEmbeddingsServiceTest, LoadValidModels) { |
| mojo::Remote<mojom::PassageEmbedder> embedder_remote; |
| base::test::TestFuture<bool> future; |
| service()->LoadModels( |
| MakeModelParams(embeddings_path_, sp_path_, kInputWindowSize), |
| MakeEmbedderParams(), embedder_remote.BindNewPipeAndPassReceiver(), |
| future.GetCallback()); |
| bool load_models_success = future.Get(); |
| EXPECT_TRUE(load_models_success); |
| } |
| |
| TEST_F(PassageEmbeddingsServiceTest, LoadModelsWithInvalidEmbeddingsModel) { |
| mojo::Remote<mojom::PassageEmbedder> embedder_remote; |
| base::test::TestFuture<bool> load_models_future; |
| service()->LoadModels(MakeModelParams(sp_path_, sp_path_, kInputWindowSize), |
| MakeEmbedderParams(), |
| embedder_remote.BindNewPipeAndPassReceiver(), |
| load_models_future.GetCallback()); |
| bool load_models_success = load_models_future.Get(); |
| // LoadModels succeeds since the model file can still be read. |
| EXPECT_TRUE(load_models_success); |
| |
| base::test::TestFuture<std::vector<mojom::PassageEmbeddingsResultPtr>> |
| execute_future; |
| embedder_remote->GenerateEmbeddings({"foo"}, |
| mojom::PassagePriority::kUserInitiated, |
| execute_future.GetCallback()); |
| std::vector<mojom::PassageEmbeddingsResultPtr> results = |
| execute_future.Take(); |
| // Execution fails since the embeddings model is invalid. |
| EXPECT_EQ(results.size(), 0u); |
| } |
| |
| TEST_F(PassageEmbeddingsServiceTest, LoadModelsWithInvalidSpModel) { |
| mojo::Remote<mojom::PassageEmbedder> embedder_remote; |
| base::test::TestFuture<bool> future; |
| service()->LoadModels( |
| MakeModelParams(embeddings_path_, embeddings_path_, kInputWindowSize), |
| MakeEmbedderParams(), embedder_remote.BindNewPipeAndPassReceiver(), |
| future.GetCallback()); |
| bool load_models_success = future.Get(); |
| EXPECT_FALSE(load_models_success); |
| } |
| |
| TEST_F(PassageEmbeddingsServiceTest, LoadModelsWithInvalidInputWindowSize) { |
| mojo::Remote<mojom::PassageEmbedder> embedder_remote; |
| base::test::TestFuture<bool> future; |
| service()->LoadModels( |
| MakeModelParams(embeddings_path_, sp_path_, 0u), MakeEmbedderParams(), |
| embedder_remote.BindNewPipeAndPassReceiver(), future.GetCallback()); |
| bool load_models_success = future.Get(); |
| EXPECT_FALSE(load_models_success); |
| } |
| |
| TEST_F(PassageEmbeddingsServiceTest, RespondsWithEmbeddings) { |
| mojo::Remote<mojom::PassageEmbedder> embedder_remote; |
| base::test::TestFuture<bool> load_models_future; |
| service()->LoadModels( |
| MakeModelParams(embeddings_path_, sp_path_, kInputWindowSize), |
| MakeEmbedderParams(), embedder_remote.BindNewPipeAndPassReceiver(), |
| load_models_future.GetCallback()); |
| bool load_models_success = load_models_future.Get(); |
| EXPECT_TRUE(load_models_success); |
| |
| base::test::TestFuture<std::vector<mojom::PassageEmbeddingsResultPtr>> |
| execute_future; |
| embedder_remote->GenerateEmbeddings({"hello", "world", ""}, |
| mojom::PassagePriority::kUserInitiated, |
| execute_future.GetCallback()); |
| auto results = execute_future.Take(); |
| EXPECT_EQ(results.size(), 3u); |
| for (const auto& result : results) { |
| EXPECT_EQ(result->embeddings.size(), kEmbeddingsOutputSize); |
| } |
| |
| histogram_tester_.ExpectUniqueSample(kCacheHitMetricName, false, 3); |
| } |
| |
| TEST_F(PassageEmbeddingsServiceTest, CacheHits) { |
| mojo::Remote<mojom::PassageEmbedder> embedder_remote; |
| base::test::TestFuture<bool> load_models_future; |
| service()->LoadModels( |
| MakeModelParams(embeddings_path_, sp_path_, kInputWindowSize), |
| MakeEmbedderParams(), embedder_remote.BindNewPipeAndPassReceiver(), |
| load_models_future.GetCallback()); |
| bool load_models_success = load_models_future.Get(); |
| EXPECT_TRUE(load_models_success); |
| |
| base::test::TestFuture<std::vector<mojom::PassageEmbeddingsResultPtr>> |
| execute_future; |
| embedder_remote->GenerateEmbeddings( |
| {"hello", "world", "hello", "world", "foo", ""}, |
| mojom::PassagePriority::kUserInitiated, execute_future.GetCallback()); |
| auto results = execute_future.Take(); |
| |
| EXPECT_EQ(results.size(), 6u); |
| |
| EXPECT_EQ(results[0]->embeddings, results[2]->embeddings); |
| EXPECT_EQ(results[1]->embeddings, results[3]->embeddings); |
| |
| for (const auto& result : results) { |
| EXPECT_EQ(result->embeddings.size(), kEmbeddingsOutputSize); |
| } |
| |
| histogram_tester_.ExpectTotalCount(kCacheHitMetricName, 6); |
| histogram_tester_.ExpectBucketCount(kCacheHitMetricName, true, 2); |
| histogram_tester_.ExpectBucketCount(kCacheHitMetricName, false, 4); |
| } |
| |
| TEST_F(PassageEmbeddingsServiceTest, RecordsDurationHistogramsWithPriority) { |
| mojo::Remote<mojom::PassageEmbedder> embedder_remote; |
| base::test::TestFuture<bool> load_models_future; |
| service()->LoadModels( |
| MakeModelParams(embeddings_path_, sp_path_, kInputWindowSize), |
| MakeEmbedderParams(), embedder_remote.BindNewPipeAndPassReceiver(), |
| load_models_future.GetCallback()); |
| std::ignore = load_models_future.Take(); |
| |
| base::test::TestFuture<std::vector<mojom::PassageEmbeddingsResultPtr>> |
| execute_future; |
| embedder_remote->GenerateEmbeddings({"hello", "world"}, |
| mojom::PassagePriority::kPassive, |
| execute_future.GetCallback()); |
| std::ignore = execute_future.Take(); |
| |
| embedder_remote->GenerateEmbeddings({"foo"}, |
| mojom::PassagePriority::kUserInitiated, |
| execute_future.GetCallback()); |
| std::ignore = execute_future.Take(); |
| |
| histogram_tester_.ExpectTotalCount( |
| "History.Embeddings.Embedder.PassageEmbeddingsGenerationDuration", 2); |
| histogram_tester_.ExpectTotalCount( |
| "History.Embeddings.Embedder.PassageEmbeddingsGenerationThreadDuration", |
| 2); |
| histogram_tester_.ExpectTotalCount( |
| "History.Embeddings.Embedder.QueryEmbeddingsGenerationDuration", 1); |
| histogram_tester_.ExpectTotalCount( |
| "History.Embeddings.Embedder.QueryEmbeddingsGenerationThreadDuration", 1); |
| } |
| |
| } // namespace |
| } // namespace passage_embeddings |