blob: 0f311fd6810a52ffcc789023d156fb9d7ac27c57 [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/history_embeddings/sql_database.h"
#include <array>
#include "base/files/file_path.h"
#include "base/files/file_util.h"
#include "base/files/scoped_temp_dir.h"
#include "base/run_loop.h"
#include "base/test/bind.h"
#include "base/test/metrics/histogram_tester.h"
#include "base/test/task_environment.h"
#include "base/test/test_future.h"
#include "components/history_embeddings/history_embeddings_features.h"
#include "components/history_embeddings/proto/history_embeddings.pb.h"
#include "components/os_crypt/async/browser/test_utils.h"
#include "components/os_crypt/async/common/encryptor.h"
#include "components/passage_embeddings/passage_embeddings_types.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace history_embeddings {
using passage_embeddings::Embedding;
namespace {
constexpr int64_t kEmbeddingsVersion = 1;
constexpr size_t kEmbeddingsSize = 768ul;
Embedding FakeEmbedding() {
Embedding embedding(std::vector<float>(kEmbeddingsSize, 1.0f));
embedding.Normalize();
return embedding;
}
} // namespace
class HistoryEmbeddingsSqlDatabaseTest : public testing::Test {
public:
HistoryEmbeddingsSqlDatabaseTest()
: os_crypt_(os_crypt_async::GetTestOSCryptAsyncForTesting()) {}
void SetUp() override {
CHECK(history_dir_.CreateUniqueTempDir());
}
void TearDown() override {
CHECK(history_dir_.Delete());
}
std::unique_ptr<SqlDatabase> MakeDatabase() {
return std::make_unique<SqlDatabase>(
history_dir_.GetPath(),
GetFeatureParameters().erase_non_ascii_characters,
GetFeatureParameters().delete_embeddings);
}
// Adds mock data for url_id = 1 tied to visit_id = 10, and url_id = 2 tied to
// visit_id = 11.
void AddBasicMockData(SqlDatabase* sql_database) {
{
UrlData url_data_1(1, 10, base::Time::Now());
url_data_1.passages.add_passages("fake passage 1");
url_data_1.passages.add_passages("fake passage 2");
url_data_1.embeddings.emplace_back(
std::vector<float>(kEmbeddingsSize, 1.0f));
url_data_1.embeddings.emplace_back(
std::vector<float>(kEmbeddingsSize, 1.0f));
ASSERT_TRUE(sql_database->AddUrlData(url_data_1));
}
{
UrlData url_data_2(2, 11, base::Time::Now());
url_data_2.passages.add_passages("fake passage 3");
url_data_2.passages.add_passages("fake passage 4");
url_data_2.embeddings.emplace_back(
std::vector<float>(kEmbeddingsSize, 1.0f));
url_data_2.embeddings.emplace_back(
std::vector<float>(kEmbeddingsSize, 1.0f));
ASSERT_TRUE(sql_database->AddUrlData(url_data_2));
}
ASSERT_TRUE(sql_database->GetPassages(1));
ASSERT_TRUE(sql_database->GetPassages(2));
ASSERT_EQ(GetEmbeddingCount(sql_database), 2U);
}
size_t GetEmbeddingCount(SqlDatabase* sql_database) {
auto iterator = sql_database->MakeUrlDataIterator({});
EXPECT_TRUE(iterator);
size_t count = 0;
while (iterator->Next()) {
count++;
}
return count;
}
protected:
os_crypt_async::Encryptor GetEncryptorInstance() {
base::test::TestFuture<os_crypt_async::Encryptor> future;
os_crypt_->GetInstance(future.GetCallback());
return future.Take();
}
base::test::TaskEnvironment env_;
std::unique_ptr<os_crypt_async::OSCryptAsync> os_crypt_;
base::ScopedTempDir history_dir_;
};
TEST_F(HistoryEmbeddingsSqlDatabaseTest, WriteCloseAndThenReadPassages) {
auto sql_database = MakeDatabase();
sql_database->SetEmbedderMetadata({kEmbeddingsVersion, kEmbeddingsSize},
GetEncryptorInstance());
// Write passages
UrlData url_passages(1, 1, base::Time::Now());
url_passages.passages.add_passages("fake passage 1");
url_passages.passages.add_passages("fake passage 2");
url_passages.embeddings.push_back(FakeEmbedding());
url_passages.embeddings.push_back(FakeEmbedding());
EXPECT_TRUE(sql_database->AddUrlData(url_passages));
// Reset and reload.
sql_database.reset();
sql_database = MakeDatabase();
sql_database->SetEmbedderMetadata({kEmbeddingsVersion, kEmbeddingsSize},
GetEncryptorInstance());
// Read passages
auto read_proto = sql_database->GetPassages(1);
ASSERT_TRUE(read_proto);
ASSERT_EQ(read_proto->passages_size(), 2);
EXPECT_EQ(read_proto->passages()[0], "fake passage 1");
EXPECT_EQ(read_proto->passages()[1], "fake passage 2");
sql_database.reset();
EXPECT_TRUE(
base::PathExists(history_dir_.GetPath().Append(kHistoryEmbeddingsName)))
<< "DB file is still there after destruction.";
}
TEST_F(HistoryEmbeddingsSqlDatabaseTest, WriteCloseAndThenReadUrlData) {
auto sql_database = MakeDatabase();
sql_database->SetEmbedderMetadata({kEmbeddingsVersion, kEmbeddingsSize},
GetEncryptorInstance());
// Write embeddings.
auto url_datas = std::to_array<UrlData>({
UrlData(1, 1, base::Time::Now()),
UrlData(2, 2, base::Time::Now()),
});
url_datas[0].passages.add_passages("data 0 passage 0");
url_datas[0].embeddings.push_back(FakeEmbedding());
url_datas[1].passages.add_passages("data 1 passage 0");
url_datas[1].passages.add_passages("data 1 passage 1");
url_datas[1].embeddings.push_back(FakeEmbedding());
url_datas[1].embeddings.push_back(FakeEmbedding());
EXPECT_TRUE(sql_database->AddUrlData(url_datas[0]));
EXPECT_TRUE(sql_database->AddUrlData(url_datas[1]));
// Reset and reload.
sql_database.reset();
sql_database = MakeDatabase();
sql_database->SetEmbedderMetadata({kEmbeddingsVersion, kEmbeddingsSize},
GetEncryptorInstance());
// Read embeddings.
{
// Block scope destructs iterator before database is closed.
std::unique_ptr<VectorDatabase::UrlDataIterator> iterator =
sql_database->MakeUrlDataIterator({});
EXPECT_TRUE(iterator);
for (const UrlData& url_data : url_datas) {
const UrlData* read_url_data = iterator->Next();
EXPECT_TRUE(read_url_data);
EXPECT_EQ(*read_url_data, url_data);
}
EXPECT_FALSE(iterator->Next());
}
sql_database.reset();
EXPECT_TRUE(
base::PathExists(history_dir_.GetPath().Append(kHistoryEmbeddingsName)))
<< "DB file is still there after destruction.";
}
TEST_F(HistoryEmbeddingsSqlDatabaseTest, TimeRangeNarrowsSearchResult) {
auto sql_database = MakeDatabase();
sql_database->SetEmbedderMetadata({kEmbeddingsVersion, kEmbeddingsSize},
GetEncryptorInstance());
// Write embeddings.
const base::Time now = base::Time::Now();
for (size_t i = 0; i < 3; i++) {
UrlData url_data(i + 1, i + 1, now + base::Minutes(i));
for (size_t j = 0; j < 3; j++) {
url_data.passages.add_passages("fake passage");
url_data.embeddings.push_back(FakeEmbedding());
}
sql_database->AddUrlData(url_data);
}
Embedding query = FakeEmbedding();
SearchParams search_params;
// An ordinary search with full results:
{
std::vector<ScoredUrl> scored_urls =
sql_database
->FindNearest({}, 3, search_params, query,
base::BindRepeating([]() { return false; }))
.scored_urls;
CHECK_EQ(scored_urls.size(), 3u);
}
// Narrowed searches with time range.
{
std::vector<ScoredUrl> scored_urls =
sql_database
->FindNearest(now, 3, search_params, query,
base::BindRepeating([]() { return false; }))
.scored_urls;
CHECK_EQ(scored_urls.size(), 3u);
}
{
std::vector<ScoredUrl> scored_urls =
sql_database
->FindNearest(now + base::Seconds(30), 3, search_params, query,
base::BindRepeating([]() { return false; }))
.scored_urls;
CHECK_EQ(scored_urls.size(), 2u);
}
{
std::vector<ScoredUrl> scored_urls =
sql_database
->FindNearest(now + base::Seconds(90), 3, search_params, query,
base::BindRepeating([]() { return false; }))
.scored_urls;
CHECK_EQ(scored_urls.size(), 1u);
}
{
std::vector<ScoredUrl> scored_urls =
sql_database
->FindNearest(now + base::Minutes(2), 3, search_params, query,
base::BindRepeating([]() { return false; }))
.scored_urls;
CHECK_EQ(scored_urls.size(), 1u);
}
{
std::vector<ScoredUrl> scored_urls =
sql_database
->FindNearest(now + base::Seconds(121), 3, search_params, query,
base::BindRepeating([]() { return false; }))
.scored_urls;
CHECK_EQ(scored_urls.size(), 0u);
}
}
TEST_F(HistoryEmbeddingsSqlDatabaseTest, InsertOrReplacePassages) {
auto sql_database = MakeDatabase();
sql_database->SetEmbedderMetadata({kEmbeddingsVersion, kEmbeddingsSize},
GetEncryptorInstance());
UrlData url_data(1, 1, base::Time::Now());
url_data.passages.add_passages("fake passage 1");
url_data.passages.add_passages("fake passage 2");
url_data.embeddings.push_back(FakeEmbedding());
url_data.embeddings.push_back(FakeEmbedding());
EXPECT_TRUE(sql_database->AddUrlData(url_data));
url_data.visit_id = 2;
url_data.passages.add_passages("fake passage 3");
url_data.embeddings.push_back(FakeEmbedding());
EXPECT_TRUE(sql_database->AddUrlData(url_data));
// Verify that the new one has replaced the old one.
auto read_proto = sql_database->GetPassages(1);
ASSERT_TRUE(read_proto);
ASSERT_EQ(read_proto->passages_size(), 3);
EXPECT_EQ(read_proto->passages()[0], "fake passage 1");
EXPECT_EQ(read_proto->passages()[1], "fake passage 2");
EXPECT_EQ(read_proto->passages()[2], "fake passage 3");
EXPECT_FALSE(sql_database->GetPassages(2));
}
TEST_F(HistoryEmbeddingsSqlDatabaseTest, IteratorMaySafelyOutliveDatabase) {
auto sql_database = MakeDatabase();
sql_database->SetEmbedderMetadata({kEmbeddingsVersion, kEmbeddingsSize},
GetEncryptorInstance());
AddBasicMockData(sql_database.get());
// Without database reset, iteration reads data.
{
std::unique_ptr<VectorDatabase::UrlDataIterator> iterator =
sql_database->MakeUrlDataIterator({});
EXPECT_TRUE(iterator);
EXPECT_TRUE(iterator->Next());
}
// With database reset, iteration gracefully ends.
{
std::unique_ptr<VectorDatabase::UrlDataIterator> iterator =
sql_database->MakeUrlDataIterator({});
EXPECT_TRUE(iterator);
// Reset database while iterator is still in scope.
sql_database.reset();
// Iterator access with dead database doesn't crash, just ends iteration.
EXPECT_FALSE(iterator->Next());
}
}
TEST_F(HistoryEmbeddingsSqlDatabaseTest, DeleteDataForUrlId) {
auto sql_database = MakeDatabase();
sql_database->SetEmbedderMetadata({kEmbeddingsVersion, kEmbeddingsSize},
GetEncryptorInstance());
AddBasicMockData(sql_database.get());
EXPECT_TRUE(sql_database->DeleteDataForUrlId(3))
<< "Deleting a non-existing url_id = 3 should return true but do "
"nothing.";
EXPECT_TRUE(sql_database->GetPassages(1));
EXPECT_TRUE(sql_database->GetPassages(2));
EXPECT_EQ(GetEmbeddingCount(sql_database.get()), 2U);
EXPECT_TRUE(sql_database->DeleteDataForUrlId(2))
<< "Succeeds. url_id = 2 does exist.";
EXPECT_TRUE(sql_database->GetPassages(1));
EXPECT_FALSE(sql_database->GetPassages(2));
EXPECT_EQ(GetEmbeddingCount(sql_database.get()), 1U);
}
TEST_F(HistoryEmbeddingsSqlDatabaseTest, DeleteDataForVisitId) {
auto sql_database = MakeDatabase();
sql_database->SetEmbedderMetadata({kEmbeddingsVersion, kEmbeddingsSize},
GetEncryptorInstance());
AddBasicMockData(sql_database.get());
EXPECT_TRUE(sql_database->DeleteDataForVisitId(40))
<< "Deleting a non-existing visit_id = 40 should return true but do "
"nothing.";
EXPECT_TRUE(sql_database->GetPassages(1));
EXPECT_TRUE(sql_database->GetPassages(2));
EXPECT_EQ(GetEmbeddingCount(sql_database.get()), 2U);
EXPECT_TRUE(sql_database->DeleteDataForVisitId(11))
<< "Succeeds. visit_id = 11 does exist.";
EXPECT_TRUE(sql_database->GetPassages(1));
EXPECT_FALSE(sql_database->GetPassages(2));
EXPECT_EQ(GetEmbeddingCount(sql_database.get()), 1U);
}
TEST_F(HistoryEmbeddingsSqlDatabaseTest, DeleteAllData) {
auto sql_database = MakeDatabase();
sql_database->SetEmbedderMetadata({kEmbeddingsVersion, kEmbeddingsSize},
GetEncryptorInstance());
AddBasicMockData(sql_database.get());
EXPECT_TRUE(sql_database->DeleteAllData(true, true));
EXPECT_FALSE(sql_database->GetPassages(1));
EXPECT_FALSE(sql_database->GetPassages(2));
EXPECT_EQ(GetEmbeddingCount(sql_database.get()), 0U);
}
TEST_F(HistoryEmbeddingsSqlDatabaseTest, DeleteDataWithoutEmbedderMetadata) {
UrlData url_data(1, 10, base::Time::Now());
url_data.passages.add_passages("fake passage 1");
url_data.embeddings.emplace_back(std::vector<float>(kEmbeddingsSize, 1.0f));
{
auto sql_database = MakeDatabase();
// Adding data is expected to fail because the database can't initialize
// fully without embedder metadata.
ASSERT_FALSE(sql_database->AddUrlData(url_data));
// With metadata set, now adding the data succeeds.
sql_database->SetEmbedderMetadata({kEmbeddingsVersion, kEmbeddingsSize},
GetEncryptorInstance());
ASSERT_TRUE(sql_database->AddUrlData(url_data));
// Don't delete yet. That would succeed as normal. Close with data resident.
EXPECT_TRUE(sql_database->GetPassages(1));
}
{
// Initialize database again, to see that we can still get it only when
// metadata is provided.
auto sql_database = MakeDatabase();
EXPECT_FALSE(sql_database->GetPassages(1));
sql_database->SetEmbedderMetadata({kEmbeddingsVersion, kEmbeddingsSize},
GetEncryptorInstance());
EXPECT_TRUE(sql_database->GetPassages(1));
EXPECT_EQ(GetEmbeddingCount(sql_database.get()), 1U);
// Again deletion would work as normal here.
}
{
// Initialize database again, with no embedder metadata.
auto sql_database = MakeDatabase();
EXPECT_FALSE(sql_database->GetPassages(1));
// Deletion succeeds even with no metadata provided.
EXPECT_TRUE(sql_database->DeleteAllData(true, true));
// Now there's no data to retrieve, even after metadata is provided.
sql_database->SetEmbedderMetadata({kEmbeddingsVersion, kEmbeddingsSize},
GetEncryptorInstance());
EXPECT_FALSE(sql_database->GetPassages(1));
EXPECT_EQ(GetEmbeddingCount(sql_database.get()), 0U);
}
}
TEST_F(HistoryEmbeddingsSqlDatabaseTest, GetUrlData) {
auto sql_database = MakeDatabase();
sql_database->SetEmbedderMetadata({kEmbeddingsVersion, kEmbeddingsSize},
GetEncryptorInstance());
{
UrlData url_data(1, 1, base::Time::Now());
url_data.passages.add_passages("fake passage 1");
url_data.passages.add_passages("fake passage 2");
url_data.passages.add_passages("fake passage 3");
url_data.embeddings.push_back(FakeEmbedding());
url_data.embeddings.push_back(FakeEmbedding());
url_data.embeddings.push_back(FakeEmbedding());
EXPECT_TRUE(sql_database->AddUrlData(url_data));
}
{
UrlData data = sql_database->GetUrlData(1).value();
CHECK_EQ(data.url_id, 1);
CHECK_EQ(data.visit_id, 1);
CHECK_EQ(data.passages.passages_size(), 3);
CHECK_EQ(data.embeddings.size(), 3u);
EXPECT_EQ(data.passages.passages(0), "fake passage 1");
EXPECT_EQ(data.passages.passages(1), "fake passage 2");
EXPECT_EQ(data.passages.passages(2), "fake passage 3");
}
// It's also possible to have passages but no embeddings, for example
// during a database rebuild.
sql_database->DeleteAllData(false, true);
{
UrlData data = sql_database->GetUrlData(1).value();
CHECK_EQ(data.url_id, 1);
CHECK_EQ(data.visit_id, 1);
CHECK_EQ(data.passages.passages_size(), 3);
CHECK_EQ(data.embeddings.size(), 0u);
EXPECT_EQ(data.passages.passages(0), "fake passage 1");
EXPECT_EQ(data.passages.passages(1), "fake passage 2");
EXPECT_EQ(data.passages.passages(2), "fake passage 3");
}
// Absent `url_id` returns std::nullopt.
EXPECT_FALSE(sql_database->GetUrlData(2).has_value());
}
TEST_F(HistoryEmbeddingsSqlDatabaseTest, IterationSkipsAndReportsMismatches) {
auto sql_database = MakeDatabase();
sql_database->SetEmbedderMetadata({kEmbeddingsVersion, kEmbeddingsSize},
GetEncryptorInstance());
// Write embeddings.
auto url_datas = std::to_array<UrlData>({
UrlData(1, 1, base::Time::Now()),
UrlData(2, 2, base::Time::Now()),
});
url_datas[0].passages.add_passages("data 0 passage 0");
url_datas[0].embeddings.push_back(FakeEmbedding());
url_datas[1].passages.add_passages("data 1 passage 0");
url_datas[1].passages.add_passages("data 1 passage 1");
url_datas[1].embeddings.push_back(FakeEmbedding());
url_datas[1].embeddings.push_back(FakeEmbedding());
// Add one too many embeddings to trigger a mismatch.
url_datas[1].embeddings.push_back(FakeEmbedding());
EXPECT_TRUE(sql_database->AddUrlData(url_datas[0]));
EXPECT_TRUE(sql_database->AddAnyUrlDataForTesting(url_datas[1]));
base::HistogramTester histogram_tester;
int observed = 0;
{
// Iterate through stored data once.
std::unique_ptr<VectorDatabase::UrlDataIterator> iterator =
sql_database->MakeUrlDataIterator({});
EXPECT_TRUE(iterator);
while (iterator->Next()) {
observed++;
}
}
EXPECT_EQ(observed, 1);
histogram_tester.ExpectUniqueSample(
"History.Embeddings.DatabaseIterationSkippedMismatches", 1, 1);
histogram_tester.ExpectUniqueSample(
"History.Embeddings.DatabaseIterationYielded", 1, 1);
}
TEST_F(HistoryEmbeddingsSqlDatabaseTest, OldVisitsAreExpired) {
auto sql_database = MakeDatabase();
sql_database->SetEmbedderMetadata({kEmbeddingsVersion, kEmbeddingsSize},
GetEncryptorInstance());
// Write embeddings; one for an old visit, one for a new visit.
auto url_datas = std::to_array<UrlData>({
UrlData(/*url_id=*/1, /*visit_id=*/1,
base::Time::Now() - base::Days(100)),
UrlData(/*url_id=*/2, /*visit_id=*/2, base::Time::Now()),
});
url_datas[0].passages.add_passages("data 0 passage 0");
url_datas[0].embeddings.push_back(FakeEmbedding());
url_datas[1].passages.add_passages("data 1 passage 0");
url_datas[1].embeddings.push_back(FakeEmbedding());
EXPECT_TRUE(sql_database->AddUrlData(url_datas[0]));
EXPECT_TRUE(sql_database->AddUrlData(url_datas[1]));
// Reset and reload.
sql_database.reset();
sql_database = MakeDatabase();
sql_database->SetEmbedderMetadata({kEmbeddingsVersion, kEmbeddingsSize},
GetEncryptorInstance());
// Read embeddings; only last visit will be found because first expired.
EXPECT_FALSE(sql_database->GetUrlData(/*url_id=*/1).has_value());
EXPECT_TRUE(sql_database->GetUrlData(/*url_id=*/2).has_value());
sql_database.reset();
EXPECT_TRUE(
base::PathExists(history_dir_.GetPath().Append(kHistoryEmbeddingsName)));
}
} // namespace history_embeddings