| // Copyright 2024 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #ifndef COMPONENTS_HISTORY_EMBEDDINGS_HISTORY_EMBEDDINGS_SERVICE_H_ |
| #define COMPONENTS_HISTORY_EMBEDDINGS_HISTORY_EMBEDDINGS_SERVICE_H_ |
| |
| #include <atomic> |
| #include <optional> |
| #include <string> |
| #include <vector> |
| |
| #include "base/files/file_path.h" |
| #include "base/functional/callback.h" |
| #include "base/functional/callback_helpers.h" |
| #include "base/gtest_prod_util.h" |
| #include "base/memory/weak_ptr.h" |
| #include "base/threading/sequence_bound.h" |
| #include "base/time/time.h" |
| #include "base/timer/elapsed_timer.h" |
| #include "components/history/core/browser/history_service.h" |
| #include "components/history/core/browser/history_service_observer.h" |
| #include "components/history/core/browser/history_types.h" |
| #include "components/history/core/browser/url_database.h" |
| #include "components/history/core/browser/url_row.h" |
| #include "components/history_embeddings/answerer.h" |
| #include "components/history_embeddings/intent_classifier.h" |
| #include "components/history_embeddings/sql_database.h" |
| #include "components/history_embeddings/vector_database.h" |
| #include "components/keyed_service/core/keyed_service.h" |
| #include "components/optimization_guide/core/model_quality/model_quality_log_entry.h" |
| #include "components/optimization_guide/proto/features/common_quality_data.pb.h" |
| #include "components/os_crypt/async/common/encryptor.h" |
| #include "components/passage_embeddings/passage_embeddings_types.h" |
| |
| namespace optimization_guide { |
| class OptimizationGuideDecider; |
| } // namespace optimization_guide |
| |
| namespace page_content_annotations { |
| class BatchAnnotationResult; |
| class PageContentAnnotationsService; |
| } // namespace page_content_annotations |
| |
| namespace os_crypt_async { |
| class OSCryptAsync; |
| } |
| |
| namespace history_embeddings { |
| |
| // Counts the # of ' ' vanilla-space characters in `s`. |
| // TODO(crbug.com/343256907): Should work on international inputs which may: |
| // a) Use special whitespace, OR |
| // b) Not use whitespace for word breaks (e.g. Thai). |
| // `String16VectorFromString16()` is the omnibox solution. We could probably |
| // just replace-all `CountWords(s)` -> |
| // `String16VectorFromString16(CleanUpTitleForMatching(s, nullptr)).size()`. |
| size_t CountWords(const std::string& s); |
| |
| // A single item that forms part of a search result; combines metadata found in |
| // the history embeddings database with additional info from history database. |
| struct ScoredUrlRow { |
| explicit ScoredUrlRow(ScoredUrl scored_url); |
| ScoredUrlRow(const ScoredUrlRow&); |
| ScoredUrlRow(ScoredUrlRow&&); |
| ~ScoredUrlRow(); |
| ScoredUrlRow& operator=(const ScoredUrlRow&); |
| ScoredUrlRow& operator=(ScoredUrlRow&&); |
| |
| // Returns the highest scored passage in `passages_embeddings`. |
| std::string GetBestPassage() const; |
| |
| // Finds the indices of the top scores, ordered descending by score. |
| // This is useful for selecting a subset of `passages_embeddings` for use as |
| // answerer context. The size of the returned vector will be at least |
| // `min_count` provided there is sufficient data available. The |
| // `min_word_count` parameter will also be used to ensure the |
| // passages for returned indices have word counts adding up to at |
| // least this minimum. |
| std::vector<size_t> GetBestScoreIndices(size_t min_count, |
| size_t min_word_count) const; |
| |
| // Basic scoring and history data for this URL. |
| ScoredUrl scored_url; |
| history::URLRow row; |
| bool is_url_known_to_sync = false; |
| |
| // All passages and embeddings for this URL (i.e. not a partial set). |
| UrlData passages_embeddings; |
| |
| // All scores against the query for `passages_embeddings`. |
| std::vector<float> scores; |
| }; |
| |
| struct SearchResult { |
| SearchResult(); |
| SearchResult(SearchResult&&); |
| ~SearchResult(); |
| SearchResult& operator=(SearchResult&&); |
| |
| // Explicit copy only, since the `answerer_result` contains a log entry. |
| // This should only be called if `answerer_result` is not populated with |
| // a log entry yet, for example after initial search and before answering. |
| SearchResult Clone(); |
| |
| // Returns true if this search result is related to the given `other` |
| // result returned by HistoryEmbeddingsService::Search (same session/query). |
| bool IsContinuationOf(const SearchResult& other); |
| |
| // Gets the answer text from within the `answerer_result`. |
| const std::string& AnswerText() const; |
| |
| // Finds the index in `scored_url_rows` that has the URL selected by the |
| // `answerer_result`, indicating where the answer came from. |
| size_t AnswerIndex() const; |
| |
| // Session ID to associate query with answers. |
| std::string session_id; |
| |
| // Keep context for search parameters requested, to make logging easier. |
| std::string query; |
| std::optional<base::Time> time_range_start; |
| size_t count = 0; |
| SearchParams search_params; |
| |
| // The actual search result data. Note that the size of this vector will |
| // not necessarily match the above requested `count`. |
| std::vector<ScoredUrlRow> scored_url_rows; |
| |
| // This may be empty for initial embeddings search results, as the answer |
| // isn't ready yet. When the answerer finishes work, a second search |
| // result is provided with this answer filled. |
| AnswererResult answerer_result; |
| }; |
| |
| using UrlDataCallback = base::OnceCallback<void(std::optional<UrlData>)>; |
| |
| using PassagesStoredCallback = base::RepeatingCallback<void(UrlData)>; |
| |
| using SearchResultCallback = base::RepeatingCallback<void(SearchResult)>; |
| |
| using QualityLogEntry = |
| std::unique_ptr<optimization_guide::ModelQualityLogEntry>; |
| |
| class HistoryEmbeddingsService |
| : public KeyedService, |
| public history::HistoryServiceObserver, |
| public passage_embeddings::EmbedderMetadataObserver { |
| public: |
| // Number of low-order bits to use in session_id for sequence number. |
| static constexpr uint64_t kSessionIdSequenceBits = 16; |
| static constexpr uint64_t kSessionIdSequenceBitMask = |
| (1 << kSessionIdSequenceBits) - 1; |
| |
| // `history_service` is never nullptr and must outlive `this`. |
| // Storage uses its `history_dir() location for the database. |
| HistoryEmbeddingsService( |
| os_crypt_async::OSCryptAsync* os_crypt_async, |
| history::HistoryService* history_service, |
| page_content_annotations::PageContentAnnotationsService* |
| page_content_annotations_service, |
| optimization_guide::OptimizationGuideDecider* optimization_guide_decider, |
| passage_embeddings::EmbedderMetadataProvider* embedder_metadata_provider, |
| passage_embeddings::Embedder* embedder, |
| std::unique_ptr<Answerer> answerer, |
| std::unique_ptr<IntentClassifier> intent_classifier); |
| HistoryEmbeddingsService(const HistoryEmbeddingsService&) = delete; |
| HistoryEmbeddingsService& operator=(const HistoryEmbeddingsService&) = delete; |
| ~HistoryEmbeddingsService() override; |
| |
| // Identify if the given URL is eligible for history embeddings. |
| bool IsEligible(const GURL& url); |
| |
| // Called by `HistoryEmbeddingsTabHelper` when passage extraction completes. |
| // Retrieves existing passages and embeddings for `url_id` from the database |
| // before calling |
| // `ComputeAndStorePassageEmbeddingsWithExistingData()`. |
| void ComputeAndStorePassageEmbeddings(history::URLID url_id, |
| history::VisitID visit_id, |
| base::Time visit_time, |
| std::vector<std::string> passages); |
| |
| // Finds the top `count` URL visit info entries nearest to `query`. Passes the |
| // results to `callback` when search completes, whether successfully or not. |
| // Search will be narrowed to a time range if `time_range_start` is provided. |
| // In that case, the start of the time range is inclusive and the end is |
| // unbounded. Practically, this can be thought of as [start, now) but now |
| // isn't fixed. |
| // The `callback` may be called a second time with another search result |
| // containing an answer, only if `skip_answering` is false and an answer is |
| // successfully generated. This two-phase result callback scheme lets callers |
| // receive initial search results without having to wait longer for answers. |
| // The `previous_search_result` may be nullptr to signal the beginning of a |
| // completely new search session; if it is non-null and the session_id is set, |
| // the new session_id is set based on the previous to indicate a continuing |
| // search session. |
| // Returns a stub result that can be used to detect if a later published |
| // SearchResult instance is related to this search. |
| // Virtual for testing. |
| virtual SearchResult Search(SearchResult* previous_search_result, |
| std::string query, |
| std::optional<base::Time> time_range_start, |
| size_t count, |
| bool skip_answering, |
| SearchResultCallback callback); |
| |
| // Weak `this` provider method. |
| base::WeakPtr<HistoryEmbeddingsService> AsWeakPtr(); |
| |
| // Submit quality logging data after user selects an item from search result. |
| // Note, the `result` contains a log entry that will be consumed by this call. |
| void SendQualityLog(SearchResult& result, |
| std::set<size_t> selections, |
| size_t num_entered_characters, |
| optimization_guide::proto::UserFeedback user_feedback, |
| optimization_guide::proto::UiSurface ui_surface); |
| |
| // KeyedService: |
| void Shutdown() override; |
| |
| // history::HistoryServiceObserver: |
| void OnHistoryDeletions(history::HistoryService* history_service, |
| const history::DeletionInfo& deletion_info) override; |
| |
| // This can be overridden to gate answer generation for some accounts. |
| virtual bool IsAnswererUseAllowed() const; |
| |
| // Asynchronously gets passages and embeddings from storage for given |
| // `url_id`. Calls `callback` with the data or nullopt if no data is found in |
| // the HistoryEmbeddings database. |
| void GetUrlData(history::URLID url_id, UrlDataCallback callback) const; |
| |
| // Asynchronously gets passages and embeddings from storage where visits |
| // are within a given time range. Calls `callback` with the data. |
| // The `limit` and `offset` can be used to control data range with |
| // standard SQL style paging. |
| void GetUrlDataInTimeRange( |
| base::Time from_time, |
| base::Time to_time, |
| size_t limit, |
| size_t offset, |
| base::OnceCallback<void(std::vector<UrlData>)> callback) const; |
| |
| // Targeted deletion for testing scenarios like model version change. |
| void DeleteDataForTesting(bool delete_passages, |
| bool delete_embeddings, |
| base::OnceClosure callback); |
| |
| // Set a callback to be called when `ProcessAndStorePassages` completes. |
| void SetPassagesStoredCallbackForTesting(PassagesStoredCallback callback); |
| |
| private: |
| friend class HistoryEmbeddingsServicePublic; |
| |
| // A utility container to wrap anything that should be accessed on |
| // the separate storage worker sequence. |
| struct Storage { |
| Storage(const base::FilePath& storage_dir, |
| bool erase_non_ascii_characters, |
| bool delete_embeddings); |
| |
| // Associate the given metadata with this Storage instance. The storage is |
| // not considered initialized until this metadata is supplied. |
| void SetEmbedderMetadata(passage_embeddings::EmbedderMetadata metadata, |
| os_crypt_async::Encryptor encryptor); |
| |
| // Called on the worker sequence to persist passages and embeddings. |
| void ProcessAndStorePassages(UrlData url_data); |
| |
| // Runs search on worker sequence. |
| std::vector<ScoredUrlRow> Search( |
| base::WeakPtr<std::atomic<size_t>> weak_latest_query_id, |
| size_t query_id, |
| SearchParams search_params, |
| passage_embeddings::Embedding query_embedding, |
| std::optional<base::Time> time_range_start, |
| size_t count); |
| |
| // Handles the History deletions on the worker thread. |
| void HandleHistoryDeletions(bool for_all_history, |
| history::URLRows deleted_rows, |
| std::set<history::VisitID> deleted_visit_ids); |
| |
| // Targeted deletion for testing scenarios like model version change. |
| void DeleteDataForTesting(bool delete_passages, bool delete_embeddings); |
| |
| // Gathers URL and passage data from the database where corresponding |
| // embeddings are absent. This is used to rebuild the embeddings table |
| // when the model changes. |
| std::vector<UrlData> CollectPassagesWithoutEmbeddings(); |
| |
| // Retrieves passages and embeddings from the database for use as a cache |
| // to avoid recomputing embeddings that exist for identical passages. |
| std::optional<UrlData> GetUrlData(history::URLID url_id); |
| |
| // Retrieves passages and embeddings from the database that have visit times |
| // within specified range. |
| std::vector<UrlData> GetUrlDataInTimeRange(base::Time from_time, |
| base::Time to_time, |
| size_t limit, |
| size_t offset); |
| |
| // A VectorDatabase implementation that holds data in memory. |
| VectorDatabaseInMemory vector_database; |
| |
| // The underlying SQL database for persistent storage. |
| SqlDatabase sql_database; |
| }; |
| |
| // passage_embeddings::EmbedderMetadataObserver: |
| // Passes the metadata to the internal storage. |
| void EmbedderMetadataUpdated( |
| passage_embeddings::EmbedderMetadata metadata) override; |
| |
| void OnOsCryptAsyncReady(os_crypt_async::Encryptor encryptor); |
| |
| // This can be overridden to prepare a log entry that will then be filled |
| // with data and sent on destruction. Default implementation returns null. |
| virtual QualityLogEntry PrepareQualityLogEntry(); |
| |
| // Called by `ComputeAndStorePassageEmbeddings()` after retrieving existing |
| // passages and embeddings for `url_data.url_id` from the database. |
| // `existing_url_data` may be nullopt if no existing data was found. |
| void ComputeAndStorePassageEmbeddingsWithExistingData( |
| UrlData url_data, |
| std::vector<std::string> passages, |
| base::ElapsedTimer database_access_timer, |
| std::optional<UrlData> existing_url_data); |
| |
| // Invoked after the embeddings for `passages` has been computed. Stores the |
| // passages along with their embeddings in the database. |
| void OnPassagesEmbeddingsComputed( |
| UrlData url_passages, |
| std::vector<std::string> passages, |
| std::vector<passage_embeddings::Embedding> embeddings, |
| passage_embeddings::Embedder::TaskId task_id, |
| passage_embeddings::ComputeEmbeddingsStatus status); |
| |
| // Invoked after the embedding for the original search query has been |
| // computed. |
| void OnQueryEmbeddingComputed( |
| SearchResultCallback callback, |
| SearchResult result, |
| std::vector<std::string> query_passages, |
| std::vector<passage_embeddings::Embedding> query_embedding, |
| passage_embeddings::Embedder::TaskId task_id, |
| passage_embeddings::ComputeEmbeddingsStatus status); |
| |
| // Finishes a search result by combining found data with additional data from |
| // history database. Moves each ScoredUrl into a more complete structure with |
| // a history URLRow. Omits any entries that don't have corresponding data in |
| // the history database. |
| void OnSearchCompleted(SearchResultCallback callback, |
| SearchResult result, |
| std::vector<ScoredUrlRow> scored_url_rows); |
| |
| // Calls `page_content_annotation_service_` to determine whether the passage |
| // of each ScoredUrl should be shown to the user. |
| void DeterminePassageVisibility(SearchResultCallback callback, |
| SearchResult result, |
| std::vector<ScoredUrlRow> scored_url_rows); |
| |
| // Called after `page_content_annotation_service_` has determined visibility |
| // for the passage of each ScoredUrl. This will filter `scored_urls` to only |
| // contain entries that can be shown to the user. |
| void OnPassageVisibilityCalculated( |
| SearchResultCallback callback, |
| SearchResult result, |
| std::vector<ScoredUrlRow> scored_url_rows, |
| const std::vector<page_content_annotations::BatchAnnotationResult>& |
| annotation_results); |
| |
| // Called on main sequence after the history worker thread finalizes |
| // the initial search result with URL rows. Calls the `callback` and |
| // then proceeds to intent check and v2 answer generation if needed. |
| void OnPrimarySearchResultReady(SearchResultCallback callback, |
| SearchResult result); |
| |
| // Invoked after the intent classifier computes query answerability. |
| void OnQueryIntentComputed(SearchResultCallback callback, |
| SearchResult result, |
| ComputeIntentStatus status, |
| bool query_is_answerable); |
| |
| // Called after the answerer finishes computing an answer. Combines |
| // the `answer_result` into `search_result` and invokes `callback` |
| // with new search result complete with answer. |
| void OnAnswerComputed(base::Time start_time, |
| SearchResultCallback callback, |
| SearchResult search_result, |
| AnswererResult answerer_result); |
| |
| // Rebuild absent embeddings from source passages. |
| void RebuildAbsentEmbeddings(std::vector<UrlData> all_url_passages); |
| |
| // Returns true if query should be filtered. If false, then `search_params` |
| // will have its query_terms set. |
| bool QueryIsFiltered(const std::string& raw_query, |
| SearchParams& search_params) const; |
| |
| raw_ptr<os_crypt_async::OSCryptAsync> os_crypt_async_; |
| |
| // The history service is used to fill in details about URLs and visits |
| // found via search. It strictly outlives this due to the dependency |
| // specified in HistoryEmbeddingsServiceFactory. |
| raw_ptr<history::HistoryService> history_service_; |
| |
| // The page content annotations service is used to determine whether the |
| // content is safe. It strictly outlives this due to the dependency specified |
| // in `HistoryEmbeddingsServiceFactory`. Can be nullptr if the underlying |
| // capabilities are not supported. |
| raw_ptr<page_content_annotations::PageContentAnnotationsService> |
| page_content_annotations_service_; |
| |
| // Used to determine whether a page should be excluded from history |
| // embeddings. |
| raw_ptr<optimization_guide::OptimizationGuideDecider> |
| optimization_guide_decider_; |
| |
| // Tracks the observed history service, for cleanup. |
| base::ScopedObservation<history::HistoryService, |
| history::HistoryServiceObserver> |
| history_service_observation_{this}; |
| |
| // The embedder used to compute embeddings. Outlives this. |
| raw_ptr<passage_embeddings::Embedder> embedder_; |
| |
| // The answerer used to answer queries with context. May be nullptr if |
| // the kHistoryEmbeddingsAnswers feature is disabled. |
| std::unique_ptr<Answerer> answerer_; |
| |
| // The intent classifier used to determine query intent and answerability. |
| std::unique_ptr<IntentClassifier> intent_classifier_; |
| |
| // Metadata about the embedder; Set when valid metadata is received from |
| // `embedder_metadata_provider`. |
| passage_embeddings::EmbedderMetadata embedder_metadata_{0, 0}; |
| |
| // Storage is bound to a separate sequence. |
| // This will be null if the feature flag is disabled. |
| base::SequenceBound<Storage> storage_; |
| |
| // Callback called when `ProcessAndStorePassages` completes. Needed for tests |
| // as the blink dependency doesn't have a 'wait for pending requests to |
| // complete' mechanism. |
| PassagesStoredCallback passages_stored_callback_for_tests_ = |
| base::DoNothing(); |
| |
| // A thread-safe invalidation mechanism to halt searches for stale queries: |
| // Each query is run with the current `query_id_` and a weak pointer to the |
| // atomic value itself. When it changes, any queries other than the latest |
| // can be halted. Note this is not task cancellation, it breaks the inner |
| // search loop while running so the atomic is needed for thread safety. |
| std::atomic<size_t> query_id_ = 0u; |
| |
| // Used to cancel the in-flight embedding task for the previous stale query. |
| std::optional<passage_embeddings::Embedder::TaskId> query_embedding_task_id_; |
| |
| // Scoped observation for when the embedder metadata is available. |
| base::ScopedObservation<passage_embeddings::EmbedderMetadataProvider, |
| passage_embeddings::EmbedderMetadataObserver> |
| embedder_metadata_observation_{this}; |
| |
| base::WeakPtrFactory<std::atomic<size_t>> query_id_weak_ptr_factory_; |
| |
| base::WeakPtrFactory<HistoryEmbeddingsService> weak_ptr_factory_; |
| }; |
| |
| } // namespace history_embeddings |
| |
| #endif // COMPONENTS_HISTORY_EMBEDDINGS_HISTORY_EMBEDDINGS_SERVICE_H_ |