blob: 489e1a47601cc7f11ef55840a6ce4cd802d37bf2 [file] [log] [blame]
// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/passage_embeddings/internal/scheduling_embedder.h"
#include <initializer_list>
#include <memory>
#include <optional>
#include <tuple>
#include "base/functional/bind.h"
#include "base/logging.h"
#include "base/task/sequenced_task_runner.h"
#include "base/test/bind.h"
#include "base/test/metrics/histogram_tester.h"
#include "base/test/task_environment.h"
#include "base/test/test_future.h"
#include "components/passage_embeddings/passage_embeddings_service_controller.h"
#include "components/passage_embeddings/passage_embeddings_test_util.h"
#include "components/passage_embeddings/passage_embeddings_types.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace passage_embeddings {
namespace {
using testing::ElementsAre;
using ComputePassagesEmbeddingsFuture =
base::test::TestFuture<std::vector<std::string>,
std::vector<Embedding>,
SchedulingEmbedder::TaskId,
ComputeEmbeddingsStatus>;
std::vector<mojom::PassageEmbeddingsResultPtr> GenerateExpectedServiceOutput(
std::initializer_list<std::string> passages) {
std::vector<mojom::PassageEmbeddingsResultPtr> results;
for (size_t i = 0; i < passages.size(); ++i) {
results.push_back(
mojom::PassageEmbeddingsResult::New(std::vector<float>{1.0f}));
}
return results;
}
void IgnoreResults(std::vector<std::string>,
std::vector<Embedding>,
SchedulingEmbedder::TaskId,
ComputeEmbeddingsStatus) {}
} // namespace
class GetEmbeddingsStub {
public:
MOCK_METHOD(void,
GetEmbeddings,
(std::vector<std::string> passages,
PassagePriority priority,
SchedulingEmbedder::GetEmbeddingsResultCallback callback));
};
class SchedulingEmbedderTest : public testing::Test {
public:
void SetUp() override {
embedder_metadata_provider_ =
std::make_unique<TestEmbedderMetadataProvider>();
}
void TearDown() override { embedder_metadata_provider_.reset(); }
protected:
base::test::TaskEnvironment task_environment_{
base::test::TaskEnvironment::TimeSource::MOCK_TIME};
std::unique_ptr<EmbedderMetadataProvider> embedder_metadata_provider_;
GetEmbeddingsStub get_embeddings_stub_;
};
TEST_F(SchedulingEmbedderTest, InvokesService) {
auto embedder = std::make_unique<SchedulingEmbedder>(
embedder_metadata_provider_.get(),
base::BindRepeating(&GetEmbeddingsStub::GetEmbeddings,
base::Unretained(&get_embeddings_stub_)),
/*max_jobs=*/4u,
/*max_batch_size=*/1u,
/*use_performance_scenario=*/false);
std::vector<std::string> requested_passages;
std::optional<PassagePriority> passage_priority;
SchedulingEmbedder::GetEmbeddingsResultCallback result_callback;
EXPECT_CALL(get_embeddings_stub_, GetEmbeddings)
.WillOnce([&](std::vector<std::string> passages, PassagePriority priority,
SchedulingEmbedder::GetEmbeddingsResultCallback callback) {
requested_passages = std::move(passages);
passage_priority = priority;
result_callback = std::move(callback);
});
embedder->ComputePassagesEmbeddings(PassagePriority::kPassive,
{"test passage 1"},
base::BindOnce(&IgnoreResults));
EXPECT_THAT(requested_passages, ElementsAre("test passage 1"));
EXPECT_EQ(passage_priority, PassagePriority::kPassive);
ASSERT_FALSE(result_callback.is_null());
// Run the callback to notify the SchedulingEmbedder of the processed output.
std::move(result_callback)
.Run(GenerateExpectedServiceOutput({"test passage 1"}),
ComputeEmbeddingsStatus::kSuccess);
}
TEST_F(SchedulingEmbedderTest, TranslatesServiceOutput) {
auto embedder = std::make_unique<SchedulingEmbedder>(
embedder_metadata_provider_.get(),
base::BindRepeating(&GetEmbeddingsStub::GetEmbeddings,
base::Unretained(&get_embeddings_stub_)),
/*max_jobs=*/4u,
/*max_batch_size=*/2u,
/*use_performance_scenario=*/false);
EXPECT_CALL(get_embeddings_stub_, GetEmbeddings)
.WillOnce(
[](std::vector<std::string> passages, PassagePriority priority,
SchedulingEmbedder::GetEmbeddingsResultCallback callback) {
std::vector<mojom::PassageEmbeddingsResultPtr> results;
results.push_back(mojom::PassageEmbeddingsResult::New(
std::vector<float>{1.0f, 0.0f}));
results.push_back(mojom::PassageEmbeddingsResult::New(
std::vector<float>{0.0f, 1.0f}));
std::move(callback).Run(std::move(results),
ComputeEmbeddingsStatus::kSuccess);
});
ComputePassagesEmbeddingsFuture future;
Embedder::TaskId task_id = embedder->ComputePassagesEmbeddings(
PassagePriority::kPassive, {"test passage 1", "test passage 2"},
future.GetCallback());
auto [passages, embeddings, received_task_id, status] = future.Take();
EXPECT_THAT(passages, ElementsAre("test passage 1", "test passage 2"));
ASSERT_EQ(embeddings.size(), 2u);
EXPECT_THAT(embeddings[0].GetData(), ElementsAre(1.0f, 0.0f));
EXPECT_THAT(embeddings[1].GetData(), ElementsAre(0.0f, 1.0f));
EXPECT_EQ(task_id, received_task_id);
EXPECT_EQ(status, ComputeEmbeddingsStatus::kSuccess);
}
TEST_F(SchedulingEmbedderTest, UserInitiatedJobTakesPriority) {
auto embedder = std::make_unique<SchedulingEmbedder>(
embedder_metadata_provider_.get(),
base::BindRepeating(&GetEmbeddingsStub::GetEmbeddings,
base::Unretained(&get_embeddings_stub_)),
/*max_jobs=*/4u,
/*max_batch_size=*/1u,
/*use_performance_scenario=*/false);
struct Call {
std::vector<std::string> passages;
PassagePriority priority;
SchedulingEmbedder::GetEmbeddingsResultCallback callback;
};
std::vector<Call> calls;
const auto save_call_parameters =
[&calls](std::vector<std::string> passages, PassagePriority priority,
SchedulingEmbedder::GetEmbeddingsResultCallback callback) {
calls.emplace_back(std::move(passages), priority, std::move(callback));
};
EXPECT_CALL(get_embeddings_stub_, GetEmbeddings)
.WillOnce(save_call_parameters)
.WillOnce(save_call_parameters)
.WillOnce(save_call_parameters);
// Submit a passive priority task.
embedder->ComputePassagesEmbeddings(PassagePriority::kPassive,
{"test passage 1", "test passage 2"},
base::BindOnce(&IgnoreResults));
// Submit a user-initiated priority task. This will suspend the partially
// completed passive priority task.
embedder->ComputePassagesEmbeddings(PassagePriority::kUserInitiated,
{"query"},
base::BindOnce(&IgnoreResults));
ASSERT_EQ(calls.size(), 1u);
EXPECT_THAT(calls.back().passages, ElementsAre("test passage 1"));
EXPECT_EQ(calls.back().priority, PassagePriority::kPassive);
// Running the callback should kick off the next round of processing.
ASSERT_FALSE(calls.back().callback.is_null());
std::move(calls.back().callback)
.Run(GenerateExpectedServiceOutput({"test passage 1"}),
ComputeEmbeddingsStatus::kSuccess);
ASSERT_EQ(calls.size(), 2u);
EXPECT_THAT(calls.back().passages, ElementsAre("query"));
EXPECT_EQ(calls.back().priority, PassagePriority::kUserInitiated);
ASSERT_FALSE(calls.back().callback.is_null());
std::move(calls.back().callback)
.Run(GenerateExpectedServiceOutput({"query"}),
ComputeEmbeddingsStatus::kSuccess);
ASSERT_EQ(calls.size(), 3u);
EXPECT_THAT(calls.back().passages, ElementsAre("test passage 2"));
EXPECT_EQ(calls.back().priority, PassagePriority::kPassive);
ASSERT_FALSE(calls.back().callback.is_null());
std::move(calls.back().callback)
.Run(GenerateExpectedServiceOutput({"test passage 2"}),
ComputeEmbeddingsStatus::kSuccess);
}
TEST_F(SchedulingEmbedderTest, TryCancel) {
auto embedder = std::make_unique<SchedulingEmbedder>(
embedder_metadata_provider_.get(),
base::BindRepeating(&GetEmbeddingsStub::GetEmbeddings,
base::Unretained(&get_embeddings_stub_)),
/*max_jobs=*/4u,
/*max_batch_size=*/1u,
/*use_performance_scenario=*/false);
std::vector<std::string> requested_passages;
SchedulingEmbedder::GetEmbeddingsResultCallback result_callback;
EXPECT_CALL(get_embeddings_stub_, GetEmbeddings)
.WillOnce([&requested_passages, &result_callback](
std::vector<std::string> passages, PassagePriority priority,
SchedulingEmbedder::GetEmbeddingsResultCallback callback) {
requested_passages = std::move(passages);
result_callback = std::move(callback);
});
embedder->ComputePassagesEmbeddings(PassagePriority::kPassive,
{"test passage 1"},
base::BindOnce(&IgnoreResults));
Embedder::TaskId second_task_id = embedder->ComputePassagesEmbeddings(
PassagePriority::kPassive, {"test passage 2"},
base::BindOnce(&IgnoreResults));
embedder->TryCancel(second_task_id);
EXPECT_THAT(requested_passages, ElementsAre("test passage 1"));
// Running the callback should not kick off any further processing. This is
// validated by the WillOnce expectation above.
ASSERT_FALSE(result_callback.is_null());
std::move(result_callback)
.Run(GenerateExpectedServiceOutput({"test passage 1"}),
ComputeEmbeddingsStatus::kSuccess);
}
TEST_F(SchedulingEmbedderTest, RecordsHistograms) {
base::HistogramTester histogram_tester;
auto embedder = std::make_unique<SchedulingEmbedder>(
embedder_metadata_provider_.get(),
base::BindRepeating(&GetEmbeddingsStub::GetEmbeddings,
base::Unretained(&get_embeddings_stub_)),
/*max_jobs=*/4u,
/*max_batch_size=*/1u,
/*use_performance_scenario=*/false);
std::vector<SchedulingEmbedder::GetEmbeddingsResultCallback> callbacks;
const auto record_callback =
[&callbacks](std::vector<std::string> passages, PassagePriority priority,
SchedulingEmbedder::GetEmbeddingsResultCallback callback) {
callbacks.push_back(std::move(callback));
};
EXPECT_CALL(get_embeddings_stub_, GetEmbeddings)
.WillOnce(record_callback)
.WillOnce(record_callback);
ComputePassagesEmbeddingsFuture future1;
embedder->ComputePassagesEmbeddings(
PassagePriority::kPassive, {"test passage 1"}, future1.GetCallback());
ComputePassagesEmbeddingsFuture future2;
Embedder::TaskId task_id = embedder->ComputePassagesEmbeddings(
PassagePriority::kUserInitiated, {"test passage 2a", "test passage 2b"},
future2.GetCallback());
ComputePassagesEmbeddingsFuture future3;
embedder->ComputePassagesEmbeddings(
PassagePriority::kPassive, {"test passage 3"}, future3.GetCallback());
embedder->TryCancel(task_id);
ASSERT_EQ(callbacks.size(), 1u);
ASSERT_FALSE(callbacks.back().is_null());
std::move(callbacks.back())
.Run(GenerateExpectedServiceOutput({"test passage 1"}),
ComputeEmbeddingsStatus::kSuccess);
ASSERT_EQ(callbacks.size(), 2u);
ASSERT_FALSE(callbacks.back().is_null());
std::move(callbacks.back())
.Run(GenerateExpectedServiceOutput({"test passage 3"}),
ComputeEmbeddingsStatus::kSuccess);
ASSERT_TRUE(future1.Wait());
ASSERT_TRUE(future2.Wait());
ASSERT_TRUE(future3.Wait());
// Only the two "passive priority" jobs successfully completed; the "user
// initiate priority" one was canceled. So only two duration histogram samples
// are logged, but three counts histograms samples and three status histogram
// samples are logged as the all jobs were enqueued and completed in some way.
histogram_tester.ExpectTotalCount("History.Embeddings.ScheduledJobDuration",
2);
histogram_tester.ExpectTotalCount(
"History.Embeddings.ScheduledJobDuration.Passive", 2);
histogram_tester.ExpectTotalCount("History.Embeddings.ScheduledJobStatus", 3);
histogram_tester.ExpectTotalCount(
"History.Embeddings.ScheduledJobStatus.Passive", 2);
histogram_tester.ExpectBucketCount(
"History.Embeddings.ScheduledJobStatus.Passive",
ComputeEmbeddingsStatus::kSuccess, 2);
histogram_tester.ExpectTotalCount(
"History.Embeddings.ScheduledJobStatus.UserInitiated", 1);
histogram_tester.ExpectBucketCount(
"History.Embeddings.ScheduledJobStatus.UserInitiated",
ComputeEmbeddingsStatus::kCanceled, 1);
histogram_tester.ExpectTotalCount("History.Embeddings.ScheduledJobCount", 3);
histogram_tester.ExpectBucketCount("History.Embeddings.ScheduledJobCount", 0,
1);
histogram_tester.ExpectBucketCount("History.Embeddings.ScheduledJobCount", 1,
1);
histogram_tester.ExpectBucketCount("History.Embeddings.ScheduledJobCount", 2,
1);
histogram_tester.ExpectTotalCount("History.Embeddings.ScheduledPassageCount",
3);
histogram_tester.ExpectBucketCount("History.Embeddings.ScheduledPassageCount",
0, 1);
histogram_tester.ExpectBucketCount("History.Embeddings.ScheduledPassageCount",
1, 1);
// When the third job is enqueued, 1 + 2 = 3 passages are waiting in the
// previous two jobs.
histogram_tester.ExpectBucketCount("History.Embeddings.ScheduledPassageCount",
3, 1);
}
TEST_F(SchedulingEmbedderTest, LimitsJobCount) {
auto embedder = std::make_unique<SchedulingEmbedder>(
embedder_metadata_provider_.get(),
base::BindRepeating(&GetEmbeddingsStub::GetEmbeddings,
base::Unretained(&get_embeddings_stub_)),
/*max_jobs=*/2u,
/*max_batch_size=*/1u,
/*use_performance_scenario=*/false);
std::vector<SchedulingEmbedder::GetEmbeddingsResultCallback> callbacks;
const auto record_callback =
[&callbacks](std::vector<std::string> passages, PassagePriority priority,
SchedulingEmbedder::GetEmbeddingsResultCallback callback) {
callbacks.push_back(std::move(callback));
};
EXPECT_CALL(get_embeddings_stub_, GetEmbeddings)
.WillOnce(record_callback)
.WillOnce(record_callback);
ComputePassagesEmbeddingsFuture future1;
embedder->ComputePassagesEmbeddings(
PassagePriority::kPassive, {"test passage 1"}, future1.GetCallback());
ComputePassagesEmbeddingsFuture future2;
embedder->ComputePassagesEmbeddings(
PassagePriority::kPassive, {"test passage 2"}, future2.GetCallback());
ComputePassagesEmbeddingsFuture future3;
embedder->ComputePassagesEmbeddings(
PassagePriority::kPassive, {"test passage 3"}, future3.GetCallback());
ASSERT_EQ(callbacks.size(), 1u);
ASSERT_FALSE(callbacks.back().is_null());
std::move(callbacks.back())
.Run(GenerateExpectedServiceOutput({"test passage 1"}),
ComputeEmbeddingsStatus::kSuccess);
ASSERT_EQ(callbacks.size(), 2u);
ASSERT_FALSE(callbacks.back().is_null());
std::move(callbacks.back())
.Run(GenerateExpectedServiceOutput({"test passage 3"}),
ComputeEmbeddingsStatus::kSuccess);
// Final job interrupts the job at back of line when the limit is reached.
EXPECT_EQ(std::get<3>(future1.Take()), ComputeEmbeddingsStatus::kSuccess);
EXPECT_EQ(std::get<3>(future2.Take()), ComputeEmbeddingsStatus::kCanceled);
EXPECT_EQ(std::get<3>(future3.Take()), ComputeEmbeddingsStatus::kSuccess);
}
} // namespace passage_embeddings