blob: f1b7b7ca48314bf28c8bb701a0c760f60abe2c3e [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/passage_embeddings/passage_embeddings_service_controller.h"
#include <memory>
#include "base/memory/raw_ptr.h"
#include "base/path_service.h"
#include "base/scoped_observation.h"
#include "base/test/bind.h"
#include "base/test/metrics/histogram_tester.h"
#include "base/test/task_environment.h"
#include "base/test/test_future.h"
#include "components/optimization_guide/core/optimization_guide_proto_util.h"
#include "components/optimization_guide/core/test_model_info_builder.h"
#include "components/passage_embeddings/passage_embeddings_test_util.h"
#include "components/passage_embeddings/passage_embeddings_types.h"
#include "mojo/public/cpp/bindings/receiver.h"
#include "services/passage_embeddings/public/mojom/passage_embeddings.mojom.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace passage_embeddings {
namespace {
using ComputePassagesEmbeddingsFuture =
base::test::TestFuture<std::vector<std::string>,
std::vector<Embedding>,
Embedder::TaskId,
ComputeEmbeddingsStatus>;
class FakePassageEmbedder : public mojom::PassageEmbedder {
public:
explicit FakePassageEmbedder(
mojo::PendingReceiver<mojom::PassageEmbedder> receiver)
: receiver_(this, std::move(receiver)) {}
private:
// mojom::PassageEmbedder:
void GenerateEmbeddings(const std::vector<std::string>& inputs,
mojom::PassagePriority priority,
GenerateEmbeddingsCallback callback) override {
std::vector<mojom::PassageEmbeddingsResultPtr> results;
for (const std::string& input : inputs) {
// Fail Embeddings generation for the entire batch when encountering
// "error" string to simulate failed model execution.
if (input == "error") {
return std::move(callback).Run({});
}
results.push_back(mojom::PassageEmbeddingsResult::New());
results.back()->embeddings =
std::vector<float>(kEmbeddingsModelOutputSize, 1.0);
results.back()->passage = input;
}
std::move(callback).Run(std::move(results));
}
mojo::Receiver<mojom::PassageEmbedder> receiver_;
};
class FakePassageEmbeddingsService : public mojom::PassageEmbeddingsService {
public:
explicit FakePassageEmbeddingsService(
mojo::PendingReceiver<mojom::PassageEmbeddingsService> receiver)
: receiver_(this, std::move(receiver)) {}
private:
// mojom::PassageEmbeddingsService:
void LoadModels(mojom::PassageEmbeddingsLoadModelsParamsPtr model_params,
mojom::PassageEmbedderParamsPtr embedder_params,
mojo::PendingReceiver<mojom::PassageEmbedder> receiver,
LoadModelsCallback callback) override {
bool valid = model_params->input_window_size != 0;
if (valid) {
embedder_ = std::make_unique<FakePassageEmbedder>(std::move(receiver));
}
// Use input window size as a signal to fail the request.
std::move(callback).Run(valid);
}
mojo::Receiver<mojom::PassageEmbeddingsService> receiver_;
std::unique_ptr<FakePassageEmbedder> embedder_;
};
class FakePassageEmbeddingsServiceController
: public PassageEmbeddingsServiceController {
public:
FakePassageEmbeddingsServiceController() = default;
~FakePassageEmbeddingsServiceController() override = default;
void MaybeLaunchService() override {
service_remote_.reset();
service_ = std::make_unique<FakePassageEmbeddingsService>(
service_remote_.BindNewPipeAndPassReceiver());
}
using PassageEmbeddingsServiceController::GetEmbeddingsCallback;
using PassageEmbeddingsServiceController::ResetEmbedderRemote;
void ResetServiceRemote() override {
ResetEmbedderRemote();
service_remote_.reset();
}
using PassageEmbeddingsServiceController::GetEmbeddings;
private:
std::unique_ptr<FakePassageEmbeddingsService> service_;
};
class FakeEmbedder : public TestEmbedder, public EmbedderMetadataObserver {
public:
explicit FakeEmbedder(
EmbedderMetadataProvider* embedder_metadata_provider,
FakePassageEmbeddingsServiceController::GetEmbeddingsCallback
get_embeddings_callback,
base::test::TestFuture<EmbedderMetadata>* embedder_metadata_future)
: get_embeddings_callback_(get_embeddings_callback),
embedder_metadata_future_(embedder_metadata_future) {
embedder_metadata_observation_.Observe(embedder_metadata_provider);
}
// Embedder:
Embedder::TaskId ComputePassagesEmbeddings(
PassagePriority priority,
std::vector<std::string> passages,
ComputePassagesEmbeddingsCallback callback) override {
get_embeddings_callback_.Run(
passages, priority,
base::BindOnce(
[](std::vector<std::string> passages,
ComputePassagesEmbeddingsCallback callback,
std::vector<mojom::PassageEmbeddingsResultPtr> results,
ComputeEmbeddingsStatus status) {
std::vector<Embedding> embeddings;
if (status == ComputeEmbeddingsStatus::kSuccess) {
embeddings = ComputeEmbeddingsForPassages(passages);
}
std::move(callback).Run(passages, embeddings, kInvalidTaskId,
status);
},
passages, std::move(callback)));
return kInvalidTaskId;
}
protected:
// EmbedderMetadataObserver:
void EmbedderMetadataUpdated(EmbedderMetadata metadata) override {
embedder_metadata_future_->SetValue(metadata);
}
base::ScopedObservation<EmbedderMetadataProvider, EmbedderMetadataObserver>
embedder_metadata_observation_{this};
FakePassageEmbeddingsServiceController::GetEmbeddingsCallback
get_embeddings_callback_;
raw_ptr<base::test::TestFuture<EmbedderMetadata>> embedder_metadata_future_;
};
} // namespace
class PassageEmbeddingsServiceControllerTest : public testing::Test {
public:
void SetUp() override {
service_controller_ =
std::make_unique<FakePassageEmbeddingsServiceController>();
service_controller_->SetEmbedderForTesting(std::make_unique<FakeEmbedder>(
/*embedder_metadata_provider=*/service_controller_.get(),
/*get_embeddings_callback=*/
base::BindRepeating(
&FakePassageEmbeddingsServiceController::GetEmbeddings,
base::Unretained(service_controller_.get())),
/*embedder_metadata_future=*/embedder_metadata_future()));
EXPECT_FALSE(embedder_metadata_future()->IsReady());
}
protected:
base::test::TestFuture<EmbedderMetadata>* embedder_metadata_future() {
return &embedder_metadata_future_;
}
Embedder* embedder() { return service_controller_->GetEmbedder(); }
base::test::TaskEnvironment task_environment_;
base::HistogramTester histogram_tester_;
base::test::TestFuture<EmbedderMetadata> embedder_metadata_future_;
std::unique_ptr<FakePassageEmbeddingsServiceController> service_controller_;
};
TEST_F(PassageEmbeddingsServiceControllerTest, ReceivesValidModelInfo) {
EXPECT_TRUE(service_controller_->MaybeUpdateModelInfo(
*GetBuilderWithValidModelInfo().Build()));
auto metadata = embedder_metadata_future()->Take();
EXPECT_TRUE(metadata.IsValid());
EXPECT_EQ(metadata.model_version, kEmbeddingsModelVersion);
EXPECT_EQ(metadata.output_size, kEmbeddingsModelOutputSize);
histogram_tester_.ExpectTotalCount(kModelInfoMetricName, 1);
histogram_tester_.ExpectUniqueSample(kModelInfoMetricName,
EmbeddingsModelInfoStatus::kValid, 1);
}
TEST_F(PassageEmbeddingsServiceControllerTest, ReceivesEmptyModelInfo) {
EXPECT_FALSE(service_controller_->MaybeUpdateModelInfo({}));
EXPECT_FALSE(embedder_metadata_future()->IsReady());
histogram_tester_.ExpectTotalCount(kModelInfoMetricName, 1);
histogram_tester_.ExpectUniqueSample(kModelInfoMetricName,
EmbeddingsModelInfoStatus::kEmpty, 1);
}
TEST_F(PassageEmbeddingsServiceControllerTest,
ReceivesModelInfoWithInvalidModelMetadata) {
optimization_guide::proto::Any metadata_any;
metadata_any.set_type_url("not a valid type url");
metadata_any.set_value("not a valid serialized metadata");
optimization_guide::TestModelInfoBuilder builder =
GetBuilderWithValidModelInfo();
builder.SetModelMetadata(metadata_any);
EXPECT_FALSE(service_controller_->MaybeUpdateModelInfo(*builder.Build()));
EXPECT_FALSE(embedder_metadata_future()->IsReady());
histogram_tester_.ExpectTotalCount(kModelInfoMetricName, 1);
histogram_tester_.ExpectUniqueSample(
kModelInfoMetricName, EmbeddingsModelInfoStatus::kInvalidMetadata, 1);
}
TEST_F(PassageEmbeddingsServiceControllerTest,
ReceivesModelInfoWithoutModelMetadata) {
optimization_guide::TestModelInfoBuilder builder =
GetBuilderWithValidModelInfo();
builder.SetModelMetadata(std::nullopt);
EXPECT_FALSE(service_controller_->MaybeUpdateModelInfo(*builder.Build()));
EXPECT_FALSE(embedder_metadata_future()->IsReady());
histogram_tester_.ExpectTotalCount(kModelInfoMetricName, 1);
histogram_tester_.ExpectUniqueSample(
kModelInfoMetricName, EmbeddingsModelInfoStatus::kNoMetadata, 1);
}
TEST_F(PassageEmbeddingsServiceControllerTest,
ReceivesModelInfoWithoutAdditionalFiles) {
base::FilePath test_data_dir;
base::PathService::Get(base::DIR_SRC_TEST_DATA_ROOT, &test_data_dir);
optimization_guide::TestModelInfoBuilder builder =
GetBuilderWithValidModelInfo();
builder.SetAdditionalFiles(
{test_data_dir.AppendASCII("foo"), test_data_dir.AppendASCII("bar")});
EXPECT_FALSE(service_controller_->MaybeUpdateModelInfo(*builder.Build()));
EXPECT_FALSE(embedder_metadata_future()->IsReady());
histogram_tester_.ExpectTotalCount(kModelInfoMetricName, 1);
histogram_tester_.ExpectUniqueSample(
kModelInfoMetricName, EmbeddingsModelInfoStatus::kInvalidAdditionalFiles,
1);
}
TEST_F(PassageEmbeddingsServiceControllerTest, ReceivesEmptyPassages) {
EXPECT_TRUE(service_controller_->MaybeUpdateModelInfo(
*GetBuilderWithValidModelInfo().Build()));
ComputePassagesEmbeddingsFuture future;
embedder()->ComputePassagesEmbeddings(PassagePriority::kPassive, {},
future.GetCallback());
auto [passages, embeddings, task_id, status] = future.Get();
EXPECT_EQ(status, ComputeEmbeddingsStatus::kSuccess);
EXPECT_TRUE(passages.empty());
EXPECT_TRUE(embeddings.empty());
}
TEST_F(PassageEmbeddingsServiceControllerTest, ReturnsEmbeddings) {
EXPECT_TRUE(service_controller_->MaybeUpdateModelInfo(
*GetBuilderWithValidModelInfo().Build()));
ComputePassagesEmbeddingsFuture future;
embedder()->ComputePassagesEmbeddings(PassagePriority::kPassive,
{"foo", "bar"}, future.GetCallback());
auto [passages, embeddings, task_id, status] = future.Get();
EXPECT_EQ(status, ComputeEmbeddingsStatus::kSuccess);
EXPECT_EQ(passages[0], "foo");
EXPECT_EQ(passages[1], "bar");
EXPECT_EQ(embeddings[0].Dimensions(), kEmbeddingsModelOutputSize);
EXPECT_EQ(embeddings[1].Dimensions(), kEmbeddingsModelOutputSize);
}
TEST_F(PassageEmbeddingsServiceControllerTest,
ReturnsModelUnavailableErrorIfModelInfoNotValid) {
optimization_guide::TestModelInfoBuilder builder =
GetBuilderWithValidModelInfo();
builder.SetModelMetadata(std::nullopt);
EXPECT_FALSE(service_controller_->MaybeUpdateModelInfo(*builder.Build()));
ComputePassagesEmbeddingsFuture future;
embedder()->ComputePassagesEmbeddings(PassagePriority::kPassive,
{"foo", "bar"}, future.GetCallback());
auto [passages, embeddings, task_id, status] = future.Get();
EXPECT_EQ(status, ComputeEmbeddingsStatus::kModelUnavailable);
EXPECT_EQ(passages[0], "foo");
EXPECT_EQ(passages[1], "bar");
EXPECT_TRUE(embeddings.empty());
}
TEST_F(PassageEmbeddingsServiceControllerTest, ReturnsExecutionFailure) {
EXPECT_TRUE(service_controller_->MaybeUpdateModelInfo(
*GetBuilderWithValidModelInfo().Build()));
ComputePassagesEmbeddingsFuture future;
embedder()->ComputePassagesEmbeddings(PassagePriority::kPassive,
{"error", "baz"}, future.GetCallback());
auto [passages, embeddings, task_id, status] = future.Get();
EXPECT_EQ(status, ComputeEmbeddingsStatus::kExecutionFailure);
EXPECT_EQ(passages[0], "error");
EXPECT_EQ(passages[1], "baz");
EXPECT_TRUE(embeddings.empty());
}
TEST_F(PassageEmbeddingsServiceControllerTest, EmbedderRunningStatus) {
EXPECT_TRUE(service_controller_->MaybeUpdateModelInfo(
*GetBuilderWithValidModelInfo().Build()));
{
ComputePassagesEmbeddingsFuture future1;
embedder()->ComputePassagesEmbeddings(
PassagePriority::kPassive, {"foo", "bar"}, future1.GetCallback());
// Embedder is running.
EXPECT_TRUE(service_controller_->EmbedderRunning());
ComputePassagesEmbeddingsFuture future2;
embedder()->ComputePassagesEmbeddings(
PassagePriority::kPassive, {"baz", "qux"}, future2.GetCallback());
// Embedder is running.
EXPECT_TRUE(service_controller_->EmbedderRunning());
auto status1 = future1.Get<3>();
EXPECT_EQ(status1, ComputeEmbeddingsStatus::kSuccess);
// Embedder is still running.
EXPECT_TRUE(service_controller_->EmbedderRunning());
auto status2 = future2.Get<3>();
EXPECT_EQ(status2, ComputeEmbeddingsStatus::kSuccess);
// Embedder is NOT running.
EXPECT_FALSE(service_controller_->EmbedderRunning());
}
{
ComputePassagesEmbeddingsFuture future1;
embedder()->ComputePassagesEmbeddings(
PassagePriority::kPassive, {"foo", "bar"}, future1.GetCallback());
// Embedder is running.
EXPECT_TRUE(service_controller_->EmbedderRunning());
ComputePassagesEmbeddingsFuture future2;
embedder()->ComputePassagesEmbeddings(
PassagePriority::kPassive, {"baz", "qux"}, future2.GetCallback());
// Embedder is running.
EXPECT_TRUE(service_controller_->EmbedderRunning());
// Callbacks are invoked synchronously on embedder remote disconnect.
service_controller_->ResetEmbedderRemote();
// Embedder is NOT running.
EXPECT_FALSE(service_controller_->EmbedderRunning());
auto status1 = future1.Get<3>();
EXPECT_EQ(status1, ComputeEmbeddingsStatus::kExecutionFailure);
auto status2 = future2.Get<3>();
EXPECT_EQ(status2, ComputeEmbeddingsStatus::kExecutionFailure);
}
{
// Calling `ComputePassagesEmbeddings()` again launches the service.
ComputePassagesEmbeddingsFuture future1;
embedder()->ComputePassagesEmbeddings(
PassagePriority::kPassive, {"foo", "bar"}, future1.GetCallback());
// Embedder is running.
EXPECT_TRUE(service_controller_->EmbedderRunning());
ComputePassagesEmbeddingsFuture future2;
embedder()->ComputePassagesEmbeddings(
PassagePriority::kPassive, {"baz", "qux"}, future2.GetCallback());
// Embedder is running.
EXPECT_TRUE(service_controller_->EmbedderRunning());
auto status1 = future1.Get<3>();
EXPECT_EQ(status1, ComputeEmbeddingsStatus::kSuccess);
// Embedder is still running.
EXPECT_TRUE(service_controller_->EmbedderRunning());
auto status2 = future2.Get<3>();
EXPECT_EQ(status2, ComputeEmbeddingsStatus::kSuccess);
// Embedder is NOT running.
EXPECT_FALSE(service_controller_->EmbedderRunning());
}
{
ComputePassagesEmbeddingsFuture future1;
embedder()->ComputePassagesEmbeddings(
PassagePriority::kPassive, {"foo", "bar"}, future1.GetCallback());
// Embedder is running.
EXPECT_TRUE(service_controller_->EmbedderRunning());
ComputePassagesEmbeddingsFuture future2;
embedder()->ComputePassagesEmbeddings(
PassagePriority::kPassive, {"baz", "qux"}, future2.GetCallback());
// Embedder is still running.
EXPECT_TRUE(service_controller_->EmbedderRunning());
// Callbacks are invoked synchronously on service remote disconnect.
service_controller_->ResetServiceRemote();
// Embedder is NOT running.
EXPECT_FALSE(service_controller_->EmbedderRunning());
auto status1 = future1.Get<3>();
EXPECT_EQ(status1, ComputeEmbeddingsStatus::kExecutionFailure);
auto status2 = future2.Get<3>();
EXPECT_EQ(status2, ComputeEmbeddingsStatus::kExecutionFailure);
}
}
} // namespace passage_embeddings