blob: b2baf5cd54a21c94e356924ee729ba8ab3ba311f [file] [log] [blame]
// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/passage_embeddings/page_embeddings_service.h"
#include <algorithm>
#include <numeric>
#include <set>
#include <utility>
#include "content/public/browser/page.h"
#include "content/public/browser/web_contents.h"
namespace passage_embeddings {
namespace {
passage_embeddings::PassagePriority ConvertToPassagePriority(
PageEmbeddingsService::Priority priority) {
switch (priority) {
case PageEmbeddingsService::kUserBlocking:
return passage_embeddings::kUserInitiated;
case PageEmbeddingsService::kUrgent:
return passage_embeddings::kUrgent;
case PageEmbeddingsService::kDefault:
return passage_embeddings::kPassive;
case PageEmbeddingsService::kBackground:
return passage_embeddings::kLatent;
}
}
} // namespace
class PageEmbeddingsService::WebContentsEventsObserver
: public content::WebContentsObserver {
public:
WebContentsEventsObserver(content::WebContents* web_contents,
PageEmbeddingsService* page_embeddings_service)
: WebContentsObserver(web_contents),
page_embeddings_service_(page_embeddings_service) {}
~WebContentsEventsObserver() override = default;
void OnVisibilityChanged(content::Visibility visibility) override {
if (visibility == content::Visibility::HIDDEN) {
page_embeddings_service_->ComputeEmbeddings(web_contents());
}
}
void WebContentsDestroyed() override {
page_embeddings_service_->web_contents_state_.erase(web_contents());
}
bool IsWebContentsHidden() const {
return web_contents()->GetVisibility() == content::Visibility::HIDDEN;
}
private:
raw_ptr<PageEmbeddingsService> page_embeddings_service_;
};
struct PageEmbeddingsService::WebContentsState {
WebContentsState();
~WebContentsState();
std::unique_ptr<WebContentsEventsObserver> observer;
// pending_passages is non-empty from the time passages are produced via
// candidates_generator_ to the time that embeddings are requested.
std::vector<std::string> pending_passages;
// The currently active task for computing embeddings. Non-empty while the
// embedding computation is pending.
std::optional<Embedder::TaskId> active_task;
// passage_embeddings is empty until embeddings are received.
std::vector<PassageEmbedding> passage_embeddings;
};
PageEmbeddingsService::ScopedPriority::ScopedPriority(
PageEmbeddingsService* service,
Observer* observer,
Priority priority)
: service_(service), observer_(observer) {
// Only one scoped priority per observer is supported.
DCHECK_EQ(0u, service_->temporary_priority_.count(observer));
// We only support raising the priority.
DCHECK_LT(priority, observer->GetDefaultPriority());
service_->temporary_priority_[observer] = priority;
if (priority < service_->current_priority_) {
service_->current_priority_ = priority;
service_->UpdateTaskPriorities(service_->current_priority_);
}
}
PageEmbeddingsService::ScopedPriority::~ScopedPriority() {
if (!service_) {
// The object has been moved-from.
return;
}
service_->temporary_priority_.erase(observer_);
Priority next_priority =
GetActivePriority(service_->observers_, service_->temporary_priority_);
if (next_priority != service_->current_priority_) {
service_->current_priority_ = next_priority;
service_->UpdateTaskPriorities(service_->current_priority_);
}
}
PageEmbeddingsService::ScopedPriority::ScopedPriority(ScopedPriority&& other) {
*this = std::move(other);
}
PageEmbeddingsService::ScopedPriority&
PageEmbeddingsService::ScopedPriority::operator=(ScopedPriority&& other) {
service_ = other.service_;
observer_ = other.observer_;
other.service_ = nullptr;
other.observer_ = nullptr;
return *this;
}
PageEmbeddingsService::PageEmbeddingsService(
EmbeddingCandidatesGenerator candidates_generator,
page_content_annotations::PageContentExtractionService*
page_content_extraction_service,
passage_embeddings::Embedder* embedder)
: candidates_generator_(candidates_generator), embedder_(embedder) {}
PageEmbeddingsService::~PageEmbeddingsService() = default;
void PageEmbeddingsService::AddObserver(Observer* observer) {
observers_.AddObserver(observer);
UpdateTaskPriorities(GetActivePriority(observers_, temporary_priority_));
}
void PageEmbeddingsService::RemoveObserver(Observer* observer) {
observers_.RemoveObserver(observer);
UpdateTaskPriorities(GetActivePriority(observers_, temporary_priority_));
}
PageEmbeddingsService::ScopedPriority PageEmbeddingsService::RaisePriority(
Observer* observer,
Priority priority) {
return ScopedPriority(this, observer, priority);
}
void PageEmbeddingsService::ProcessAllEmbeddings() {
// For the computation of embeddings for all visible tabs, which are otherwise
// only lazily computed on being hidden.
for (const auto& [web_contents, web_contents_state] : web_contents_state_) {
if (!web_contents_state.observer->IsWebContentsHidden() &&
!web_contents_state.pending_passages.empty()) {
ComputeEmbeddings(web_contents);
}
}
}
std::vector<PassageEmbedding> PageEmbeddingsService::GetEmbeddings(
content::WebContents* web_content) const {
const auto loc = web_contents_state_.find(web_content);
if (loc == web_contents_state_.end()) {
return {};
}
return loc->second.passage_embeddings;
}
void PageEmbeddingsService::OnPageContentExtracted(
content::Page& page,
const optimization_guide::proto::AnnotatedPageContent& page_content) {
auto* const web_contents =
content::WebContents::FromRenderFrameHost(&page.GetMainDocument());
auto loc = web_contents_state_.find(web_contents);
if (loc == web_contents_state_.end()) {
web_contents_state_[web_contents].observer =
std::make_unique<WebContentsEventsObserver>(web_contents, this);
}
web_contents_state_[web_contents].pending_passages =
candidates_generator_.Run(page_content, 10);
if (web_contents_state_[web_contents].observer->IsWebContentsHidden()) {
// The WebContents may have transitioned from visible to hidden by the time
// we received the passages, so compute embeddings.
ComputeEmbeddings(web_contents);
}
}
void PageEmbeddingsService::ComputeEmbeddings(
content::WebContents* web_contents) {
WebContentsState& state = web_contents_state_[web_contents];
if (state.active_task.has_value()) {
embedder_->TryCancel(*state.active_task);
state.active_task.reset();
}
// Ensure that state.pending_passages is cleared before invoking
// ComputePassagesEmbeddings().
std::vector<std::string> pending_passages;
pending_passages.swap(state.pending_passages);
state.active_task = embedder_->ComputePassagesEmbeddings(
ConvertToPassagePriority(current_priority_), std::move(pending_passages),
base::BindOnce(&PageEmbeddingsService::OnEmbeddingsComputed,
weak_ptr_factory_.GetWeakPtr(),
web_contents->GetWeakPtr()));
}
void PageEmbeddingsService::OnEmbeddingsComputed(
base::WeakPtr<content::WebContents> web_contents,
std::vector<std::string> passages,
std::vector<Embedding> embeddings,
Embedder::TaskId task_id,
ComputeEmbeddingsStatus status) {
if (!web_contents) {
// The web contents was destroyed while computing the embeddings.
return;
}
CHECK_EQ(passages.size(), embeddings.size());
std::vector<PassageEmbedding> passage_embeddings;
for (size_t i = 0; i < passages.size(); ++i) {
passage_embeddings.push_back(
{std::move(passages[i]), std::move(embeddings[i])});
}
const auto loc = web_contents_state_.find(web_contents.get());
DCHECK(loc != web_contents_state_.end());
// Ignore stale embeddings from previously cancelled tasks.
if (loc->second.active_task != task_id) {
return;
}
loc->second.active_task.reset();
if (status != passage_embeddings::ComputeEmbeddingsStatus::kSuccess) {
loc->second.passage_embeddings.clear();
return;
}
loc->second.passage_embeddings = std::move(passage_embeddings);
for (Observer& observer : observers_) {
observer.OnPageEmbeddingsAvailable(web_contents.get());
}
}
// static
PageEmbeddingsService::Priority PageEmbeddingsService::GetActivePriority(
const base::ObserverList<Observer>& observers,
const std::map<Observer*, Priority>& temporary_priority) {
const Priority highest_default_priority = std::transform_reduce(
observers.begin(), observers.end(), kDefault,
[](Priority p1, Priority p2) { return std::min(p1, p2); },
[](const Observer& observer) { return observer.GetDefaultPriority(); });
return std::transform_reduce(
temporary_priority.begin(), temporary_priority.end(),
highest_default_priority,
[](Priority p1, Priority p2) { return std::min(p1, p2); },
[](const std::map<Observer*, Priority>::value_type& pair) {
return pair.second;
});
}
void PageEmbeddingsService::UpdateTaskPriorities(Priority priority) {
if (priority == current_priority_) {
return;
}
current_priority_ = priority;
std::set<Embedder::TaskId> tasks;
for (const auto& [web_contents, web_contents_state] : web_contents_state_) {
if (web_contents_state.active_task.has_value()) {
tasks.insert(*web_contents_state.active_task);
}
}
if (!tasks.empty()) {
embedder_->ReprioritizeTasks(ConvertToPassagePriority(priority), tasks);
}
}
PageEmbeddingsService::WebContentsState::WebContentsState() = default;
PageEmbeddingsService::WebContentsState::~WebContentsState() = default;
} // namespace passage_embeddings