blob: b3b994b33df6d5bd69772724f4484af924e8719f [file] [log] [blame]
// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/contextual_tasks/contextual_tasks_context_service.h"
#include "base/test/metrics/histogram_tester.h"
#include "base/test/scoped_feature_list.h"
#include "base/test/test_future.h"
#include "chrome/browser/contextual_tasks/contextual_tasks_context_service_factory.h"
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service_factory.h"
#include "chrome/browser/page_content_annotations/page_content_extraction_service_factory.h"
#include "chrome/browser/passage_embeddings/page_embeddings_service.h"
#include "chrome/browser/passage_embeddings/page_embeddings_service_factory.h"
#include "chrome/browser/profiles/profile.h"
#include "chrome/browser/ui/browser_commands.h"
#include "chrome/test/base/in_process_browser_test.h"
#include "chrome/test/base/ui_test_utils.h"
#include "components/contextual_tasks/public/features.h"
#include "components/optimization_guide/proto/features/common_quality_data.pb.h"
#include "components/passage_embeddings/passage_embeddings_features.h"
#include "components/passage_embeddings/passage_embeddings_test_util.h"
#include "content/public/test/browser_test.h"
#include "content/public/test/browser_test_utils.h"
#include "net/dns/mock_host_resolver.h"
#include "testing/gmock/include/gmock/gmock.h"
namespace contextual_tasks {
using ::testing::_;
using ::testing::Return;
class FakeEmbedderMetadataProvider
: public passage_embeddings::EmbedderMetadataProvider {
public:
FakeEmbedderMetadataProvider() = default;
~FakeEmbedderMetadataProvider() override = default;
// passage_embeddings::EmbedderMetadataProvider:
void AddObserver(
passage_embeddings::EmbedderMetadataObserver* observer) override {
observer_list_.AddObserver(observer);
}
void RemoveObserver(
passage_embeddings::EmbedderMetadataObserver* observer) override {
observer_list_.RemoveObserver(observer);
}
void NotifyObservers() {
observer_list_.Notify(
&passage_embeddings::EmbedderMetadataObserver::EmbedderMetadataUpdated,
passage_embeddings::EmbedderMetadata(1, 768));
}
private:
base::ObserverList<passage_embeddings::EmbedderMetadataObserver>
observer_list_;
};
class FakeEmbedder : public passage_embeddings::TestEmbedder {
public:
FakeEmbedder() = default;
~FakeEmbedder() override = default;
// passage_embeddings::TestEmbedder:
passage_embeddings::Embedder::TaskId ComputePassagesEmbeddings(
passage_embeddings::PassagePriority priority,
std::vector<std::string> passages,
ComputePassagesEmbeddingsCallback callback) override {
if (status_ == passage_embeddings::ComputeEmbeddingsStatus::kSuccess) {
passage_embeddings::TestEmbedder::ComputePassagesEmbeddings(
priority, passages, std::move(callback));
return 0;
}
std::move(callback).Run(passages, {}, 0, status_);
return 0;
}
void set_status(passage_embeddings::ComputeEmbeddingsStatus status) {
status_ = status;
}
private:
passage_embeddings::ComputeEmbeddingsStatus status_ =
passage_embeddings::ComputeEmbeddingsStatus::kSuccess;
};
class MockPageEmbeddingsService
: public passage_embeddings::PageEmbeddingsService {
public:
MockPageEmbeddingsService(
page_content_annotations::PageContentExtractionService*
page_content_extraction_service)
: PageEmbeddingsService(page_content_extraction_service) {}
~MockPageEmbeddingsService() override = default;
MOCK_METHOD(std::vector<passage_embeddings::PassageEmbedding>,
GetEmbeddings,
(content::WebContents * web_contents),
(const override));
};
class ContextualTasksContextServiceTest : public InProcessBrowserTest {
public:
ContextualTasksContextServiceTest() { InitializeFeatureList(); }
~ContextualTasksContextServiceTest() override {
scoped_feature_list_.Reset();
}
virtual void InitializeFeatureList() {
scoped_feature_list_.InitWithFeaturesAndParameters(
{{kContextualTasksContext,
{{{"ContextualTasksContextOnlyUseTitles", "false"}}}},
{passage_embeddings::kPassageEmbedder, {}}},
/*disabled_features=*/{});
}
void SetUpOnMainThread() override {
host_resolver()->AddRule("*", "127.0.0.1");
InProcessBrowserTest::SetUpOnMainThread();
embedded_test_server()->ServeFilesFromSourceDirectory(
"chrome/test/data/optimization_guide");
ASSERT_TRUE(embedded_test_server()->Start());
}
void TearDown() override {
InProcessBrowserTest::TearDown();
}
void SetUpBrowserContextKeyedServices(
content::BrowserContext* browser_context) override {
passage_embeddings::PageEmbeddingsServiceFactory::GetInstance()
->SetTestingFactoryAndUse(
browser_context,
base::BindRepeating([](content::BrowserContext* browser_context)
-> std::unique_ptr<KeyedService> {
return std::make_unique<
testing::NiceMock<MockPageEmbeddingsService>>(
page_content_annotations::
PageContentExtractionServiceFactory::GetForProfile(
Profile::FromBrowserContext(browser_context)));
}));
ContextualTasksContextServiceFactory::GetInstance()
->SetTestingFactoryAndUse(
browser_context,
base::BindRepeating(
[](passage_embeddings::EmbedderMetadataProvider*
embedder_metadata_provider,
passage_embeddings::Embedder* embedder,
content::BrowserContext* context)
-> std::unique_ptr<KeyedService> {
Profile* profile = Profile::FromBrowserContext(context);
return std::make_unique<ContextualTasksContextService>(
profile,
passage_embeddings::PageEmbeddingsServiceFactory::
GetForProfile(profile),
embedder_metadata_provider, embedder,
OptimizationGuideKeyedServiceFactory::GetForProfile(
profile));
},
&embedder_metadata_provider_, &embedder_));
}
ContextualTasksContextService* service() {
return ContextualTasksContextServiceFactory::GetForProfile(
browser()->profile());
}
MockPageEmbeddingsService* page_embeddings_service() {
return static_cast<MockPageEmbeddingsService*>(
passage_embeddings::PageEmbeddingsServiceFactory::GetForProfile(
browser()->profile()));
}
void NotifyEmbedderMetadata() {
embedder_metadata_provider_.NotifyObservers();
}
void UpdateEmbedderStatus(
passage_embeddings::ComputeEmbeddingsStatus status) {
embedder_.set_status(status);
}
passage_embeddings::Embedding CreateFakeEmbedding(float value) {
constexpr size_t kMockPassageWordCount = 10;
passage_embeddings::Embedding embedding(std::vector<float>(
passage_embeddings::kEmbeddingsModelOutputSize, value));
embedding.Normalize();
embedding.SetPassageWordCount(kMockPassageWordCount);
return embedding;
}
void NavigateToValidURL() {
// Navigate to a valid URL.
content::WebContents* web_contents =
browser()->tab_strip_model()->GetActiveWebContents();
GURL url(embedded_test_server()->GetURL("a.test",
"/optimization_guide/hello.html"));
content::NavigateToURLBlockUntilNavigationsComplete(web_contents, url, 1);
}
protected:
base::test::ScopedFeatureList scoped_feature_list_;
private:
FakeEmbedderMetadataProvider embedder_metadata_provider_;
FakeEmbedder embedder_;
};
IN_PROC_BROWSER_TEST_F(ContextualTasksContextServiceTest, NoEmbedder) {
base::HistogramTester histogram_tester;
NavigateToValidURL();
base::test::TestFuture<std::vector<content::WebContents*>> future;
service()->GetRelevantTabsForQuery("some text", future.GetCallback());
EXPECT_TRUE(future.Get().empty());
histogram_tester.ExpectTotalCount("ContextualTasks.Context.RelevantTabsCount",
0);
histogram_tester.ExpectTotalCount(
"ContextualTasks.Context.ContextCalculationLatency", 0);
histogram_tester.ExpectUniqueSample(
"ContextualTasks.Context.ContextDeterminationStatus",
ContextDeterminationStatus::kEmbedderNotAvailable, 1);
}
IN_PROC_BROWSER_TEST_F(ContextualTasksContextServiceTest, EmbedderFailed) {
base::HistogramTester histogram_tester;
NavigateToValidURL();
NotifyEmbedderMetadata();
UpdateEmbedderStatus(
passage_embeddings::ComputeEmbeddingsStatus::kExecutionFailure);
base::test::TestFuture<std::vector<content::WebContents*>> future;
service()->GetRelevantTabsForQuery("some text", future.GetCallback());
EXPECT_TRUE(future.Get().empty());
histogram_tester.ExpectTotalCount("ContextualTasks.Context.RelevantTabsCount",
0);
histogram_tester.ExpectTotalCount(
"ContextualTasks.Context.ContextCalculationLatency", 0);
histogram_tester.ExpectUniqueSample(
"ContextualTasks.Context.ContextDeterminationStatus",
ContextDeterminationStatus::kQueryEmbeddingFailed, 1);
}
IN_PROC_BROWSER_TEST_F(ContextualTasksContextServiceTest,
SuccessQueryNoPageEmbeddings) {
base::HistogramTester histogram_tester;
NavigateToValidURL();
NotifyEmbedderMetadata();
base::test::TestFuture<std::vector<content::WebContents*>> future;
service()->GetRelevantTabsForQuery("some text", future.GetCallback());
EXPECT_TRUE(future.Get().empty());
histogram_tester.ExpectUniqueSample(
"ContextualTasks.Context.RelevantTabsCount", 0, 1);
histogram_tester.ExpectTotalCount(
"ContextualTasks.Context.ContextCalculationLatency", 1);
histogram_tester.ExpectUniqueSample(
"ContextualTasks.Context.ContextDeterminationStatus",
ContextDeterminationStatus::kSuccess, 1);
}
IN_PROC_BROWSER_TEST_F(ContextualTasksContextServiceTest, Success) {
base::HistogramTester histogram_tester;
NavigateToValidURL();
NotifyEmbedderMetadata();
std::vector<passage_embeddings::PassageEmbedding> fake_page_embeddings = {
// Not match.
{std::make_pair("passage 1",
passage_embeddings::PassageType::kPageContent),
CreateFakeEmbedding(0.1f)},
// Match - active tab is added.
{std::make_pair("passage 2",
passage_embeddings::PassageType::kPageContent),
CreateFakeEmbedding(1.0f)},
// Match - should be skipped.
{std::make_pair("passage 3",
passage_embeddings::PassageType::kPageContent),
CreateFakeEmbedding(1.0f)}};
EXPECT_CALL(*page_embeddings_service(), GetEmbeddings(_))
.WillOnce(Return(fake_page_embeddings));
base::test::TestFuture<std::vector<content::WebContents*>> future;
service()->GetRelevantTabsForQuery("some text", future.GetCallback());
EXPECT_EQ(1u, future.Get().size());
histogram_tester.ExpectUniqueSample(
"ContextualTasks.Context.RelevantTabsCount", 1, 1);
histogram_tester.ExpectTotalCount(
"ContextualTasks.Context.ContextCalculationLatency", 1);
histogram_tester.ExpectUniqueSample(
"ContextualTasks.Context.ContextDeterminationStatus",
ContextDeterminationStatus::kSuccess, 1);
}
IN_PROC_BROWSER_TEST_F(ContextualTasksContextServiceTest, SkipsNonHttp) {
base::HistogramTester histogram_tester;
NotifyEmbedderMetadata();
EXPECT_CALL(*page_embeddings_service(), GetEmbeddings(_)).Times(0);
base::test::TestFuture<std::vector<content::WebContents*>> future;
service()->GetRelevantTabsForQuery("some text", future.GetCallback());
EXPECT_TRUE(future.Get().empty());
histogram_tester.ExpectUniqueSample(
"ContextualTasks.Context.RelevantTabsCount", 0, 1);
histogram_tester.ExpectTotalCount(
"ContextualTasks.Context.ContextCalculationLatency", 1);
}
class ContextualTasksContextServiceTitlesOnlyTest
: public ContextualTasksContextServiceTest {
public:
void InitializeFeatureList() override {
scoped_feature_list_.InitWithFeaturesAndParameters(
{{kContextualTasksContext,
{{{"ContextualTasksContextOnlyUseTitles", "true"}}}},
{passage_embeddings::kPassageEmbedder, {}}},
/*disabled_features=*/{});
}
};
IN_PROC_BROWSER_TEST_F(ContextualTasksContextServiceTitlesOnlyTest, Success) {
NotifyEmbedderMetadata();
NavigateToValidURL();
std::vector<passage_embeddings::PassageEmbedding> fake_page_embeddings = {
// Not match.
{std::make_pair("passage 1",
passage_embeddings::PassageType::kPageContent),
CreateFakeEmbedding(0.1f)},
// Not added - page content skipped.
{std::make_pair("passage 2",
passage_embeddings::PassageType::kPageContent),
CreateFakeEmbedding(1.0f)},
// Added - title passage matches.
{std::make_pair("passage 3", passage_embeddings::PassageType::kTitle),
CreateFakeEmbedding(1.0f)}};
EXPECT_CALL(*page_embeddings_service(), GetEmbeddings(_))
.WillOnce(Return(fake_page_embeddings));
base::test::TestFuture<std::vector<content::WebContents*>> future;
service()->GetRelevantTabsForQuery("some text", future.GetCallback());
EXPECT_EQ(1u, future.Get().size());
}
} // namespace contextual_tasks