blob: 8e792dfde809af0e8a1e951c05dd40dcd7b216ab [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/passage_embeddings/passage_embeddings_service_controller.h"
#include <ranges>
#include "base/functional/bind.h"
#include "base/metrics/histogram_functions.h"
#include "base/not_fatal_until.h"
#include "base/notreached.h"
#include "base/task/thread_pool.h"
#include "components/optimization_guide/core/optimization_guide_util.h"
#include "components/passage_embeddings/internal/scheduling_embedder.h"
#include "components/passage_embeddings/passage_embeddings_features.h"
#include "components/passage_embeddings/passage_embeddings_types.h"
#include "mojo/public/cpp/bindings/callback_helpers.h"
#include "services/passage_embeddings/public/mojom/passage_embeddings.mojom.h"
namespace passage_embeddings {
namespace {
mojom::PassageEmbeddingsLoadModelsParamsPtr MakeModelParams(
const base::FilePath& embeddings_path,
const base::FilePath& sp_path,
uint32_t input_window_size) {
auto params = mojom::PassageEmbeddingsLoadModelsParams::New();
params->embeddings_model = base::File(
embeddings_path, base::File::FLAG_OPEN | base::File::FLAG_READ);
params->sp_model =
base::File(sp_path, base::File::FLAG_OPEN | base::File::FLAG_READ);
params->input_window_size = input_window_size;
return params;
}
// Makes the parameters used to run the passage embedder.
mojom::PassageEmbedderParamsPtr MakeEmbedderParams() {
auto params = mojom::PassageEmbedderParams::New();
params->user_initiated_priority_num_threads =
kUserInitiatedPriorityNumThreads.Get();
params->passive_priority_num_threads = kPassivePriorityNumThreads.Get();
params->embedder_cache_size = kEmbedderCacheSize.Get();
params->allow_gpu_execution = kAllowGpuExecution.Get();
return params;
}
mojom::PassagePriority PassagePriorityToMojom(PassagePriority priority) {
switch (priority) {
case kUserInitiated:
return mojom::PassagePriority::kUserInitiated;
case kPassive:
case kLatent:
return mojom::PassagePriority::kPassive;
}
}
class ScopedEmbeddingsModelInfoStatusLogger {
public:
ScopedEmbeddingsModelInfoStatusLogger() = default;
~ScopedEmbeddingsModelInfoStatusLogger() {
CHECK_NE(EmbeddingsModelInfoStatus::kUnknown, status_);
base::UmaHistogramEnumeration(kModelInfoMetricName, status_);
}
void set_status(EmbeddingsModelInfoStatus status) { status_ = status; }
private:
EmbeddingsModelInfoStatus status_ = EmbeddingsModelInfoStatus::kUnknown;
};
} // namespace
PassageEmbeddingsServiceController::PassageEmbeddingsServiceController()
: embedder_(std::make_unique<SchedulingEmbedder>(
/*embedder_metadata_provider=*/this,
/*get_embeddings_callback=*/
base::BindRepeating(
&PassageEmbeddingsServiceController::GetEmbeddings,
base::Unretained(this)),
kSchedulerMaxJobs.Get(),
kSchedulerMaxBatchSize.Get(),
kUsePerformanceScenario.Get())) {}
PassageEmbeddingsServiceController::~PassageEmbeddingsServiceController() =
default;
bool PassageEmbeddingsServiceController::MaybeUpdateModelInfo(
base::optional_ref<const optimization_guide::ModelInfo> model_info) {
// Got the same version again. Do not run through rest of logic.
if (model_info && model_version_ == model_info->GetVersion()) {
return true;
}
// Reset everything, so if the model info is invalid, the service controller
// would stop accepting requests.
embeddings_model_path_.clear();
sp_model_path_.clear();
model_metadata_ = std::nullopt;
ResetEmbedderRemote();
ScopedEmbeddingsModelInfoStatusLogger logger;
if (!model_info.has_value()) {
logger.set_status(EmbeddingsModelInfoStatus::kEmpty);
return false;
}
// The only additional file should be the sentencepiece model.
base::flat_set<base::FilePath> additional_files =
model_info->GetAdditionalFiles();
if (additional_files.size() != 1u) {
logger.set_status(EmbeddingsModelInfoStatus::kInvalidAdditionalFiles);
return false;
}
// Check validity of model metadata.
const std::optional<optimization_guide::proto::Any>& metadata =
model_info->GetModelMetadata();
if (!metadata) {
logger.set_status(EmbeddingsModelInfoStatus::kNoMetadata);
return false;
}
std::optional<optimization_guide::proto::PassageEmbeddingsModelMetadata>
embeddings_metadata = optimization_guide::ParsedAnyMetadata<
optimization_guide::proto::PassageEmbeddingsModelMetadata>(*metadata);
if (!embeddings_metadata) {
logger.set_status(EmbeddingsModelInfoStatus::kInvalidMetadata);
return false;
}
model_version_ = model_info->GetVersion();
model_metadata_ = embeddings_metadata;
embeddings_model_path_ = model_info->GetModelFilePath();
sp_model_path_ = *(additional_files.begin());
CHECK(EmbedderReady());
logger.set_status(EmbeddingsModelInfoStatus::kValid);
observer_list_.Notify(&EmbedderMetadataObserver::EmbedderMetadataUpdated,
GetEmbedderMetadata());
return true;
}
void PassageEmbeddingsServiceController::LoadModelsToService(
mojo::PendingReceiver<mojom::PassageEmbedder> receiver,
mojom::PassageEmbeddingsLoadModelsParamsPtr params) {
if (!service_remote_) {
// Close the model files in a background thread.
base::ThreadPool::PostTaskAndReply(
FROM_HERE, {base::MayBlock()},
base::DoNothingWithBoundArgs(std::move(params)),
base::BindOnce(&PassageEmbeddingsServiceController::OnLoadModelsResult,
weak_ptr_factory_.GetWeakPtr(), /*success=*/false));
return;
}
service_remote_->LoadModels(
std::move(params), MakeEmbedderParams(), std::move(receiver),
base::BindOnce(&PassageEmbeddingsServiceController::OnLoadModelsResult,
weak_ptr_factory_.GetWeakPtr()));
}
void PassageEmbeddingsServiceController::OnLoadModelsResult(bool success) {
if (!success) {
ResetEmbedderRemote();
}
}
Embedder* PassageEmbeddingsServiceController::GetEmbedder() {
return embedder_.get();
}
void PassageEmbeddingsServiceController::SetEmbedderForTesting(
std::unique_ptr<Embedder> embedder) {
embedder_ = std::move(embedder);
}
void PassageEmbeddingsServiceController::AddObserver(
EmbedderMetadataObserver* observer) {
if (EmbedderReady()) {
observer->EmbedderMetadataUpdated(GetEmbedderMetadata());
}
observer_list_.AddObserver(observer);
}
void PassageEmbeddingsServiceController::RemoveObserver(
EmbedderMetadataObserver* observer) {
observer_list_.RemoveObserver(observer);
}
void PassageEmbeddingsServiceController::GetEmbeddings(
std::vector<std::string> passages,
PassagePriority priority,
GetEmbeddingsResultCallback callback) {
if (passages.empty()) {
std::move(callback).Run({}, ComputeEmbeddingsStatus::kSuccess);
return;
}
if (!EmbedderReady()) {
VLOG(1) << "Missing model path: embeddings='" << embeddings_model_path_
<< "'; sp='" << sp_model_path_ << "'";
std::move(callback).Run({}, ComputeEmbeddingsStatus::kModelUnavailable);
return;
}
if (!embedder_remote_) {
MaybeLaunchService();
auto receiver = embedder_remote_.BindNewPipeAndPassReceiver();
// Unretained is safe because `this` owns `embedder_remote_`, which
// synchronously calls the disconnect and idle handlers.
embedder_remote_.set_disconnect_handler(
base::BindOnce(&PassageEmbeddingsServiceController::ResetEmbedderRemote,
base::Unretained(this)));
embedder_remote_.set_idle_handler(
kEmbedderTimeout.Get(),
base::BindRepeating(
&PassageEmbeddingsServiceController::ResetEmbedderRemote,
base::Unretained(this)));
base::ThreadPool::PostTaskAndReplyWithResult(
FROM_HERE, {base::MayBlock()},
base::BindOnce(&MakeModelParams, embeddings_model_path_, sp_model_path_,
model_metadata_->input_window_size()),
base::BindOnce(&PassageEmbeddingsServiceController::LoadModelsToService,
weak_ptr_factory_.GetWeakPtr(), std::move(receiver)));
}
pending_requests_.push_back(next_request_id_);
embedder_remote_->GenerateEmbeddings(
std::move(passages), PassagePriorityToMojom(priority),
mojo::WrapCallbackWithDefaultInvokeIfNotRun(
base::BindOnce(&PassageEmbeddingsServiceController::OnGotEmbeddings,
weak_ptr_factory_.GetWeakPtr(), next_request_id_,
std::move(callback)),
std::vector<mojom::PassageEmbeddingsResultPtr>()));
next_request_id_++;
}
bool PassageEmbeddingsServiceController::EmbedderReady() {
return !sp_model_path_.empty() && !embeddings_model_path_.empty();
}
EmbedderMetadata PassageEmbeddingsServiceController::GetEmbedderMetadata() {
if (model_metadata_->score_threshold() > 0.0) {
return EmbedderMetadata(model_version_, model_metadata_->output_size(),
model_metadata_->score_threshold());
}
return EmbedderMetadata(model_version_, model_metadata_->output_size());
}
bool PassageEmbeddingsServiceController::EmbedderRunning() {
return !pending_requests_.empty();
}
void PassageEmbeddingsServiceController::ResetEmbedderRemote() {
embedder_remote_.reset();
}
void PassageEmbeddingsServiceController::OnGotEmbeddings(
RequestId request_id,
GetEmbeddingsResultCallback callback,
std::vector<mojom::PassageEmbeddingsResultPtr> results) {
// Mojo invokes the callbacks in the order in which `GenerateEmbeddings()` was
// called. Therefore, `request_id` should be expected at the front of
// `pending_requests_`. However, when `embedder_remote_` disconnects and the
// callbacks are dropped, `mojo::WrapCallbackWithDefaultInvokeIfNotRun()`
// invokes the callbacks in the reverse order in which they were bound.
auto it = std::ranges::find(pending_requests_, request_id);
if (it != pending_requests_.end()) {
pending_requests_.erase(it);
} else {
NOTREACHED(base::NotFatalUntil::M140);
}
auto status = results.empty() ? ComputeEmbeddingsStatus::kExecutionFailure
: ComputeEmbeddingsStatus::kSuccess;
std::move(callback).Run(std::move(results), status);
}
} // namespace passage_embeddings