blob: c77cc884d7885321d10ded45a38d1792ea7e9ed2 [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/passage_embeddings/internal/scheduling_embedder.h"
#include <memory>
#include <numeric>
#include <string>
#include <utility>
#include <vector>
#include "base/check_op.h"
#include "base/logging.h"
#include "base/metrics/histogram_functions.h"
#include "base/strings/stringprintf.h"
#include "base/time/time.h"
#include "components/passage_embeddings/passage_embeddings_types.h"
namespace passage_embeddings {
namespace {
using ScenarioScope = performance_scenarios::ScenarioScope;
using LoadingScenario = performance_scenarios::LoadingScenario;
using InputScenario = performance_scenarios::InputScenario;
std::string PassagePriorityToString(PassagePriority priority) {
switch (priority) {
case PassagePriority::kUserInitiated:
return "UserInitiated";
case PassagePriority::kUrgent:
return "Urgent";
case PassagePriority::kPassive:
return "Passive";
case PassagePriority::kLatent:
return "Latent";
}
}
void RecordDurationHistograms(PassagePriority priority,
base::TimeDelta duration) {
base::UmaHistogramTimes("History.Embeddings.ScheduledJobDuration", duration);
base::UmaHistogramTimes(
base::StringPrintf("History.Embeddings.ScheduledJobDuration.%s",
PassagePriorityToString(priority)),
duration);
}
void RecordStatusHistograms(PassagePriority priority,
ComputeEmbeddingsStatus status) {
base::UmaHistogramEnumeration("History.Embeddings.ScheduledJobStatus",
status);
base::UmaHistogramEnumeration(
base::StringPrintf("History.Embeddings.ScheduledJobStatus.%s",
PassagePriorityToString(priority)),
status);
}
} // namespace
SchedulingEmbedder::Job::Job(PassagePriority priority,
TaskId task_id,
std::vector<std::string> passages,
ComputePassagesEmbeddingsCallback callback)
: priority(priority),
task_id(task_id),
passages(std::move(passages)),
callback(std::move(callback)) {}
SchedulingEmbedder::Job::~Job() = default;
SchedulingEmbedder::Job::Job(Job&&) = default;
SchedulingEmbedder::Job& SchedulingEmbedder::Job::operator=(Job&&) = default;
////////////////////////////////////////////////////////////////////////////////
SchedulingEmbedder::SchedulingEmbedder(
EmbedderMetadataProvider* embedder_metadata_provider,
GetEmbeddingsCallback get_embeddings_callback,
size_t max_jobs,
size_t max_batch_size,
bool use_performance_scenario)
: get_embeddings_callback_(get_embeddings_callback),
max_jobs_(max_jobs),
max_batch_size_(max_batch_size),
use_performance_scenario_(use_performance_scenario) {
if (embedder_metadata_provider) {
embedder_metadata_observation_.Observe(embedder_metadata_provider);
}
if (use_performance_scenario_) {
performance_scenario_observation_.Observe(
performance_scenarios::PerformanceScenarioObserverList::GetForScope(
ScenarioScope::kGlobal)
.get());
}
}
SchedulingEmbedder::~SchedulingEmbedder() = default;
SchedulingEmbedder::TaskId SchedulingEmbedder::ComputePassagesEmbeddings(
PassagePriority priority,
std::vector<std::string> passages,
ComputePassagesEmbeddingsCallback callback) {
base::UmaHistogramCounts1000("History.Embeddings.ScheduledJobCount",
jobs_.size());
base::UmaHistogramCounts1000(
"History.Embeddings.ScheduledPassageCount",
std::accumulate(
jobs_.begin(), jobs_.end(), 0u, [](size_t sum, const Job& job) {
return sum + job.passages.size() - job.embeddings.size();
}));
const TaskId task_id = next_task_id_++;
// Zero size jobs are expected, and can be called back immediately
// instead of waiting in line for nothing.
if (passages.empty()) {
std::move(callback).Run(
/*passages=*/{}, /*embeddings=*/{}, task_id,
ComputeEmbeddingsStatus::kSuccess);
return task_id;
}
// Limit the number of jobs accepted to avoid high memory use when
// waiting a long time to process the queue.
while (jobs_.size() >= max_jobs_ && !jobs_.back().in_progress) {
FinishJob(std::move(jobs_.back()), ComputeEmbeddingsStatus::kCanceled);
jobs_.pop_back();
}
jobs_.emplace_back(priority, task_id, std::move(passages),
std::move(callback));
SubmitWorkToEmbedder();
return task_id;
}
void SchedulingEmbedder::SubmitWorkToEmbedder() {
if (!embedder_metadata_.IsValid()) {
// Underlying embedder not ready yet. Wait for it.
VLOG(5) << "SubmitWorkToEmbedder: embedder not ready";
return;
}
if (work_submitted_) {
// Waiting for work in progress to complete.
VLOG(5) << "SubmitWorkToEmbedder: work already in progress";
return;
}
if (jobs_.empty()) {
// No jobs to start.
VLOG(5) << "SubmitWorkToEmbedder: no jobs";
return;
}
if (use_performance_scenario_ && !IsPerformanceScenarioReady()) {
// Waiting for a suitable performance scenario.
VLOG(5) << "SubmitWorkToEmbedder: unsuitable scenario";
return;
}
// Put higher priority jobs at the front. This may suspend partially
// completed jobs of lower priority by pushing them toward the back.
std::stable_sort(jobs_.begin(), jobs_.end(), [](const Job& a, const Job& b) {
return a.priority < b.priority;
});
// Submit a batch of passages taken from jobs near the front of the queue.
// Only submit one priority type of passage, regardless of count.
PassagePriority priority = jobs_.front().priority;
std::vector<std::string> passages;
size_t job_index = 0;
while (passages.size() < max_batch_size_ && job_index < jobs_.size() &&
jobs_.at(job_index).priority == priority) {
Job& job = jobs_.at(job_index);
job.in_progress = true;
size_t accept = std::min(max_batch_size_ - passages.size(),
job.passages.size() - job.embeddings.size());
VLOG(3) << "Batching range [" << job.embeddings.size() << ','
<< job.embeddings.size() + accept << ") of " << job.passages.size()
<< " passages from job " << job_index << '/' << jobs_.size();
for (size_t i = job.embeddings.size();
i < job.passages.size() && accept > 0; i++, accept--) {
passages.push_back(job.passages[i]);
}
job_index++;
}
work_submitted_ = true;
get_embeddings_callback_.Run(
std::move(passages), priority,
base::BindOnce(&SchedulingEmbedder::OnEmbeddingsComputed,
weak_ptr_factory_.GetWeakPtr()));
}
bool SchedulingEmbedder::IsPerformanceScenarioReady() {
if (!jobs_.empty() &&
(jobs_.front().priority == PassagePriority::kUserInitiated ||
jobs_.front().priority == PassagePriority::kUrgent)) {
// Do not block on performance scenario if user initiated a query or it's
// urgent.
return true;
}
LoadingScenario loading_scenario =
performance_scenarios::GetLoadingScenario(ScenarioScope::kGlobal)
->load(std::memory_order_relaxed);
InputScenario input_scenario =
performance_scenarios::GetInputScenario(ScenarioScope::kGlobal)
->load(std::memory_order_relaxed);
return (loading_scenario == LoadingScenario::kNoPageLoading ||
loading_scenario == LoadingScenario::kBackgroundPageLoading) &&
input_scenario == InputScenario::kNoInput;
}
void SchedulingEmbedder::ReprioritizeTasks(PassagePriority priority,
const std::set<TaskId>& tasks) {
for (Job& job : jobs_) {
const auto loc = tasks.find(job.task_id);
if (loc != tasks.end()) {
job.priority = priority;
}
}
// Note: the jobs will be reordered to account for the new priorities on the
// next call to SubmitWorkToEmbedder().
}
bool SchedulingEmbedder::TryCancel(TaskId task_id) {
for (auto itr = jobs_.begin(); itr < jobs_.end(); itr++) {
Job& job = *itr;
if (task_id == job.task_id && !job.in_progress) {
VLOG(2) << "Aborted embedding work for " << job.passages.size()
<< " passages starting with `"
<< (job.passages.empty() ? "" : job.passages[0]) << "`";
std::move(job.callback)
.Run(std::move(job.passages), {}, job.task_id,
ComputeEmbeddingsStatus::kCanceled);
RecordStatusHistograms(job.priority, ComputeEmbeddingsStatus::kCanceled);
jobs_.erase(itr);
return true;
}
}
return false;
}
void SchedulingEmbedder::EmbedderMetadataUpdated(EmbedderMetadata metadata) {
VLOG(4) << "SchedulingEmbedder received metadata with version: "
<< metadata.model_version;
embedder_metadata_ = metadata;
SubmitWorkToEmbedder();
}
void SchedulingEmbedder::OnLoadingScenarioChanged(
ScenarioScope scope,
LoadingScenario old_scenario,
LoadingScenario new_scenario) {
VLOG(5) << "SchedulingEmbedder using new loading scenario: "
<< static_cast<int>(new_scenario);
SubmitWorkToEmbedder();
}
void SchedulingEmbedder::OnInputScenarioChanged(ScenarioScope scope,
InputScenario old_scenario,
InputScenario new_scenario) {
VLOG(5) << "SchedulingEmbedder using new input scenario: "
<< static_cast<int>(new_scenario);
SubmitWorkToEmbedder();
}
void SchedulingEmbedder::OnEmbeddingsComputed(
std::vector<mojom::PassageEmbeddingsResultPtr> results,
ComputeEmbeddingsStatus status) {
std::vector<Embedding> embeddings;
for (auto& result : results) {
embeddings.emplace_back(result->embeddings);
embeddings.back().Normalize();
}
VLOG(3) << embeddings.size() << " embeddings computed with status "
<< static_cast<int>(status);
if (embeddings.empty()) {
FinishJob(std::move(jobs_.front()), status);
jobs_.pop_front();
// Continue on to allow possibility of resuming any remaining jobs.
// This upholds the 1:1 callback requirement and gives jobs another
// chance to succeed even when primary embedder fails a batch.
// Note, we don't fail all jobs here, only the first. Failing fewer could
// result in retry loops requiring special handling in order to keep the 1:1
// callback guarantee. And failing more than the first is unnecessary since
// progress can be made while giving the later jobs another chance to
// succeed. Note, if a failure is caused by a passage from a later job
// in a batch, failing the first job may not be the optimal recovery
// strategy, but the underlying embedder is not expected to fail at all.
}
// Take embeddings into jobs and pop them as they're filled. The
// !jobs_.empty() check ensures we don't overrun the available jobs if the
// service were to maliciously send too many embeddings.
size_t read_index = 0;
while (read_index < embeddings.size() && !jobs_.empty()) {
Job& job = jobs_.front();
while (job.embeddings.size() < job.passages.size() &&
read_index < embeddings.size()) {
job.embeddings.push_back(std::move(embeddings[read_index]));
read_index++;
}
if (job.embeddings.size() == job.passages.size()) {
FinishJob(std::move(job), status);
jobs_.pop_front();
}
}
// Note, this could call back later/asynchronously or
// immediately/synchronously, depending on the embedder.
work_submitted_ = false;
SubmitWorkToEmbedder();
}
// static
void SchedulingEmbedder::FinishJob(Job job, ComputeEmbeddingsStatus status) {
VLOG(2) << "Finished embedding work with status " << static_cast<int>(status)
<< " for " << job.passages.size() << " passages starting with `"
<< job.passages[0] << "`";
if (job.passages.size() != job.embeddings.size()) {
job.embeddings.clear();
}
std::move(job.callback)
.Run(std::move(job.passages), std::move(job.embeddings), job.task_id,
status);
if (status == ComputeEmbeddingsStatus::kSuccess) {
RecordDurationHistograms(job.priority, job.timer.Elapsed());
}
RecordStatusHistograms(job.priority, status);
}
} // namespace passage_embeddings