blob: 4804f768704168645b217ab2e3ac583e6486498d [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_PASSAGE_EMBEDDINGS_INTERNAL_SCHEDULING_EMBEDDER_H_
#define COMPONENTS_PASSAGE_EMBEDDINGS_INTERNAL_SCHEDULING_EMBEDDER_H_
#include <deque>
#include <memory>
#include <optional>
#include <string>
#include <vector>
#include "base/functional/callback.h"
#include "base/memory/weak_ptr.h"
#include "base/scoped_observation.h"
#include "base/time/time.h"
#include "base/timer/elapsed_timer.h"
#include "components/passage_embeddings/passage_embeddings_types.h"
#include "components/performance_manager/scenario_api/performance_scenario_observer.h"
#include "components/performance_manager/scenario_api/performance_scenarios.h"
#include "services/passage_embeddings/public/mojom/passage_embeddings.mojom.h"
namespace passage_embeddings {
// The SchedulingEmbedder adds scheduling control with batching and priorities
// so that high priority queries can be computed as soon as possible. Scheduling
// is also needed to avoid clogging the pipes for a slow remote embedder. Even
// single pages can take a while, and when the model changes, all existing
// passages need their embeddings recomputed, which can take a very long time
// and should be done at lower priority.
class SchedulingEmbedder
: public Embedder,
public EmbedderMetadataObserver,
public performance_scenarios::PerformanceScenarioObserver {
public:
using GetEmbeddingsResultCallback = base::OnceCallback<void(
std::vector<mojom::PassageEmbeddingsResultPtr> results,
ComputeEmbeddingsStatus status)>;
using GetEmbeddingsCallback =
base::RepeatingCallback<void(std::vector<std::string> passages,
PassagePriority priority,
GetEmbeddingsResultCallback callback)>;
SchedulingEmbedder(EmbedderMetadataProvider* embedder_metadata_provider,
GetEmbeddingsCallback get_embeddings_callback,
size_t max_jobs,
size_t scheduled_max_batch_size,
bool use_performance_scenario);
~SchedulingEmbedder() override;
// Embedder:
TaskId ComputePassagesEmbeddings(
PassagePriority priority,
std::vector<std::string> passages,
ComputePassagesEmbeddingsCallback callback) override;
void ReprioritizeTasks(PassagePriority priority,
const std::set<TaskId>& tasks) override;
bool TryCancel(TaskId task_id) override;
private:
// A job consists of multiple passages, and each passage must have its
// embedding computed. When all are finished, the job is done and its
// callback will be invoked. Multiple jobs may be batched together when
// submitting work to the `embedder_remote_proxy`, and jobs can also be broken
// down so that partial progress is made across multiple work submissions.
struct Job {
Job(PassagePriority priority,
TaskId task_id,
std::vector<std::string> passages,
ComputePassagesEmbeddingsCallback callback);
~Job();
Job(const Job&) = delete;
Job& operator=(const Job&) = delete;
Job(Job&&);
Job& operator=(Job&&);
// Data for the job is saved from calls to `ComputePassagesEmbeddings`.
PassagePriority priority;
TaskId task_id;
std::vector<std::string> passages;
ComputePassagesEmbeddingsCallback callback;
bool in_progress = false;
// Completed embeddings; may be partial.
std::vector<Embedding> embeddings;
// Measures total job duration, from creation to completion.
base::ElapsedTimer timer;
};
// EmbedderMetadataObserver:
void EmbedderMetadataUpdated(EmbedderMetadata metadata) override;
// PerformanceScenarioObserver:
void OnLoadingScenarioChanged(
performance_scenarios::ScenarioScope scope,
performance_scenarios::LoadingScenario old_scenario,
performance_scenarios::LoadingScenario new_scenario) override;
void OnInputScenarioChanged(
performance_scenarios::ScenarioScope scope,
performance_scenarios::InputScenario old_scenario,
performance_scenarios::InputScenario new_scenario) override;
// Invoked after the embedding for the current job has been computed.
// Continues processing next job if one is pending.
void OnEmbeddingsComputed(
std::vector<mojom::PassageEmbeddingsResultPtr> results,
ComputeEmbeddingsStatus status);
// Stable-sort jobs by priority and submit a batch of work to embedder.
// This will only submit new work if the embedder is not already working.
void SubmitWorkToEmbedder();
// Returns true if currently in a work ready performance scenario state.
bool IsPerformanceScenarioReady();
// Call the callback with status, etc. and record relevant histograms.
static void FinishJob(Job job, ComputeEmbeddingsStatus status);
// When this is non-empty, the embedder is working and its results will be
// applied from front to back when `OnEmbeddingsComputed` is called. Not all
// of these jobs are necessarily being worked on by the embedder. It may
// contain a mix of in-progress, partially completed, and not-yet-started
// jobs. In-progress jobs are ordered first, and in the same order as
// submitted to the embedder. Partially completed jobs may follow,
// still in the order they were last submitted to the embedder.
// Not-yet-started jobs are ordered last. All jobs will be re-ordered by
// priority before submitting the next batch to the embedder.
std::deque<Job> jobs_;
// ID to assign to the next Job.
TaskId next_task_id_ = 1;
// Whether the embedder is currently working on some passages. Note, this
// is not the same concept as having a job in progress since multiple
// embedder work submissions may be required to complete a job.
bool work_submitted_ = false;
// The callback that does the actual embeddings computations.
// May be slow; await results before sending the next request.
GetEmbeddingsCallback get_embeddings_callback_;
// Metadata about the embedder; Set when valid metadata is received from
// `embedder_metadata_provider`.
EmbedderMetadata embedder_metadata_{0, 0};
// The maximum number of jobs to hold at once. Exceeding the cap
// will cause job failures on last pending jobs to avoid very high memory use.
// When the limit is reached, the last pending job is canceled instead of
// failing to accept the new job so that queries can still be accepted even
// if the queue is full of lower priority jobs awaiting performance scenario.
size_t max_jobs_;
// The maximum number of embeddings to submit to the primary embedder.
size_t max_batch_size_;
// Whether to block embedding work submission on performance scenario.
bool use_performance_scenario_;
base::ScopedObservation<
performance_scenarios::PerformanceScenarioObserverList,
performance_scenarios::PerformanceScenarioObserver>
performance_scenario_observation_{this};
base::ScopedObservation<EmbedderMetadataProvider, EmbedderMetadataObserver>
embedder_metadata_observation_{this};
base::WeakPtrFactory<SchedulingEmbedder> weak_ptr_factory_{this};
};
} // namespace passage_embeddings
#endif // COMPONENTS_PASSAGE_EMBEDDINGS_INTERNAL_SCHEDULING_EMBEDDER_H_