blob: a4cdf12ff9bd3bca8cf066a9486fffc7b2c0298b [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/history_embeddings/history_embeddings_service.h"
#include <numeric>
#include "base/feature_list.h"
#include "base/files/file_path.h"
#include "base/functional/bind.h"
#include "base/metrics/histogram_functions.h"
#include "base/task/sequenced_task_runner.h"
#include "base/task/task_traits.h"
#include "base/task/thread_pool.h"
#include "base/time/time.h"
#include "components/history/core/browser/history_types.h"
#include "components/history/core/browser/url_database.h"
#include "components/history/core/browser/url_row.h"
#include "components/history_embeddings/history_embeddings_features.h"
#include "components/history_embeddings/mock_embedder.h"
#include "components/history_embeddings/sql_database.h"
#include "components/history_embeddings/vector_database.h"
#include "components/page_content_annotations/core/page_content_annotations_service.h"
#include "mojo/public/cpp/bindings/callback_helpers.h"
#include "mojo/public/cpp/bindings/remote.h"
#include "services/service_manager/public/cpp/interface_provider.h"
#include "third_party/blink/public/mojom/content_extraction/inner_text.mojom.h"
namespace history_embeddings {
void OnGotInnerText(mojo::Remote<blink::mojom::InnerTextAgent> remote,
base::TimeTicks start_time,
base::OnceCallback<void(std::vector<std::string>)> callback,
blink::mojom::InnerTextFramePtr mojo_frame) {
std::vector<std::string> valid_passages;
const base::TimeDelta extraction_time = base::TimeTicks::Now() - start_time;
if (mojo_frame) {
for (const auto& segment : mojo_frame->segments) {
if (segment->is_text()) {
valid_passages.emplace_back(segment->get_text());
}
}
base::UmaHistogramTimes("History.Embeddings.Passages.ExtractionTime",
extraction_time);
}
// Save passages
const size_t total_text_size =
std::accumulate(valid_passages.cbegin(), valid_passages.cend(), 0u,
[](size_t acc, const std::string& passage) {
return acc + passage.size();
});
base::UmaHistogramCounts1000("History.Embeddings.Passages.PassageCount",
valid_passages.size());
base::UmaHistogramCounts10M("History.Embeddings.Passages.TotalTextSize",
total_text_size);
std::move(callback).Run(std::move(valid_passages));
}
// This is run on the HistoryService's worker thread to access the full URL
// database and finish the results for a completed embeddings search.
// Finished results are then sent to the given callback using the task_runner.
void FinishSearchResultWithHistory(
const scoped_refptr<base::SequencedTaskRunner> task_runner,
SearchResultCallback callback,
std::vector<ScoredUrl> scored_urls,
history::HistoryBackend* history_backend,
history::URLDatabase* url_database) {
SearchResult result;
if (url_database) {
// Move each ScoredUrl into a more complete ScoredUrlRow with more info from
// the history database.
result.reserve(scored_urls.size());
for (ScoredUrl& scored_url : scored_urls) {
result.emplace_back(std::move(scored_url));
if (!url_database->GetURLRow(result.back().scored_url.url_id,
&result.back().row)) {
// This omission covers an edge case and should generally not happen
// unless a notification was missed or the history database and
// history_embeddings database went out of sync. It's theoretically
// possible since operations across separate databases are not atomic.
result.pop_back();
}
}
}
task_runner->PostTask(FROM_HERE,
base::BindOnce(std::move(callback), std::move(result)));
}
////////////////////////////////////////////////////////////////////////////////
HistoryEmbeddingsService::HistoryEmbeddingsService(
history::HistoryService* history_service,
page_content_annotations::PageContentAnnotationsService*
page_content_annotations_service)
: history_service_(history_service),
page_content_annotations_service_(page_content_annotations_service),
query_id_(0u),
query_id_weak_ptr_factory_(&query_id_),
weak_ptr_factory_(this) {
if (!base::FeatureList::IsEnabled(kHistoryEmbeddings)) {
// If the feature flag is disabled, skip initialization. Note we don't also
// check the pref here, because the pref can change at runtime.
return;
}
CHECK(history_service_);
history_service_observation_.Observe(history_service_);
// Notify page content annotations service that we will need the content
// visibility model during the session.
if (page_content_annotations_service_) {
page_content_annotations_service_->RequestAndNotifyWhenModelAvailable(
page_content_annotations::AnnotationType::kContentVisibility,
base::DoNothing());
}
// TODO: b/333094780 - Swap this to the model-backed embedder once ready.
embedder_ = std::make_unique<MockEmbedder>();
storage_ = base::SequenceBound<Storage>(
base::ThreadPool::CreateSequencedTaskRunner(
{base::MayBlock(), base::TaskPriority::USER_BLOCKING,
base::TaskShutdownBehavior::BLOCK_SHUTDOWN}),
history_service_->history_dir());
}
HistoryEmbeddingsService::~HistoryEmbeddingsService() = default;
void HistoryEmbeddingsService::RetrievePassages(
const history::VisitRow& visit_row,
content::RenderFrameHost& host) {
const base::TimeTicks start_time = base::TimeTicks::Now();
mojo::Remote<blink::mojom::InnerTextAgent> agent;
host.GetRemoteInterfaces()->GetInterface(agent.BindNewPipeAndPassReceiver());
auto params = blink::mojom::InnerTextParams::New();
params->max_words_per_aggregate_passage =
std::max(0, kPassageExtractionMaxWordsPerAggregatePassage.Get());
auto* agent_ptr = agent.get();
agent_ptr->GetInnerText(
std::move(params),
mojo::WrapCallbackWithDefaultInvokeIfNotRun(
base::BindOnce(
&OnGotInnerText, std::move(agent), start_time,
base::BindOnce(&HistoryEmbeddingsService::OnPassagesRetrieved,
weak_ptr_factory_.GetWeakPtr(),
UrlPassages(visit_row.url_id, visit_row.visit_id,
visit_row.visit_time))),
nullptr));
}
void HistoryEmbeddingsService::Search(
std::string query,
std::optional<base::Time> time_range_start,
size_t count,
SearchResultCallback callback) {
embedder_->ComputePassagesEmbeddings(
{std::move(query)},
base::BindOnce(&HistoryEmbeddingsService::OnQueryEmbeddingComputed,
weak_ptr_factory_.GetWeakPtr(), time_range_start, count,
std::move(callback)));
}
void HistoryEmbeddingsService::OnQueryEmbeddingComputed(
std::optional<base::Time> time_range_start,
size_t count,
SearchResultCallback callback,
std::vector<std::string> query_passages,
std::vector<Embedding> query_embeddings) {
bool succeeded = !query_embeddings.empty();
base::UmaHistogramBoolean("History.Embeddings.QueryEmbeddingSucceeded",
succeeded);
if (!succeeded) {
// Query embedding failed. Just return no search results.
std::move(callback).Run({});
return;
}
CHECK_EQ(query_embeddings.size(), 1u);
query_id_++;
storage_.AsyncCall(&Storage::Search)
.WithArgs(query_id_weak_ptr_factory_.GetWeakPtr(), query_id_.load(),
std::move(query_embeddings.front()), time_range_start, count)
.Then(base::BindOnce(&HistoryEmbeddingsService::OnSearchCompleted,
weak_ptr_factory_.GetWeakPtr(),
std::move(callback)));
}
base::WeakPtr<HistoryEmbeddingsService> HistoryEmbeddingsService::AsWeakPtr() {
return weak_ptr_factory_.GetWeakPtr();
}
void HistoryEmbeddingsService::Shutdown() {
query_id_weak_ptr_factory_.InvalidateWeakPtrs();
weak_ptr_factory_.InvalidateWeakPtrs();
storage_.Reset();
}
void HistoryEmbeddingsService::OnHistoryDeletions(
history::HistoryService* history_service,
const history::DeletionInfo& deletion_info) {
storage_.AsyncCall(&Storage::HandleHistoryDeletions)
.WithArgs(deletion_info.IsAllHistory(), deletion_info.deleted_rows(),
deletion_info.deleted_visit_ids());
}
HistoryEmbeddingsService::Storage::Storage(const base::FilePath& storage_dir)
: sql_database(storage_dir) {}
void HistoryEmbeddingsService::Storage::ProcessAndStorePassages(
UrlPassages url_passages,
std::vector<Embedding> passages_embeddings) {
// Compute and save embeddings vectors.
UrlEmbeddings url_embeddings(url_passages);
url_embeddings.embeddings = std::move(passages_embeddings);
vector_database.AddUrlEmbeddings(std::move(url_embeddings));
vector_database.SaveTo(&sql_database);
sql_database.InsertOrReplacePassages(url_passages);
}
std::vector<ScoredUrl> HistoryEmbeddingsService::Storage::Search(
base::WeakPtr<std::atomic<size_t>> weak_latest_query_id,
size_t query_id,
Embedding query_embedding,
std::optional<base::Time> time_range_start,
size_t count) {
std::vector<ScoredUrl> scored_urls = sql_database.FindNearest(
time_range_start, count, std::move(query_embedding),
base::BindRepeating(
[](base::WeakPtr<std::atomic<size_t>> weak_latest_query_id,
size_t query_id) {
// If the service shut down or started a new query, this one is no
// longer needed. Signal to exit early. Best result so far will be
// returned.
return !weak_latest_query_id || *weak_latest_query_id != query_id;
},
std::move(weak_latest_query_id), query_id));
// Populate source passages.
for (ScoredUrl& scored_url : scored_urls) {
std::optional<proto::PassagesValue> value =
sql_database.GetPassages(scored_url.url_id);
if (value &&
scored_url.index < static_cast<size_t>(value.value().passages_size())) {
scored_url.passage = value.value().passages(scored_url.index);
}
}
return scored_urls;
}
void HistoryEmbeddingsService::Storage::HandleHistoryDeletions(
bool for_all_history,
history::URLRows deleted_rows,
std::set<history::VisitID> deleted_visit_ids) {
if (for_all_history) {
sql_database.DeleteAllData();
return;
}
for (history::URLRow url_row : deleted_rows) {
sql_database.DeleteDataForUrlId(url_row.id());
}
for (history::VisitID visit_id : deleted_visit_ids) {
sql_database.DeleteDataForVisitId(visit_id);
}
}
void HistoryEmbeddingsService::OnPassagesRetrieved(
UrlPassages url_passages,
std::vector<std::string> passages) {
embedder_->ComputePassagesEmbeddings(
std::move(passages),
base::BindOnce(&HistoryEmbeddingsService::OnPassagesEmbeddingsComputed,
weak_ptr_factory_.GetWeakPtr(), std::move(url_passages)));
}
void HistoryEmbeddingsService::OnPassagesEmbeddingsComputed(
UrlPassages url_passages,
std::vector<std::string> passages,
std::vector<Embedding> passages_embeddings) {
url_passages.passages.mutable_passages()->Assign(
std::make_move_iterator(passages.begin()),
std::make_move_iterator(passages.end()));
storage_.AsyncCall(&Storage::ProcessAndStorePassages)
.WithArgs(url_passages, std::move(passages_embeddings))
.Then(base::BindOnce(callback_for_tests_, url_passages));
}
void HistoryEmbeddingsService::OnSearchCompleted(
SearchResultCallback callback,
std::vector<ScoredUrl> scored_urls) {
// TODO(b/330925683): Handle search interruption. This may not still need to
// happen by now.
DeterminePassageVisibility(std::move(callback), std::move(scored_urls));
}
void HistoryEmbeddingsService::DeterminePassageVisibility(
SearchResultCallback callback,
std::vector<ScoredUrl> scored_urls) {
bool is_visibility_model_available =
page_content_annotations_service_ &&
page_content_annotations_service_->GetModelInfoForType(
page_content_annotations::AnnotationType::kContentVisibility);
base::UmaHistogramCounts100("History.Embeddings.NumUrlsMatched",
scored_urls.size());
base::UmaHistogramBoolean(
"History.Embeddings.VisibilityModelAvailableAtQuery",
is_visibility_model_available);
if (!is_visibility_model_available) {
OnPassageVisibilityCalculated(std::move(callback), std::move(scored_urls),
{});
return;
}
std::vector<std::string> inputs;
inputs.reserve(scored_urls.size());
for (const ScoredUrl& url : scored_urls) {
inputs.emplace_back(url.passage);
}
page_content_annotations_service_->BatchAnnotate(
base::BindOnce(&HistoryEmbeddingsService::OnPassageVisibilityCalculated,
weak_ptr_factory_.GetWeakPtr(), std::move(callback),
std::move(scored_urls)),
std::move(inputs),
page_content_annotations::AnnotationType::kContentVisibility);
}
void HistoryEmbeddingsService::OnPassageVisibilityCalculated(
SearchResultCallback callback,
std::vector<ScoredUrl> scored_urls,
const std::vector<page_content_annotations::BatchAnnotationResult>&
annotation_results) {
if (annotation_results.empty()) {
scored_urls.clear();
} else {
CHECK_EQ(scored_urls.size(), annotation_results.size());
// Filter for scored URLs that are ok to be shown to the user.
auto urls_it = scored_urls.begin();
for (const page_content_annotations::BatchAnnotationResult& result :
annotation_results) {
if (result.visibility_score().value_or(0.0) <=
kContentVisibilityThreshold.Get()) {
urls_it = scored_urls.erase(urls_it);
} else {
++urls_it;
}
}
}
base::UmaHistogramCounts100("History.Embeddings.NumMatchedUrlsVisible",
scored_urls.size());
if (scored_urls.empty()) {
std::move(callback).Run({});
return;
}
// Use the callback task mechanism for simplicity and easier control with
// other standard async machinery.
history_service_->ScheduleDBTaskForUI(
base::BindOnce(&FinishSearchResultWithHistory,
base::SequencedTaskRunner::GetCurrentDefault(),
std::move(callback), std::move(scored_urls)));
}
} // namespace history_embeddings