blob: 1fb06a49fed936a25ec8c2c0632131ef58fc81aa [file] [log] [blame]
// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/passage_embeddings/passage_embeddings_coordinator.h"
#include <algorithm>
#include <string>
#include <string_view>
#include <vector>
#include "base/functional/bind.h"
#include "base/logging.h"
#include "base/metrics/histogram_functions.h"
#include "base/strings/strcat.h"
#include "chrome/browser/page_content_annotations/page_content_extraction_service.h"
#include "chrome/browser/passage_embeddings/chrome_passage_embeddings_service_controller.h"
#include "components/passage_embeddings/passage_embeddings_features.h"
#include "content/public/browser/render_frame_host.h"
#include "content/public/browser/web_contents.h"
namespace passage_embeddings {
namespace {
std::unique_ptr<WebContentsPassageEmbedder> CreateWebContentsPassageEmbedder(
content::WebContents* web_contents,
WebContentsPassageEmbedder::Delegate& delegate) {
if (kUseBackgroundPassageEmbedder.Get()) {
return std::make_unique<WebContentsBackgroundPassageEmbedder>(web_contents,
delegate);
}
return std::make_unique<WebContentsImmediatePassageEmbedder>(web_contents,
delegate);
}
void CollectTextForContentNode(
const optimization_guide::proto::ContentNode& node,
std::vector<std::string>& text) {
if (!node.has_content_attributes()) {
return;
}
const auto& attributes = node.content_attributes();
switch (attributes.attribute_type()) {
case optimization_guide::proto::ContentAttributeType::
CONTENT_ATTRIBUTE_TABLE:
if (!attributes.table_data().table_name().empty()) {
text.push_back(attributes.table_data().table_name());
}
break;
case optimization_guide::proto::ContentAttributeType::
CONTENT_ATTRIBUTE_TEXT:
if (!attributes.text_data().text_content().empty()) {
text.push_back(attributes.text_data().text_content());
}
break;
case optimization_guide::proto::ContentAttributeType::
CONTENT_ATTRIBUTE_IMAGE:
if (!attributes.image_data().image_caption().empty()) {
text.push_back(attributes.image_data().image_caption());
}
break;
default:
break;
}
}
void CollectTextForContentNodesRecursively(
const optimization_guide::proto::ContentNode& node,
std::vector<std::string>& text) {
CollectTextForContentNode(node, text);
for (const auto& child : node.children_nodes()) {
CollectTextForContentNodesRecursively(child, text);
}
}
int CountWords(std::string_view s) {
if (s.empty()) {
return 0;
}
int word_count = (s[0] == ' ') ? 0 : 1;
for (size_t i = 1; i < s.length(); i++) {
if (s[i] != ' ' && s[i - 1] == ' ') {
word_count++;
}
}
return word_count;
}
// Provide a translation of APC to passages for the purposes of measuring
// embeddings performance. This translation is extremely simple and not intended
// to be a full fidelity representation or to be reused outside of this limited
// experiment.
std::vector<std::string> CreatePassagesFromAnnotatedPageContent(
const optimization_guide::proto::AnnotatedPageContent&
annotated_page_content) {
std::vector<std::string> text;
CollectTextForContentNodesRecursively(annotated_page_content.root_node(),
text);
if (text.empty()) {
return {};
}
const auto append_with_whitespace_separator =
[](std::string& str, std::string_view str_to_append) {
if (str_to_append.empty()) {
return;
}
if (str.empty() || str.back() == ' ') {
str.append(str_to_append);
return;
}
base::StrAppend(&str, {" ", str_to_append});
};
const int max_words_per_aggregate_passage = kMinWordsPerPassage.Get();
const int min_words_per_passage = kMinWordsPerPassage.Get();
const int max_passages_per_page = kMaxPassagesPerPage.Get();
std::vector<std::string> passages;
passages.push_back("");
int current_passage_words = 0;
for (const std::string& item : text) {
const int item_words = CountWords(item);
if (current_passage_words >= max_words_per_aggregate_passage) {
if (passages.size() >= static_cast<size_t>(max_passages_per_page)) {
break;
}
passages.push_back("");
current_passage_words = 0;
}
const bool should_append =
current_passage_words < min_words_per_passage ||
current_passage_words + item_words <= max_words_per_aggregate_passage;
if (should_append) {
append_with_whitespace_separator(passages.back(), item);
current_passage_words += item_words;
}
}
if (passages.back() == "") {
passages.pop_back();
}
return passages;
}
} // namespace
PassageEmbeddingsCoordinator::PassageEmbeddingsCoordinator(
page_content_annotations::PageContentExtractionService*
page_content_extraction_service)
: omnibox_focus_changed_listener_(base::BindRepeating(
&PassageEmbeddingsCoordinator::OnOmniboxFocusChanged,
base::Unretained(this))) {
page_content_extraction_observation_.Observe(page_content_extraction_service);
}
PassageEmbeddingsCoordinator::~PassageEmbeddingsCoordinator() = default;
void PassageEmbeddingsCoordinator::OnPageContentExtracted(
content::Page& page,
const optimization_guide::proto::AnnotatedPageContent& page_content) {
std::vector<std::string> passages =
CreatePassagesFromAnnotatedPageContent(page_content);
VLOG(2) << "Received page content for url "
<< page_content.main_frame_data().url() << ". Generated "
<< passages.size() << " passages.";
auto* const web_contents =
content::WebContents::FromRenderFrameHost(&page.GetMainDocument());
auto loc = web_contents_passage_embedders_.find(web_contents);
if (loc == web_contents_passage_embedders_.end()) {
loc = web_contents_passage_embedders_
.emplace(web_contents,
CreateWebContentsPassageEmbedder(web_contents, *this))
.first;
}
loc->second->AcceptPassages(std::move(passages));
}
Embedder::TaskId PassageEmbeddingsCoordinator::ComputePassagesEmbeddings(
std::vector<std::string> passages,
Embedder::ComputePassagesEmbeddingsCallback callback) {
return ChromePassageEmbeddingsServiceController::Get()
->GetEmbedder()
->ComputePassagesEmbeddings(current_priority_, std::move(passages),
std::move(callback));
}
bool PassageEmbeddingsCoordinator::TryCancel(Embedder::TaskId task_id) {
return ChromePassageEmbeddingsServiceController::Get()
->GetEmbedder()
->TryCancel(task_id);
}
void PassageEmbeddingsCoordinator::OnWebContentsDestroyed(
content::WebContents* web_contents) {
web_contents_passage_embedders_.erase(web_contents);
}
void PassageEmbeddingsCoordinator::OnOmniboxFocusChanged(bool is_focused) {
current_priority_ = is_focused ? kUrgent : kPassive;
std::set<Embedder::TaskId> task_ids;
for (const auto& [web_contents, web_contents_passage_embedder] :
web_contents_passage_embedders_) {
if (current_priority_ == kUrgent) {
web_contents_passage_embedder
->MaybeProcessPendingPassagesOnPriorityIncrease();
}
std::optional<Embedder::TaskId> task_id =
web_contents_passage_embedder->current_task_id();
if (task_id) {
task_ids.insert(*task_id);
}
}
ChromePassageEmbeddingsServiceController::Get()
->GetEmbedder()
->ReprioritizeTasks(current_priority_, task_ids);
}
} // namespace passage_embeddings