blob: 464290eec75cfa808bb5bd58312dbe4f5c4f234f [file] [log] [blame]
// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/contextual_tasks/contextual_tasks_context_service.h"
#include "base/metrics/histogram_functions.h"
#include "base/strings/stringprintf.h"
#include "base/task/single_thread_task_runner.h"
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service.h"
#include "chrome/browser/passage_embeddings/page_embeddings_service.h"
#include "chrome/browser/profiles/profile.h"
#include "chrome/browser/ui/browser_window/public/browser_window_interface.h"
#include "chrome/browser/ui/browser_window/public/browser_window_interface_iterator.h"
#include "chrome/browser/ui/tabs/tab_strip_model.h"
#include "components/contextual_tasks/public/features.h"
#include "components/passage_embeddings/passage_embeddings_types.h"
#include "content/public/browser/web_contents.h"
namespace contextual_tasks {
namespace {
// Convenience macro for emitting OPTIMIZATION_GUIDE_LOGs where
// optimization_keyed_service_ is defined.
#define AUTO_CONTEXT_LOG(message) \
OPTIMIZATION_GUIDE_LOG( \
optimization_guide_common::mojom::LogSource::CONTEXTUAL_TASKS_CONTEXT, \
optimization_guide_keyed_service_->GetOptimizationGuideLogger(), \
(message))
void RecordContextDeterminationStatus(ContextDeterminationStatus status) {
base::UmaHistogramEnumeration(
"ContextualTasks.Context.ContextDeterminationStatus", status);
}
} // namespace
ContextualTasksContextService::ContextualTasksContextService(
Profile* profile,
passage_embeddings::PageEmbeddingsService* page_embeddings_service,
passage_embeddings::EmbedderMetadataProvider* embedder_metadata_provider,
passage_embeddings::Embedder* embedder,
OptimizationGuideKeyedService* optimization_guide_keyed_service)
: profile_(profile),
page_embeddings_service_(page_embeddings_service),
embedder_metadata_provider_(embedder_metadata_provider),
embedder_(embedder),
optimization_guide_keyed_service_(optimization_guide_keyed_service) {
scoped_observation_.Observe(embedder_metadata_provider_);
}
ContextualTasksContextService::~ContextualTasksContextService() = default;
void ContextualTasksContextService::GetRelevantTabsForQuery(
const std::string& query,
base::OnceCallback<void(std::vector<content::WebContents*>)> callback) {
base::TimeTicks now = base::TimeTicks::Now();
AUTO_CONTEXT_LOG(base::StringPrintf("Processing query %s", query));
if (!is_embedder_available_) {
AUTO_CONTEXT_LOG("Embedder not available");
RecordContextDeterminationStatus(
ContextDeterminationStatus::kEmbedderNotAvailable);
base::SingleThreadTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, base::BindOnce(std::move(callback),
std::vector<content::WebContents*>({})));
return;
}
// Force active tab embedding to be processed.
page_embeddings_service_->ProcessAllEmbeddings();
AUTO_CONTEXT_LOG("Submitted query to embedder");
embedder_->ComputePassagesEmbeddings(
passage_embeddings::PassagePriority::kUrgent, {query},
base::BindOnce(&ContextualTasksContextService::OnQueryEmbeddingReady,
weak_ptr_factory_.GetWeakPtr(), query, now,
std::move(callback)));
}
void ContextualTasksContextService::EmbedderMetadataUpdated(
passage_embeddings::EmbedderMetadata metadata) {
is_embedder_available_ = metadata.IsValid();
}
void ContextualTasksContextService::OnQueryEmbeddingReady(
const std::string& query,
base::TimeTicks start_time,
base::OnceCallback<void(std::vector<content::WebContents*>)> callback,
std::vector<std::string> passages,
std::vector<passage_embeddings::Embedding> embeddings,
passage_embeddings::Embedder::TaskId task_id,
passage_embeddings::ComputeEmbeddingsStatus status) {
// Query embedding was not successfully generated.
if (status != passage_embeddings::ComputeEmbeddingsStatus::kSuccess) {
AUTO_CONTEXT_LOG(
base::StringPrintf("Query embedding for %s failed", query));
RecordContextDeterminationStatus(
ContextDeterminationStatus::kQueryEmbeddingFailed);
std::move(callback).Run({});
return;
}
// Unexpected output size. Just return.
if (embeddings.size() != 1u) {
AUTO_CONTEXT_LOG(base::StringPrintf(
"Query embedding for %s had unexpected output", query));
RecordContextDeterminationStatus(
ContextDeterminationStatus::kQueryEmbeddingOutputMalformed);
std::move(callback).Run({});
return;
}
RecordContextDeterminationStatus(ContextDeterminationStatus::kSuccess);
AUTO_CONTEXT_LOG(
base::StringPrintf("Processing query embedding for %s", query));
passage_embeddings::Embedding query_embedding = embeddings[0];
// Collect relevant web contents.
// TODO: crbug.com/452056256 - Include other criteria other than embedding
// score.
std::vector<content::WebContents*> relevant_web_contents;
int all_browsers_tab_count = 0;
ForEachCurrentBrowserWindowInterfaceOrderedByActivation(
[this, profile = profile_, &relevant_web_contents,
&all_browsers_tab_count, &query_embedding,
&query](BrowserWindowInterface* browser) {
if (browser->GetProfile() != profile) {
return true;
}
TabStripModel* const tab_strip_model = browser->GetTabStripModel();
const int tab_count = tab_strip_model->count();
all_browsers_tab_count += tab_count;
for (int i = 0; i < tab_count; i++) {
content::WebContents* web_contents =
tab_strip_model->GetWebContentsAt(i);
if (!web_contents) {
continue;
}
if (!web_contents->GetLastCommittedURL().SchemeIsHTTPOrHTTPS()) {
continue;
}
// See if any passage embeddings are closely related to the query
// embedding. Just add if at least one is high enough.
std::vector<passage_embeddings::PassageEmbedding>
web_contents_embeddings =
page_embeddings_service_->GetEmbeddings(web_contents);
AUTO_CONTEXT_LOG(base::StringPrintf(
"Comparing query embedding to %llu embeddings for %s",
web_contents_embeddings.size(),
web_contents->GetLastCommittedURL().spec()));
for (const auto& embedding : web_contents_embeddings) {
if (kOnlyUseTitlesForSimilarity.Get() &&
embedding.passage.second !=
passage_embeddings::PassageType::kTitle) {
continue;
}
float similarity_score =
embedding.embedding.ScoreWith(query_embedding);
AUTO_CONTEXT_LOG(base::StringPrintf(
"Similarity with passage %s and query %s: %f",
embedding.passage.first, query, similarity_score));
if (similarity_score > kMinEmbeddingSimilarityScore.Get()) {
AUTO_CONTEXT_LOG(base::StringPrintf(
"Adding %s to relevant set",
web_contents->GetLastCommittedURL().spec()));
relevant_web_contents.push_back(web_contents);
break;
}
}
}
return true;
});
AUTO_CONTEXT_LOG(base::StringPrintf("Number of open tabs for query %s: %d",
query, all_browsers_tab_count));
AUTO_CONTEXT_LOG(
base::StringPrintf("Number of relevant tabs for query %s: %d", query,
relevant_web_contents.size()));
base::UmaHistogramTimes("ContextualTasks.Context.ContextCalculationLatency",
base::TimeTicks::Now() - start_time);
base::UmaHistogramCounts100("ContextualTasks.Context.RelevantTabsCount",
relevant_web_contents.size());
std::move(callback).Run(std::move(relevant_web_contents));
}
} // namespace contextual_tasks