blob: c26a7ed778343da5452dc21bae56e7695b4760b9 [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_PASSAGE_EMBEDDINGS_PASSAGE_EMBEDDINGS_SERVICE_CONTROLLER_H_
#define COMPONENTS_PASSAGE_EMBEDDINGS_PASSAGE_EMBEDDINGS_SERVICE_CONTROLLER_H_
#include <memory>
#include <vector>
#include "base/callback_list.h"
#include "base/observer_list.h"
#include "base/timer/elapsed_timer.h"
#include "base/types/optional_ref.h"
#include "components/optimization_guide/core/delivery/model_info.h"
#include "components/optimization_guide/proto/passage_embeddings_model_metadata.pb.h"
#include "components/passage_embeddings/passage_embeddings_types.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "services/passage_embeddings/public/mojom/passage_embeddings.mojom.h"
namespace passage_embeddings {
class PassageEmbeddingsServiceController : public EmbedderMetadataProvider {
public:
PassageEmbeddingsServiceController();
~PassageEmbeddingsServiceController() override;
// Updates the paths and the metadata needed for executing the passage
// embeddings model. The original paths and metadata will be erased regardless
// of the validity of the new model paths.
// Returns true and notifies the observers if the given paths are valid.
// Virtual for testing.
virtual bool MaybeUpdateModelInfo(
base::optional_ref<const optimization_guide::ModelInfo> model_info);
// Returns true if the embedder is currently running.
bool EmbedderRunning();
// Returns the embedder used to generate embeddings.
Embedder* GetEmbedder();
protected:
// EmbedderMetadataProvider:
void AddObserver(EmbedderMetadataObserver* observer) override;
void RemoveObserver(EmbedderMetadataObserver* observer) override;
// Computes embeddings for each entry in `passages`. Will invoke `callback`
// when done. If successful, it is guaranteed that `results` will have the
// same number of passages and embeddings and in the same order as
// `passages`. Otherwise `results` will have empty passages and embeddings.
using GetEmbeddingsResultCallback = base::OnceCallback<void(
std::vector<mojom::PassageEmbeddingsResultPtr> results,
ComputeEmbeddingsStatus status)>;
using GetEmbeddingsCallback =
base::RepeatingCallback<void(std::vector<std::string> passages,
PassagePriority priority,
GetEmbeddingsResultCallback callback)>;
void GetEmbeddings(std::vector<std::string> passages,
PassagePriority priority,
GetEmbeddingsResultCallback callback);
// Returns true if this service controller is ready for embeddings generation.
bool EmbedderReady();
// Returns the metadata about the embeddings model. This is only valid when
// EmbedderReady() returns true.
EmbedderMetadata GetEmbedderMetadata();
// Launches the passage embeddings service and binds `cpu_logger_` to the
// service process. Does nothing if the service is already launched.
virtual void MaybeLaunchService() = 0;
// Resets `service_remote_` and `cpu_logger_`. Called when the service remote
// is idle or disconnects.
virtual void ResetServiceRemote() = 0;
// Resets `embedder_remote_`. Called when the model info is updated, when
// models fail to load, or when the embedder remote is idle or disconnects.
void ResetEmbedderRemote();
mojo::Remote<mojom::PassageEmbeddingsService> service_remote_;
private:
// uint64_t is large enough to never overflow.
using RequestId = uint64_t;
RequestId next_request_id_ = 0;
// Called when the model files on disks are opened and ready to be sent to
// the service.
void LoadModelsToService(
mojo::PendingReceiver<mojom::PassageEmbedder> receiver,
base::ElapsedTimer service_launch_timer,
mojom::PassageEmbeddingsLoadModelsParamsPtr params);
// Called when an attempt to load models to service finishes.
void OnLoadModelsResult(base::ElapsedTimer service_launch_timer,
bool success);
// Called when an attempt to generate embeddings finishes.
void OnGotEmbeddings(RequestId request_id,
GetEmbeddingsResultCallback callback,
base::ElapsedTimer generate_embeddings_timer,
PassagePriority priority,
std::vector<mojom::PassageEmbeddingsResultPtr> results);
// Version of the embeddings model.
int64_t model_version_;
// Metadata of the embeddings model.
std::optional<optimization_guide::proto::PassageEmbeddingsModelMetadata>
model_metadata_;
base::FilePath embeddings_model_path_;
base::FilePath sp_model_path_;
mojo::Remote<mojom::PassageEmbedder> embedder_remote_;
// Pending requests to generate embeddings.
std::vector<RequestId> pending_requests_;
// Notifies embedders that model metadata updated.
base::ObserverList<EmbedderMetadataObserver> observer_list_;
// This holds the main scheduler that receives requests from multiple clients,
// prioritizes all the jobs, and ultimately submits batches of work via
// `GetEmbeddings` when the time is right.
const std::unique_ptr<Embedder> embedder_;
// Used to generate weak pointers to self.
base::WeakPtrFactory<PassageEmbeddingsServiceController> weak_ptr_factory_{
this};
};
} // namespace passage_embeddings
#endif // COMPONENTS_PASSAGE_EMBEDDINGS_PASSAGE_EMBEDDINGS_SERVICE_CONTROLLER_H_