blob: 38078d7015e9471407d081f2ec1d9d3e54c73bd4 [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "chrome/browser/ai/ai_manager_keyed_service.h"
#include <memory>
#include "base/memory/raw_ptr.h"
#include "base/memory/weak_ptr.h"
#include "base/supports_user_data.h"
#include "base/test/mock_callback.h"
#include "chrome/browser/ai/ai_manager_keyed_service_factory.h"
#include "chrome/browser/ai/ai_text_session.h"
#include "chrome/browser/optimization_guide/mock_optimization_guide_keyed_service.h"
#include "chrome/browser/optimization_guide/optimization_guide_keyed_service_factory.h"
#include "chrome/test/base/chrome_render_view_host_test_harness.h"
#include "components/keyed_service/core/keyed_service.h"
#include "components/optimization_guide/core/mock_optimization_guide_model_executor.h"
#include "components/optimization_guide/core/optimization_guide_model_executor.h"
#include "components/optimization_guide/core/optimization_guide_switches.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/blink/public/mojom/ai/ai_manager.mojom-shared.h"
#include "third_party/blink/public/mojom/ai/ai_manager.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_text_session.mojom.h"
#include "third_party/blink/public/mojom/ai/ai_text_session_info.mojom.h"
using optimization_guide::MockSession;
using optimization_guide::MockSessionWrapper;
using testing::_;
using testing::An;
using testing::AtMost;
using testing::Invoke;
using testing::NiceMock;
using testing::Return;
namespace {
class MockSupportsUserData : public base::SupportsUserData {};
const optimization_guide::TokenLimits& GetFakeTokenLimits() {
static const optimization_guide::TokenLimits limits{
.max_tokens = 4096,
.max_context_tokens = 2048,
.max_execute_tokens = 1024,
.max_output_tokens = 1024,
};
return limits;
}
} // namespace
class AIManagerKeyedServiceTest : public ChromeRenderViewHostTestHarness {
public:
void SetUp() override {
ChromeRenderViewHostTestHarness::SetUp();
SetUpOptimizationGuide();
}
void TearDown() override {
mock_optimization_guide_keyed_service_ = nullptr;
ChromeRenderViewHostTestHarness::TearDown();
}
protected:
MockSupportsUserData* mock_host() { return &mock_host_; }
testing::NiceMock<MockOptimizationGuideKeyedService>* MockService() {
return mock_optimization_guide_keyed_service_;
}
private:
void SetUpOptimizationGuide() {
mock_optimization_guide_keyed_service_ =
static_cast<NiceMock<MockOptimizationGuideKeyedService>*>(
OptimizationGuideKeyedServiceFactory::GetInstance()
->SetTestingFactoryAndUse(
profile(),
base::BindRepeating([](content::BrowserContext* context)
-> std::unique_ptr<KeyedService> {
return std::make_unique<
NiceMock<MockOptimizationGuideKeyedService>>();
})));
ON_CALL(*mock_optimization_guide_keyed_service_, StartSession(_, _))
.WillByDefault(
[&] { return std::make_unique<MockSessionWrapper>(&session_); });
ON_CALL(session_, GetTokenLimits()).WillByDefault(GetFakeTokenLimits);
}
raw_ptr<testing::NiceMock<MockOptimizationGuideKeyedService>>
mock_optimization_guide_keyed_service_;
testing::NiceMock<MockSession> session_;
MockSupportsUserData mock_host_;
};
// Tests that involve invalid on-device model file paths should not crash when
// the associated RFH is destroyed.
TEST_F(AIManagerKeyedServiceTest, NoUAFWithInvalidOnDeviceModelPath) {
auto* command_line = base::CommandLine::ForCurrentProcess();
command_line->AppendSwitchASCII(
optimization_guide::switches::kOnDeviceModelExecutionOverride,
"invalid-on-device-model-file-path");
EXPECT_CALL(*MockService(), CanCreateOnDeviceSession(_, _))
.Times(AtMost(1))
.WillOnce(Invoke([](optimization_guide::ModelBasedCapabilityKey feature,
optimization_guide::OnDeviceModelEligibilityReason*
on_device_model_eligibility_reason) {
*on_device_model_eligibility_reason = optimization_guide::
OnDeviceModelEligibilityReason::kFeatureNotEnabled;
return false;
}));
base::MockCallback<blink::mojom::AIManager::CanCreateTextSessionCallback>
callback;
EXPECT_CALL(callback, Run(_))
.Times(AtMost(1))
.WillOnce(Invoke([&](blink::mojom::ModelAvailabilityCheckResult result) {
EXPECT_EQ(
result,
blink::mojom::ModelAvailabilityCheckResult::kNoFeatureNotEnabled);
}));
AIManagerKeyedService* ai_manager =
AIManagerKeyedServiceFactory::GetAIManagerKeyedService(
main_rfh()->GetBrowserContext());
ai_manager->CanCreateTextSession(callback.Get());
// The callback may still be pending, delete the WebContents and destroy the
// associated RFH, which should not result in a UAF.
DeleteContents();
task_environment()->RunUntilIdle();
}
// Tests the `AIUserDataSet`'s behavior of managing the lifetime of
// `AITextSession`s.
TEST_F(AIManagerKeyedServiceTest, AIContextBoundObjectSet) {
base::MockCallback<blink::mojom::AIManager::CreateTextSessionCallback>
callback;
base::RunLoop run_loop;
EXPECT_CALL(callback, Run(_))
.Times(AtMost(1))
.WillOnce(Invoke([&](blink::mojom::AITextSessionInfoPtr result) {
EXPECT_TRUE(result);
run_loop.Quit();
}));
AIManagerKeyedService* ai_manager =
AIManagerKeyedServiceFactory::GetAIManagerKeyedService(
main_rfh()->GetBrowserContext());
mojo::Remote<blink::mojom::AIManager> mock_remote;
mojo::Remote<blink::mojom::AITextSession> mock_session;
ai_manager->AddReceiver(mock_remote.BindNewPipeAndPassReceiver(),
mock_host());
// Initially the `AIUserDataSet` is empty.
base::WeakPtr<AIContextBoundObjectSet> context_bound_objects =
AIContextBoundObjectSet::GetFromContext(mock_host())
->GetWeakPtrForTesting();
ASSERT_EQ(0u, context_bound_objects->GetSizeForTesting());
// After creating one `AITextSession`, the `AIUserDataSet` contains 1
// element.
mock_remote->CreateTextSession(mock_session.BindNewPipeAndPassReceiver(),
nullptr, std::nullopt, callback.Get());
run_loop.Run();
ASSERT_EQ(1u, context_bound_objects->GetSizeForTesting());
// After resetting the session, the `AIUserDataSet` becomes empty again and
// should be removed from the context.
mock_session.reset();
task_environment()->RunUntilIdle();
ASSERT_FALSE(context_bound_objects);
}