blob: aa6b21f27e35e95e9bd1df2a6565e97f993510eb [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "services/passage_embeddings/passage_embeddings_service.h"
#include <utility>
#include "base/files/file.h"
#include "components/optimization_guide/machine_learning_tflite_buildflags.h"
#if BUILDFLAG(BUILD_WITH_TFLITE_LIB)
#include "services/passage_embeddings/passage_embedder.h"
#endif
namespace passage_embeddings {
PassageEmbeddingsService::PassageEmbeddingsService(
mojo::PendingReceiver<mojom::PassageEmbeddingsService> receiver)
: receiver_(this, std::move(receiver)) {}
PassageEmbeddingsService::~PassageEmbeddingsService() = default;
#if BUILDFLAG(BUILD_WITH_TFLITE_LIB)
void PassageEmbeddingsService::OnEmbedderDisconnect() {
embedder_.reset();
}
#endif
void PassageEmbeddingsService::LoadModels(
mojom::PassageEmbeddingsLoadModelsParamsPtr model_params,
mojom::PassageEmbedderParamsPtr embedder_params,
mojo::PendingReceiver<mojom::PassageEmbedder> receiver,
LoadModelsCallback callback) {
#if BUILDFLAG(BUILD_WITH_TFLITE_LIB)
embedder_ = std::make_unique<PassageEmbedder>(
std::move(receiver), std::move(embedder_params),
base::BindOnce(&PassageEmbeddingsService::OnEmbedderDisconnect,
base::Unretained(this)));
// Load the model files.
if (model_params->input_window_size == 0 ||
!embedder_->LoadModels(std::move(model_params->embeddings_model),
std::move(model_params->sp_model),
model_params->input_window_size)) {
embedder_.reset();
std::move(callback).Run(false);
return;
}
std::move(callback).Run(true);
#else
std::move(callback).Run(false);
#endif
}
} // namespace passage_embeddings