blob: 9455c032ac71da7ab432c354fe704db3d3a58d4b [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef COMPONENTS_HISTORY_EMBEDDINGS_VECTOR_DATABASE_H_
#define COMPONENTS_HISTORY_EMBEDDINGS_VECTOR_DATABASE_H_
#include <optional>
#include <vector>
#include "base/time/time.h"
#include "components/history/core/browser/history_types.h"
#include "components/history_embeddings/proto/history_embeddings.pb.h"
#include "components/keyed_service/core/keyed_service.h"
namespace history_embeddings {
struct UrlPassages {
UrlPassages(history::URLID url_id,
history::VisitID visit_id,
base::Time visit_time);
~UrlPassages();
UrlPassages(const UrlPassages&);
UrlPassages& operator=(const UrlPassages&);
UrlPassages(UrlPassages&&);
UrlPassages& operator=(UrlPassages&&);
bool operator==(const UrlPassages&) const;
history::URLID url_id;
history::VisitID visit_id;
base::Time visit_time;
proto::PassagesValue passages;
};
class Embedding {
public:
explicit Embedding(std::vector<float> data);
Embedding(std::vector<float> data, size_t passage_word_count);
Embedding();
~Embedding();
Embedding(const Embedding&);
Embedding& operator=(const Embedding&);
Embedding(Embedding&&);
Embedding& operator=(Embedding&&);
bool operator==(const Embedding&) const;
// The number of elements in the data vector.
size_t Dimensions() const;
// The length of the vector.
float Magnitude() const;
// Scale the vector to unit length.
void Normalize();
// Compares one embedding with another and returns a similarity measure.
float ScoreWith(const std::string& other_passage,
const Embedding& other_embedding) const;
// Const accessor used for storage.
const std::vector<float>& GetData() const { return data_; }
// Used for search filtering of passages with low word count.
size_t GetPassageWordCount() const { return passage_word_count_; }
void SetPassageWordCount(size_t passage_word_count) {
passage_word_count_ = passage_word_count;
}
private:
std::vector<float> data_;
size_t passage_word_count_ = 0;
};
struct UrlEmbeddings {
UrlEmbeddings();
UrlEmbeddings(history::URLID url_id,
history::VisitID visit_id,
base::Time visit_time);
explicit UrlEmbeddings(const UrlPassages& url_passages);
~UrlEmbeddings();
UrlEmbeddings(UrlEmbeddings&&);
UrlEmbeddings& operator=(UrlEmbeddings&&);
UrlEmbeddings(const UrlEmbeddings&);
UrlEmbeddings& operator=(const UrlEmbeddings&);
bool operator==(const UrlEmbeddings&) const;
// Finds score of embedding nearest to query, also taking passages
// into consideration since some should be skipped. The passages
// correspond to the embeddings 1:1 by index.
float BestScoreWith(const Embedding& query,
const proto::PassagesValue& passages,
size_t search_minimum_word_count) const;
history::URLID url_id;
history::VisitID visit_id;
base::Time visit_time;
std::vector<Embedding> embeddings;
};
struct ScoredUrl {
ScoredUrl(history::URLID url_id,
history::VisitID visit_id,
base::Time visit_time,
float score);
~ScoredUrl();
ScoredUrl(ScoredUrl&&);
ScoredUrl& operator=(ScoredUrl&&);
ScoredUrl(const ScoredUrl&);
ScoredUrl& operator=(const ScoredUrl&);
// Basic data about the found URL/visit.
history::URLID url_id;
history::VisitID visit_id;
base::Time visit_time;
// A measure of how closely the query matched the found data.
float score;
};
struct SearchInfo {
SearchInfo();
SearchInfo(SearchInfo&&);
~SearchInfo();
// Result of the search, the best scored URLs.
std::vector<ScoredUrl> scored_urls;
// The number of URLs searched to find this result.
size_t searched_url_count = 0u;
// The number of embeddings searched to find this result.
size_t searched_embedding_count = 0u;
// Whether the search completed without interruption. Starting a new search
// may cause a search to halt, and in that case this member will be false.
bool completed = false;
};
struct UrlPassagesEmbeddings {
UrlPassagesEmbeddings(history::URLID url_id,
history::VisitID visit_id,
base::Time visit_time);
UrlPassagesEmbeddings(const UrlPassagesEmbeddings&);
UrlPassagesEmbeddings& operator=(const UrlPassagesEmbeddings&);
bool operator==(const UrlPassagesEmbeddings&) const;
UrlPassages url_passages;
UrlEmbeddings url_embeddings;
};
// This base class decouples storage classes and inverts the dependency so that
// a vector database can work with a SQLite database, simple in-memory storage,
// flat files, or whatever kinds of storage will work efficiently.
class VectorDatabase {
public:
struct UrlDataIterator {
virtual ~UrlDataIterator() = default;
// Returns nullptr if none remain; otherwise advances the iterator
// and returns a pointer to the next instance (which may be owned
// by the iterator itself).
virtual const UrlPassagesEmbeddings* Next() = 0;
};
virtual ~VectorDatabase() = default;
// Returns the expected number of dimensions for an embedding.
virtual size_t GetEmbeddingDimensions() const = 0;
// Insert or update all embeddings for a URL's full set of passages.
// Returns true on success.
virtual bool AddUrlData(UrlPassagesEmbeddings url_passages_embeddings) = 0;
// Create an iterator that steps through database items.
// Null may be returned if there are none.
virtual std::unique_ptr<UrlDataIterator> MakeUrlDataIterator(
std::optional<base::Time> time_range_start) = 0;
// Searches the database for embeddings near given `query` and returns
// information about where they were found and how nearly the query matched.
SearchInfo FindNearest(std::optional<base::Time> time_range_start,
size_t count,
const Embedding& query,
base::RepeatingCallback<bool()> is_search_halted);
};
// This is an in-memory vector store that supports searching and saving to
// another persistent backing store.
class VectorDatabaseInMemory : public VectorDatabase {
public:
VectorDatabaseInMemory();
~VectorDatabaseInMemory() override;
// Save this store's data to another given store. Most implementations don't
// need this, but it's useful for an in-memory store to work with a separate
// backing database on a worker sequence.
void SaveTo(VectorDatabase* database);
// VectorDatabase:
size_t GetEmbeddingDimensions() const override;
bool AddUrlData(UrlPassagesEmbeddings url_passages_embeddings) override;
std::unique_ptr<UrlDataIterator> MakeUrlDataIterator(
std::optional<base::Time> time_range_start) override;
private:
std::vector<UrlPassagesEmbeddings> data_;
};
} // namespace history_embeddings
#endif // COMPONENTS_HISTORY_EMBEDDINGS_VECTOR_DATABASE_H_