| // Copyright 2024 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "components/history_embeddings/sql_database.h" |
| |
| #include <algorithm> |
| |
| #include "base/check.h" |
| #include "base/files/file_path.h" |
| #include "base/logging.h" |
| #include "base/metrics/histogram_functions.h" |
| #include "base/sequence_checker.h" |
| #include "base/strings/string_util.h" |
| #include "components/history/core/browser/history_backend.h" |
| #include "components/history_embeddings/history_embeddings_features.h" |
| #include "components/history_embeddings/passages_util.h" |
| #include "components/history_embeddings/proto/history_embeddings.pb.h" |
| #include "components/os_crypt/async/common/encryptor.h" |
| #include "sql/init_status.h" |
| #include "sql/meta_table.h" |
| #include "sql/transaction.h" |
| |
| namespace history_embeddings { |
| |
| // These database versions should roll together unless we develop migrations. |
| constexpr int kLowestSupportedDatabaseVersion = 1; |
| constexpr int kCurrentDatabaseVersion = 1; |
| |
| // This embeddings data version can be rolled to force recompute of all |
| // embeddings, useful when we change how source passages are preprocessed. |
| // Rolling this is preferable to rolling `kCurrentDatabaseVersion` in that |
| // source passages can be preserved; only the embeddings table will be cleared. |
| constexpr int kEmbeddingsDataVersion = 1; |
| |
| namespace { |
| |
| [[nodiscard]] bool InitSchema(sql::Database& db) { |
| static constexpr char kSqlCreateTablePassages[] = |
| "CREATE TABLE IF NOT EXISTS passages(" |
| // The URL associated with these passages, as stored in History. |
| "url_id INTEGER PRIMARY KEY NOT NULL," |
| // The Visit from which these passages were extracted. This is to allow |
| // us to properly expire and delete passage data when the associated |
| // visit is deleted. |
| "visit_id INTEGER NOT NULL," |
| // Store the associated visit time too, so we have a way to scrub expired |
| // entries if we ever miss deletion events from History. This can happen |
| // if Chrome shuts down unexpectedly or if History DB is razed. |
| "visit_time INTEGER NOT NULL," |
| // An opaque encrypted blob of passages. |
| "passages_blob BLOB NOT NULL);"; |
| if (!db.Execute(kSqlCreateTablePassages)) { |
| return false; |
| } |
| |
| // Create an index over visit_id so we can quickly delete passages associated |
| // with visits that get deleted. |
| if (!db.Execute("CREATE INDEX IF NOT EXISTS index_passages_visit_id ON " |
| "passages(visit_id)")) { |
| return false; |
| } |
| |
| static constexpr char kSqlCreateTableEmbeddings[] = |
| "CREATE TABLE IF NOT EXISTS embeddings(" |
| // The URL associated with these embeddings, as stored in History. |
| "url_id INTEGER PRIMARY KEY NOT NULL," |
| // The Visit from which these embeddings were computed. This is to allow |
| // us to properly expire and delete embedding data when the associated |
| // visit is deleted. |
| "visit_id INTEGER NOT NULL," |
| // Store the associated visit time too, so we have a way to scrub expired |
| // entries if we ever miss deletion events from History. This can happen |
| // if Chrome shuts down unexpectedly or if History DB is razed. |
| "visit_time INTEGER NOT NULL," |
| // A serialized proto::EmbeddingsValue message containing all embedding |
| // vectors from this URL/visit source. |
| "embeddings_blob BLOB NOT NULL);"; |
| if (!db.Execute(kSqlCreateTableEmbeddings)) { |
| return false; |
| } |
| |
| // Create an index over visit_id so we can quickly delete embeddings |
| // associated with visits that get deleted. |
| if (!db.Execute("CREATE INDEX IF NOT EXISTS index_embeddings_visit_id ON " |
| "embeddings(visit_id)")) { |
| return false; |
| } |
| |
| return true; |
| } |
| |
| } // namespace |
| |
| SqlDatabase::SqlDatabase(const base::FilePath& storage_dir, |
| bool erase_non_ascii_characters, |
| bool delete_embeddings) |
| : storage_dir_(storage_dir), |
| erase_non_ascii_characters_(erase_non_ascii_characters), |
| delete_embeddings_(delete_embeddings), |
| db_(/*tag=*/"HistoryEmbeddings"), |
| weak_ptr_factory_(this) {} |
| |
| SqlDatabase::~SqlDatabase() = default; |
| |
| void SqlDatabase::SetEmbedderMetadata( |
| passage_embeddings::EmbedderMetadata embedder_metadata, |
| os_crypt_async::Encryptor encryptor) { |
| embedder_metadata_ = embedder_metadata; |
| CHECK(!encryptor_.has_value()) << "Cannot call SetEmbedderMetadata twice."; |
| encryptor_.emplace(std::move(encryptor)); |
| } |
| |
| bool SqlDatabase::LazyInit(bool force_init_for_deletion) { |
| // Only use `force_init_for_deletion` if normal full initialization fails |
| // and a deletion request is being applied. Never use if normal init succeeds. |
| CHECK(!force_init_for_deletion || !db_init_status_.has_value()); |
| |
| if (!db_init_status_.has_value()) { |
| // Don't attempt initialization until ready, unless forced for the |
| // data deletion flow. |
| if (!embedder_metadata_ && !force_init_for_deletion) { |
| return false; |
| } |
| |
| db_init_status_ = InitInternal(storage_dir_, force_init_for_deletion); |
| base::UmaHistogramBoolean("History.Embeddings.DatabaseInitialized", |
| *db_init_status_ == sql::InitStatus::INIT_OK); |
| } |
| |
| return *db_init_status_ == sql::InitStatus::INIT_OK; |
| } |
| |
| sql::InitStatus SqlDatabase::InitInternal(const base::FilePath& storage_dir, |
| bool force_init_for_deletion) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| // base::Unretained is okay because `this` owns and outlives `db_`. |
| db_.set_error_callback(base::BindRepeating( |
| &SqlDatabase::DatabaseErrorCallback, base::Unretained(this))); |
| |
| base::FilePath db_file_path = storage_dir.Append(kHistoryEmbeddingsName); |
| if (!db_.Open(db_file_path)) { |
| return sql::InitStatus::INIT_FAILURE; |
| } |
| |
| // Raze old incompatible databases. |
| if (sql::MetaTable::RazeIfIncompatible(&db_, kLowestSupportedDatabaseVersion, |
| kCurrentDatabaseVersion) == |
| sql::RazeIfIncompatibleResult::kFailed) { |
| return sql::InitStatus::INIT_FAILURE; |
| } |
| |
| // Wrap initialization in a transaction to make it atomic. |
| sql::Transaction transaction(&db_); |
| if (!transaction.Begin()) { |
| return sql::InitStatus::INIT_FAILURE; |
| } |
| |
| // Initialize the current version meta table. Safest to leave the compatible |
| // version equal to the current version - unless we know we're making a very |
| // safe backwards-compatible schema change. |
| sql::MetaTable meta_table; |
| if (!meta_table.Init(&db_, kCurrentDatabaseVersion, |
| /*compatible_version=*/kCurrentDatabaseVersion)) { |
| return sql::InitStatus::INIT_FAILURE; |
| } |
| if (meta_table.GetCompatibleVersionNumber() > kCurrentDatabaseVersion) { |
| LOG(ERROR) << "HistoryEmbeddings database is too new."; |
| return sql::INIT_TOO_NEW; |
| } |
| |
| if (!InitSchema(db_)) { |
| return sql::INIT_FAILURE; |
| } |
| |
| if (!transaction.Commit()) { |
| return sql::INIT_FAILURE; |
| } |
| |
| // Delete passages and embeddings for visits that are beyond the data |
| // retention window. The history system automatically expires data while |
| // Chrome is running, but it's possible to miss events or start Chrome after |
| // some down time, so this prevents long term accidental retention edge cases. |
| DeleteExpiredData(/*expiration_time=*/base::Time::Now() - |
| base::Days(history::HistoryBackend::kExpireDaysThreshold)); |
| |
| // It's possible to get here without `embedder_metadata_` if forcing for |
| // data deletion. In that case, don't check or change meta table. |
| if (embedder_metadata_.has_value()) { |
| constexpr char kKeyModelVersion[] = "model_version"; |
| constexpr char kKeyEmbeddingsDataVersion[] = "embeddings_data_version"; |
| |
| int model_version = 0; |
| meta_table.GetValue(kKeyModelVersion, &model_version); |
| |
| bool delete_embeddings = |
| model_version != embedder_metadata_->model_version || |
| delete_embeddings_; |
| |
| // TODO(crbug.com/375502129): Remove this guard and the related guard below |
| // for more complete data version handling. |
| if (erase_non_ascii_characters_) { |
| int embeddings_data_version = 0; |
| meta_table.GetValue(kKeyEmbeddingsDataVersion, &embeddings_data_version); |
| delete_embeddings |= embeddings_data_version != kEmbeddingsDataVersion; |
| } |
| |
| if (delete_embeddings) { |
| // Old version embeddings can't be used with new model. Simply delete them |
| // all and set new version. Passages can be used for reconstruction later. |
| constexpr char kSqlDeleteFromEmbeddings[] = "DELETE FROM embeddings;"; |
| if (!db_.Execute(kSqlDeleteFromEmbeddings) || |
| !meta_table.SetValue(kKeyModelVersion, |
| embedder_metadata_->model_version)) { |
| return sql::InitStatus::INIT_FAILURE; |
| } |
| // Only write the meta table with this first data version change if |
| // doing so will result in the embeddings being rebuilt with the |
| // non-ASCII character changes. |
| // TODO(crbug.com/375502129): See above TODO comment; remove this guard. |
| if (erase_non_ascii_characters_ && |
| !meta_table.SetValue(kKeyEmbeddingsDataVersion, |
| kEmbeddingsDataVersion)) { |
| return sql::InitStatus::INIT_FAILURE; |
| } |
| } |
| } |
| |
| return sql::InitStatus::INIT_OK; |
| } |
| |
| void SqlDatabase::Close() { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| db_.Close(); |
| db_.reset_error_callback(); |
| db_init_status_.reset(); |
| } |
| |
| // Gets the passages associated with `url_id`. Returns nullopt if there's |
| // nothing available. |
| std::optional<proto::PassagesValue> SqlDatabase::GetPassages( |
| history::URLID url_id) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| if (!LazyInit()) { |
| return std::nullopt; |
| } |
| |
| constexpr char kSqlSelectPassages[] = |
| "SELECT passages_blob FROM passages WHERE url_id = ?"; |
| DCHECK(db_.IsSQLValid(kSqlSelectPassages)); |
| sql::Statement statement( |
| db_.GetCachedStatement(SQL_FROM_HERE, kSqlSelectPassages)); |
| statement.BindInt64(0, url_id); |
| |
| if (statement.Step()) { |
| return PassagesBlobToProto(statement.ColumnBlob(0), *encryptor_); |
| } |
| |
| return std::nullopt; |
| } |
| |
| std::optional<UrlData> SqlDatabase::GetUrlData(history::URLID url_id) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| if (!LazyInit()) { |
| return {}; |
| } |
| |
| history::VisitID visit_id = 0; |
| base::Time visit_time; |
| std::optional<proto::PassagesValue> passages; |
| { |
| constexpr char kSqlSelectVisitIdAndPassages[] = |
| "SELECT visit_id, visit_time, passages_blob FROM passages WHERE url_id " |
| "= ?"; |
| DCHECK(db_.IsSQLValid(kSqlSelectVisitIdAndPassages)); |
| sql::Statement statement( |
| db_.GetCachedStatement(SQL_FROM_HERE, kSqlSelectVisitIdAndPassages)); |
| statement.BindInt64(0, url_id); |
| |
| if (statement.Step()) { |
| visit_id = statement.ColumnInt64(0); |
| visit_time = statement.ColumnTime(1); |
| passages = PassagesBlobToProto(statement.ColumnBlob(2), *encryptor_); |
| } |
| } |
| if (!passages.has_value() || visit_id == 0) { |
| return {}; |
| } |
| |
| UrlData url_data(url_id, visit_id, visit_time); |
| url_data.passages = std::move(passages.value()); |
| bool loaded_missized_embedding = false; |
| { |
| constexpr char kSqlSelectEmbeddings[] = |
| "SELECT embeddings_blob FROM embeddings " |
| "WHERE visit_id = ?"; |
| DCHECK(db_.IsSQLValid(kSqlSelectEmbeddings)); |
| sql::Statement statement( |
| db_.GetCachedStatement(SQL_FROM_HERE, kSqlSelectEmbeddings)); |
| statement.BindInt64(0, visit_id); |
| |
| if (statement.Step()) { |
| base::span<const uint8_t> blob = statement.ColumnBlob(0); |
| |
| proto::EmbeddingsValue value; |
| if (!value.ParseFromArray(blob.data(), blob.size())) { |
| return url_data; |
| } |
| for (const proto::EmbeddingVector& vector : value.vectors()) { |
| url_data.embeddings.emplace_back( |
| std::vector(vector.floats().cbegin(), vector.floats().cend()), |
| vector.passage_word_count()); |
| if (url_data.embeddings.back().Dimensions() != |
| GetEmbeddingDimensions()) { |
| url_data.embeddings.clear(); |
| loaded_missized_embedding = true; |
| break; |
| } |
| } |
| } |
| } |
| base::UmaHistogramBoolean("History.Embeddings.LoadedMissizedEmbedding", |
| loaded_missized_embedding); |
| return url_data; |
| } |
| |
| std::vector<UrlData> SqlDatabase::GetUrlDataInTimeRange(base::Time from_time, |
| base::Time to_time, |
| size_t limit, |
| size_t offset) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| if (!LazyInit()) { |
| return {}; |
| } |
| |
| constexpr char kSqlSelectOrderedPassagesAndEmbeddingsWithinTimeRange[] = |
| "SELECT passages.url_id, passages.visit_id, passages.visit_time, " |
| "passages.passages_blob, embeddings.embeddings_blob " |
| "FROM passages " |
| "INNER JOIN embeddings ON passages.url_id = embeddings.url_id " |
| "WHERE passages.visit_time >= ? AND passages.visit_time < ? " |
| "ORDER BY passages.visit_time LIMIT ? OFFSET ?"; |
| DCHECK(db_.IsSQLValid(kSqlSelectOrderedPassagesAndEmbeddingsWithinTimeRange)); |
| sql::Statement statement(db_.GetCachedStatement( |
| SQL_FROM_HERE, kSqlSelectOrderedPassagesAndEmbeddingsWithinTimeRange)); |
| statement.BindTime(0, from_time); |
| statement.BindTime(1, to_time); |
| statement.BindInt(2, static_cast<int>(limit)); |
| statement.BindInt(3, static_cast<int>(offset)); |
| |
| std::vector<UrlData> url_datas; |
| while (statement.Step()) { |
| history::URLID url_id = statement.ColumnInt64(0); |
| history::VisitID visit_id = statement.ColumnInt64(1); |
| base::Time visit_time = statement.ColumnTime(2); |
| UrlData& url_data = url_datas.emplace_back(url_id, visit_id, visit_time); |
| |
| std::optional<proto::PassagesValue> passages = |
| PassagesBlobToProto(statement.ColumnBlob(3), *encryptor_); |
| if (passages.has_value()) { |
| url_data.passages = std::move(passages.value()); |
| } |
| |
| proto::EmbeddingsValue value; |
| base::span<const uint8_t> embeddings_blob = statement.ColumnBlob(4); |
| if (value.ParseFromArray(embeddings_blob.data(), embeddings_blob.size())) { |
| for (const proto::EmbeddingVector& vector : value.vectors()) { |
| url_data.embeddings.emplace_back( |
| std::vector(vector.floats().cbegin(), vector.floats().cend()), |
| vector.passage_word_count()); |
| } |
| } |
| } |
| |
| return url_datas; |
| } |
| |
| std::vector<UrlData> SqlDatabase::GetUrlPassagesWithoutEmbeddings() { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| if (!LazyInit()) { |
| return {}; |
| } |
| |
| constexpr char kSqlSelectPassagesWithoutEmbeddings[] = |
| "SELECT url_id, visit_id, visit_time, passages_blob " |
| "FROM passages WHERE url_id NOT IN (SELECT url_id FROM embeddings);"; |
| DCHECK(db_.IsSQLValid(kSqlSelectPassagesWithoutEmbeddings)); |
| sql::Statement statement(db_.GetCachedStatement( |
| SQL_FROM_HERE, kSqlSelectPassagesWithoutEmbeddings)); |
| |
| std::vector<UrlData> all_url_passages; |
| while (statement.Step()) { |
| std::optional<proto::PassagesValue> passages_value = |
| PassagesBlobToProto(statement.ColumnBlob(3), *encryptor_); |
| if (passages_value.has_value()) { |
| UrlData& url_passages = all_url_passages.emplace_back( |
| statement.ColumnInt64(0), statement.ColumnInt64(1), |
| statement.ColumnTime(2)); |
| url_passages.passages = std::move(passages_value.value()); |
| } |
| } |
| return all_url_passages; |
| } |
| |
| bool SqlDatabase::AddAnyUrlDataForTesting(UrlData url_data) { |
| if (!LazyInit()) { |
| return false; |
| } |
| return InsertOrReplacePassages(url_data) && |
| InsertOrReplaceEmbeddings(url_data); |
| } |
| |
| size_t SqlDatabase::GetEmbeddingDimensions() const { |
| return embedder_metadata_->output_size; |
| } |
| |
| bool SqlDatabase::AddUrlData(UrlData url_data) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| if (!LazyInit()) { |
| return false; |
| } |
| |
| CHECK(static_cast<size_t>(url_data.passages.passages_size()) == |
| url_data.embeddings.size()); |
| sql::Transaction transaction(&db_); |
| return transaction.Begin() && InsertOrReplacePassages(url_data) && |
| InsertOrReplaceEmbeddings(url_data) && transaction.Commit(); |
| } |
| |
| constexpr char kSqlSelectPassagesAndEmbeddings[] = |
| "SELECT passages.url_id, passages.visit_id, passages.visit_time, " |
| "passages.passages_blob, embeddings.embeddings_blob " |
| "FROM passages " |
| "INNER JOIN embeddings ON passages.url_id = embeddings.url_id"; |
| constexpr char kSqlSelectPassagesAndEmbeddingsWithinTimeRange[] = |
| "SELECT passages.url_id, passages.visit_id, passages.visit_time, " |
| "passages.passages_blob, embeddings.embeddings_blob " |
| "FROM passages " |
| "INNER JOIN embeddings ON passages.url_id = embeddings.url_id " |
| "WHERE passages.visit_time >= ?"; |
| |
| std::unique_ptr<VectorDatabase::UrlDataIterator> |
| SqlDatabase::MakeUrlDataIterator(std::optional<base::Time> time_range_start) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| if (!LazyInit()) { |
| return nullptr; |
| } |
| |
| DCHECK(db_.IsSQLValid(kSqlSelectPassagesAndEmbeddings)); |
| DCHECK(db_.IsSQLValid(kSqlSelectPassagesAndEmbeddingsWithinTimeRange)); |
| |
| struct RowDataIterator : public UrlDataIterator { |
| explicit RowDataIterator(base::WeakPtr<SqlDatabase> sql_database, |
| std::optional<base::Time> time_range_start) |
| : sql_database(sql_database), data(0, 0, base::Time()) { |
| CHECK(!sql_database->iteration_statement_); |
| if (time_range_start.has_value()) { |
| sql_database->iteration_statement_ = std::make_unique<sql::Statement>( |
| sql_database->db_.GetCachedStatement( |
| SQL_FROM_HERE, kSqlSelectPassagesAndEmbeddingsWithinTimeRange)); |
| sql_database->iteration_statement_->BindTime(0, |
| time_range_start.value()); |
| } else { |
| sql_database->iteration_statement_ = std::make_unique<sql::Statement>( |
| sql_database->db_.GetCachedStatement( |
| SQL_FROM_HERE, kSqlSelectPassagesAndEmbeddings)); |
| } |
| } |
| ~RowDataIterator() override { |
| if (sql_database) { |
| sql_database->iteration_statement_.reset(); |
| } |
| base::UmaHistogramCounts1000( |
| "History.Embeddings.DatabaseIterationSkippedPassages", |
| skipped_passages); |
| base::UmaHistogramCounts1000( |
| "History.Embeddings.DatabaseIterationSkippedEmbeddings", |
| skipped_embeddings); |
| base::UmaHistogramCounts1000( |
| "History.Embeddings.DatabaseIterationSkippedMismatches", |
| skipped_mismatches); |
| base::UmaHistogramCounts1000( |
| "History.Embeddings.DatabaseIterationSkippedMissizedEmbeddings", |
| skipped_missized); |
| base::UmaHistogramCounts10000( |
| "History.Embeddings.DatabaseIterationYielded", yielded); |
| } |
| |
| const UrlData* Next() override { |
| if (!sql_database) { |
| return nullptr; |
| } |
| sql::Statement* statement = sql_database->iteration_statement_.get(); |
| CHECK(statement); |
| // Don't expect perfect data; step until we find valid data. |
| while (statement->Step()) { |
| data = UrlData(/*url_id=*/statement->ColumnInt64(0), |
| /*visit_id=*/statement->ColumnInt64(1), |
| /*visit_time=*/statement->ColumnTime(2)); |
| // Passages |
| std::optional<proto::PassagesValue> passages_value = |
| PassagesBlobToProto(statement->ColumnBlob(3), |
| *sql_database->encryptor_); |
| if (!passages_value.has_value()) { |
| skipped_passages++; |
| continue; |
| } |
| data.passages = std::move(passages_value.value()); |
| |
| // Embeddings |
| base::span<const uint8_t> blob = statement->ColumnBlob(4); |
| proto::EmbeddingsValue value; |
| if (!value.ParseFromArray(blob.data(), blob.size())) { |
| skipped_embeddings++; |
| continue; |
| } |
| for (const proto::EmbeddingVector& vector : value.vectors()) { |
| data.embeddings.emplace_back( |
| std::vector(vector.floats().cbegin(), vector.floats().cend()), |
| vector.passage_word_count()); |
| } |
| const size_t expected_dimensions = |
| sql_database->GetEmbeddingDimensions(); |
| if (std::ranges::any_of( |
| data.embeddings, |
| [=](const passage_embeddings::Embedding& embedding) { |
| return embedding.Dimensions() != expected_dimensions; |
| })) { |
| skipped_missized++; |
| continue; |
| } |
| |
| // Confirm embeddings and passages are 1:1. |
| if (data.embeddings.empty() || |
| data.embeddings.size() != |
| static_cast<size_t>(data.passages.passages_size())) { |
| skipped_mismatches++; |
| continue; |
| } |
| |
| yielded++; |
| return &data; |
| } |
| return nullptr; |
| } |
| |
| base::WeakPtr<SqlDatabase> sql_database; |
| UrlData data; |
| // Keep stats on any data loading failures, and report histogram in dtor. |
| int skipped_passages = 0; |
| int skipped_embeddings = 0; |
| int skipped_mismatches = 0; |
| int skipped_missized = 0; |
| int yielded = 0; |
| }; |
| |
| return std::make_unique<RowDataIterator>(weak_ptr_factory_.GetWeakPtr(), |
| time_range_start); |
| } |
| |
| bool SqlDatabase::DeleteDataForUrlId(history::URLID url_id) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| bool close = false; |
| if (!LazyInit()) { |
| // Database isn't fully initialized. Attempt to force open it just for |
| // deletion, and then close so it can be initialized normally later. |
| if (!LazyInit(true)) { |
| return false; |
| } |
| close = true; |
| } |
| |
| bool delete_passages_success = false; |
| { |
| constexpr char kSqlDeleteFromPassagesByUrl[] = |
| "DELETE FROM passages WHERE url_id=?"; |
| DCHECK(db_.IsSQLValid(kSqlDeleteFromPassagesByUrl)); |
| sql::Statement statement( |
| db_.GetCachedStatement(SQL_FROM_HERE, kSqlDeleteFromPassagesByUrl)); |
| statement.BindInt64(0, url_id); |
| delete_passages_success = statement.Run(); |
| } |
| bool delete_embeddings_success = false; |
| { |
| constexpr char kSqlDeleteFromEmbeddingsByUrl[] = |
| "DELETE FROM embeddings WHERE url_id=?"; |
| DCHECK(db_.IsSQLValid(kSqlDeleteFromEmbeddingsByUrl)); |
| sql::Statement statement( |
| db_.GetCachedStatement(SQL_FROM_HERE, kSqlDeleteFromEmbeddingsByUrl)); |
| statement.BindInt64(0, url_id); |
| delete_embeddings_success = statement.Run(); |
| } |
| |
| if (close) { |
| Close(); |
| } |
| |
| return delete_passages_success && delete_embeddings_success; |
| } |
| |
| bool SqlDatabase::DeleteDataForVisitId(history::VisitID visit_id) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| bool close = false; |
| if (!LazyInit()) { |
| // Database isn't fully initialized. Attempt to force open it just for |
| // deletion, and then close so it can be initialized normally later. |
| if (!LazyInit(true)) { |
| return false; |
| } |
| close = true; |
| } |
| |
| bool delete_passages_success = false; |
| { |
| constexpr char kSqlDeleteFromPassagesByVisit[] = |
| "DELETE FROM passages WHERE visit_id=?"; |
| DCHECK(db_.IsSQLValid(kSqlDeleteFromPassagesByVisit)); |
| sql::Statement statement( |
| db_.GetCachedStatement(SQL_FROM_HERE, kSqlDeleteFromPassagesByVisit)); |
| statement.BindInt64(0, visit_id); |
| delete_passages_success = statement.Run(); |
| } |
| bool delete_embeddings_success = false; |
| { |
| constexpr char kSqlDeleteFromEmbeddingsByVisit[] = |
| "DELETE FROM embeddings WHERE visit_id=?"; |
| DCHECK(db_.IsSQLValid(kSqlDeleteFromEmbeddingsByVisit)); |
| sql::Statement statement( |
| db_.GetCachedStatement(SQL_FROM_HERE, kSqlDeleteFromEmbeddingsByVisit)); |
| statement.BindInt64(0, visit_id); |
| delete_embeddings_success = statement.Run(); |
| } |
| |
| if (close) { |
| Close(); |
| } |
| |
| return delete_passages_success && delete_embeddings_success; |
| } |
| |
| bool SqlDatabase::DeleteAllData(bool delete_passages, bool delete_embeddings) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| bool close = false; |
| if (!LazyInit()) { |
| // Database isn't fully initialized. Attempt to force open it just for |
| // deletion, and then close so it can be initialized normally later. |
| if (!LazyInit(true)) { |
| return false; |
| } |
| close = true; |
| } |
| |
| bool delete_passages_success = |
| !delete_passages || db_.Execute("DELETE FROM passages;"); |
| bool delete_embeddings_success = |
| !delete_embeddings || db_.Execute("DELETE FROM embeddings;"); |
| |
| if (close) { |
| Close(); |
| } |
| |
| return delete_passages_success && delete_embeddings_success; |
| } |
| |
| void SqlDatabase::DatabaseErrorCallback(int extended_error, |
| sql::Statement* statement) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| // TODO(crbug.com/325524013): Handle razing the database on catastrophic |
| // error. |
| |
| // The default handling is to assert on debug and to ignore on release. |
| // This is because database errors happen in the wild due to faulty hardware, |
| // or are sometimes transitory, and we want Chrome to carry on when possible. |
| if (!sql::Database::IsExpectedSqliteError(extended_error)) { |
| DLOG(FATAL) << db_.GetErrorMessage(); |
| } |
| } |
| |
| void SqlDatabase::DeleteExpiredData(base::Time expiration_time) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| constexpr char kSqlDeleteExpiredPassages[] = |
| "DELETE FROM passages WHERE visit_time < ?;"; |
| constexpr char kSqlDeleteExpiredEmbeddings[] = |
| "DELETE FROM embeddings WHERE visit_time < ?;"; |
| DCHECK(db_.IsSQLValid(kSqlDeleteExpiredPassages)); |
| DCHECK(db_.IsSQLValid(kSqlDeleteExpiredEmbeddings)); |
| |
| sql::Statement expire_passages( |
| db_.GetUniqueStatement(kSqlDeleteExpiredPassages)); |
| expire_passages.BindTime(0, expiration_time); |
| expire_passages.Run(); |
| |
| sql::Statement expire_embeddings( |
| db_.GetUniqueStatement(kSqlDeleteExpiredEmbeddings)); |
| expire_embeddings.BindTime(0, expiration_time); |
| expire_embeddings.Run(); |
| } |
| |
| bool SqlDatabase::InsertOrReplacePassages(const UrlData& url_passages) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| constexpr char kSqlInsertOrReplacePassages[] = |
| "INSERT OR REPLACE INTO passages " |
| "(url_id, visit_id, visit_time, passages_blob) " |
| "VALUES (?,?,?,?)"; |
| DCHECK(db_.IsSQLValid(kSqlInsertOrReplacePassages)); |
| sql::Statement statement( |
| db_.GetCachedStatement(SQL_FROM_HERE, kSqlInsertOrReplacePassages)); |
| statement.BindInt64(0, url_passages.url_id); |
| statement.BindInt64(1, url_passages.visit_id); |
| statement.BindTime(2, url_passages.visit_time); |
| |
| std::vector<uint8_t> blob = |
| PassagesProtoToBlob(url_passages.passages, *encryptor_); |
| if (blob.empty()) { |
| return false; |
| } |
| statement.BindBlob(3, std::move(blob)); |
| bool result = statement.Run(); |
| |
| if (result) { |
| size_t ascii_passages_count = 0; |
| size_t non_ascii_passages_count = 0; |
| for (const std::string& passage : url_passages.passages.passages()) { |
| if (base::IsStringASCII(passage)) { |
| ascii_passages_count++; |
| } else { |
| non_ascii_passages_count++; |
| } |
| } |
| base::UmaHistogramCounts100( |
| "History.Embeddings.DatabaseStoredAsciiPassages", ascii_passages_count); |
| base::UmaHistogramCounts100( |
| "History.Embeddings.DatabaseStoredNonAsciiPassages", |
| non_ascii_passages_count); |
| base::UmaHistogramPercentage( |
| "History.Embeddings.DatabaseStoredNonAsciiPassageRatio", |
| 100 * non_ascii_passages_count / |
| (ascii_passages_count + non_ascii_passages_count)); |
| } |
| |
| return result; |
| } |
| |
| bool SqlDatabase::InsertOrReplaceEmbeddings(const UrlData& url_embeddings) { |
| DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); |
| |
| if (url_embeddings.embeddings.size() == 0) { |
| return false; |
| } |
| |
| constexpr char kSqlInsertOrReplaceEmbeddings[] = |
| "INSERT OR REPLACE INTO embeddings " |
| "(url_id, visit_id, visit_time, embeddings_blob) " |
| "VALUES (?,?,?,?)"; |
| DCHECK(db_.IsSQLValid(kSqlInsertOrReplaceEmbeddings)); |
| sql::Statement statement( |
| db_.GetCachedStatement(SQL_FROM_HERE, kSqlInsertOrReplaceEmbeddings)); |
| statement.BindInt64(0, url_embeddings.url_id); |
| statement.BindInt64(1, url_embeddings.visit_id); |
| statement.BindTime(2, url_embeddings.visit_time); |
| |
| proto::EmbeddingsValue value; |
| for (const passage_embeddings::Embedding& embedding : |
| url_embeddings.embeddings) { |
| CHECK_EQ(GetEmbeddingDimensions(), embedding.Dimensions()); |
| proto::EmbeddingVector* vector = value.add_vectors(); |
| for (float f : embedding.GetData()) { |
| vector->add_floats(f); |
| } |
| vector->set_passage_word_count(embedding.GetPassageWordCount()); |
| } |
| statement.BindBlob(3, value.SerializeAsString()); |
| |
| return statement.Run(); |
| } |
| |
| } // namespace history_embeddings |