| // Copyright 2024 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #ifndef COMPONENTS_HISTORY_EMBEDDINGS_VECTOR_DATABASE_H_ |
| #define COMPONENTS_HISTORY_EMBEDDINGS_VECTOR_DATABASE_H_ |
| |
| #include <optional> |
| #include <vector> |
| |
| #include "base/time/time.h" |
| #include "components/history/core/browser/history_types.h" |
| #include "components/history_embeddings/proto/history_embeddings.pb.h" |
| #include "components/keyed_service/core/keyed_service.h" |
| |
| namespace history_embeddings { |
| |
| struct UrlPassages { |
| UrlPassages(history::URLID url_id, |
| history::VisitID visit_id, |
| base::Time visit_time); |
| ~UrlPassages(); |
| UrlPassages(const UrlPassages&); |
| UrlPassages& operator=(const UrlPassages&); |
| UrlPassages(UrlPassages&&); |
| UrlPassages& operator=(UrlPassages&&); |
| bool operator==(const UrlPassages&) const; |
| |
| history::URLID url_id; |
| history::VisitID visit_id; |
| base::Time visit_time; |
| proto::PassagesValue passages; |
| }; |
| |
| class Embedding { |
| public: |
| explicit Embedding(std::vector<float> data); |
| Embedding(std::vector<float> data, size_t passage_word_count); |
| Embedding(); |
| ~Embedding(); |
| Embedding(const Embedding&); |
| Embedding& operator=(const Embedding&); |
| Embedding(Embedding&&); |
| Embedding& operator=(Embedding&&); |
| bool operator==(const Embedding&) const; |
| |
| // The number of elements in the data vector. |
| size_t Dimensions() const; |
| |
| // The length of the vector. |
| float Magnitude() const; |
| |
| // Scale the vector to unit length. |
| void Normalize(); |
| |
| // Compares one embedding with another and returns a similarity measure. |
| float ScoreWith(const std::string& other_passage, |
| const Embedding& other_embedding) const; |
| |
| // Const accessor used for storage. |
| const std::vector<float>& GetData() const { return data_; } |
| |
| // Used for search filtering of passages with low word count. |
| size_t GetPassageWordCount() const { return passage_word_count_; } |
| void SetPassageWordCount(size_t passage_word_count) { |
| passage_word_count_ = passage_word_count; |
| } |
| |
| private: |
| std::vector<float> data_; |
| size_t passage_word_count_ = 0; |
| }; |
| |
| struct UrlEmbeddings { |
| UrlEmbeddings(); |
| UrlEmbeddings(history::URLID url_id, |
| history::VisitID visit_id, |
| base::Time visit_time); |
| explicit UrlEmbeddings(const UrlPassages& url_passages); |
| ~UrlEmbeddings(); |
| UrlEmbeddings(UrlEmbeddings&&); |
| UrlEmbeddings& operator=(UrlEmbeddings&&); |
| UrlEmbeddings(const UrlEmbeddings&); |
| UrlEmbeddings& operator=(const UrlEmbeddings&); |
| bool operator==(const UrlEmbeddings&) const; |
| |
| // Finds score of embedding nearest to query, also taking passages |
| // into consideration since some should be skipped. The passages |
| // correspond to the embeddings 1:1 by index. |
| float BestScoreWith(const Embedding& query, |
| const proto::PassagesValue& passages, |
| size_t search_minimum_word_count) const; |
| |
| history::URLID url_id; |
| history::VisitID visit_id; |
| base::Time visit_time; |
| std::vector<Embedding> embeddings; |
| }; |
| |
| struct ScoredUrl { |
| ScoredUrl(history::URLID url_id, |
| history::VisitID visit_id, |
| base::Time visit_time, |
| float score); |
| ~ScoredUrl(); |
| ScoredUrl(ScoredUrl&&); |
| ScoredUrl& operator=(ScoredUrl&&); |
| ScoredUrl(const ScoredUrl&); |
| ScoredUrl& operator=(const ScoredUrl&); |
| |
| // Basic data about the found URL/visit. |
| history::URLID url_id; |
| history::VisitID visit_id; |
| base::Time visit_time; |
| |
| // A measure of how closely the query matched the found data. |
| float score; |
| }; |
| |
| struct SearchInfo { |
| SearchInfo(); |
| SearchInfo(SearchInfo&&); |
| ~SearchInfo(); |
| |
| // Result of the search, the best scored URLs. |
| std::vector<ScoredUrl> scored_urls; |
| |
| // The number of URLs searched to find this result. |
| size_t searched_url_count = 0u; |
| |
| // The number of embeddings searched to find this result. |
| size_t searched_embedding_count = 0u; |
| |
| // Whether the search completed without interruption. Starting a new search |
| // may cause a search to halt, and in that case this member will be false. |
| bool completed = false; |
| }; |
| |
| struct UrlPassagesEmbeddings { |
| UrlPassagesEmbeddings(history::URLID url_id, |
| history::VisitID visit_id, |
| base::Time visit_time); |
| UrlPassagesEmbeddings(const UrlPassagesEmbeddings&); |
| UrlPassagesEmbeddings& operator=(const UrlPassagesEmbeddings&); |
| bool operator==(const UrlPassagesEmbeddings&) const; |
| |
| UrlPassages url_passages; |
| UrlEmbeddings url_embeddings; |
| }; |
| |
| // This base class decouples storage classes and inverts the dependency so that |
| // a vector database can work with a SQLite database, simple in-memory storage, |
| // flat files, or whatever kinds of storage will work efficiently. |
| class VectorDatabase { |
| public: |
| struct UrlDataIterator { |
| virtual ~UrlDataIterator() = default; |
| |
| // Returns nullptr if none remain; otherwise advances the iterator |
| // and returns a pointer to the next instance (which may be owned |
| // by the iterator itself). |
| virtual const UrlPassagesEmbeddings* Next() = 0; |
| }; |
| |
| virtual ~VectorDatabase() = default; |
| |
| // Returns the expected number of dimensions for an embedding. |
| virtual size_t GetEmbeddingDimensions() const = 0; |
| |
| // Insert or update all embeddings for a URL's full set of passages. |
| // Returns true on success. |
| virtual bool AddUrlData(UrlPassagesEmbeddings url_passages_embeddings) = 0; |
| |
| // Create an iterator that steps through database items. |
| // Null may be returned if there are none. |
| virtual std::unique_ptr<UrlDataIterator> MakeUrlDataIterator( |
| std::optional<base::Time> time_range_start) = 0; |
| |
| // Searches the database for embeddings near given `query` and returns |
| // information about where they were found and how nearly the query matched. |
| SearchInfo FindNearest(std::optional<base::Time> time_range_start, |
| size_t count, |
| const Embedding& query, |
| base::RepeatingCallback<bool()> is_search_halted); |
| }; |
| |
| // This is an in-memory vector store that supports searching and saving to |
| // another persistent backing store. |
| class VectorDatabaseInMemory : public VectorDatabase { |
| public: |
| VectorDatabaseInMemory(); |
| ~VectorDatabaseInMemory() override; |
| |
| // Save this store's data to another given store. Most implementations don't |
| // need this, but it's useful for an in-memory store to work with a separate |
| // backing database on a worker sequence. |
| void SaveTo(VectorDatabase* database); |
| |
| // VectorDatabase: |
| size_t GetEmbeddingDimensions() const override; |
| bool AddUrlData(UrlPassagesEmbeddings url_passages_embeddings) override; |
| std::unique_ptr<UrlDataIterator> MakeUrlDataIterator( |
| std::optional<base::Time> time_range_start) override; |
| |
| private: |
| std::vector<UrlPassagesEmbeddings> data_; |
| }; |
| |
| } // namespace history_embeddings |
| |
| #endif // COMPONENTS_HISTORY_EMBEDDINGS_VECTOR_DATABASE_H_ |