blob: 26436010a0dd79ae32b1269822e1f0c4b6b825b7 [file]
// Copyright 2026 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <memory>
#include <utility>
#include <vector>
#include "base/functional/bind.h"
#include "base/functional/callback.h"
#include "base/memory/raw_ptr.h"
#include "base/run_loop.h"
#include "base/test/bind.h"
#include "base/test/gmock_callback_support.h"
#include "base/test/run_until.h"
#include "base/test/task_environment.h"
#include "base/test/test_future.h"
#include "base/types/expected.h"
#include "components/private_ai/client_impl.h"
#include "components/private_ai/common/private_ai_logger.h"
#include "components/private_ai/connection_basic.h"
#include "components/private_ai/connection_factory_impl.h"
#include "components/private_ai/connection_metrics.h"
#include "components/private_ai/connection_timeout.h"
#include "components/private_ai/connection_token_attestation.h"
#include "components/private_ai/error_code.h"
#include "components/private_ai/private_ai_common.h"
#include "components/private_ai/proto/private_ai.pb.h"
#include "components/private_ai/secure_channel.h"
#include "components/private_ai/testing/fake_secure_channel.h"
#include "components/private_ai/testing/fake_token_manager.h"
#include "services/network/test/test_network_context.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace private_ai {
namespace {
using ::testing::_;
using ::testing::Invoke;
} // namespace
class ClientImplIntegrationTest : public testing::Test {
public:
void SetUp() override {
GURL url("wss://example.com?key=test-api-key");
auto factory = std::make_unique<ConnectionFactoryImpl>(
url, &test_network_context_, &logger_);
factory->EnableTokenAttestation(&token_manager_);
factory->SetSecureChannelFactoryForTesting(base::BindLambdaForTesting(
[this]() -> std::unique_ptr<SecureChannel::Factory> {
return std::make_unique<FakeSecureChannelFactory>(
base::BindRepeating(
&ClientImplIntegrationTest::on_secure_channel_created,
base::Unretained(this)),
base::BindRepeating(
&ClientImplIntegrationTest::on_secure_channel_destroyed,
base::Unretained(this)));
}));
client_ = std::make_unique<ClientImpl>(std::move(factory), &logger_);
}
void TearDown() override {
// Ensure that all SecureChannels are destroyed.
client_.reset();
ASSERT_TRUE(
base::test::RunUntil([&]() { return secure_channels_.empty(); }));
}
void on_secure_channel_created(FakeSecureChannel* secure_channel) {
secure_channels_.push_back(secure_channel);
}
void on_secure_channel_destroyed(FakeSecureChannel* secure_channel) {
std::erase(secure_channels_, secure_channel);
}
FakeSecureChannel* last_secure_channel() {
if (secure_channels_.empty()) {
return nullptr;
}
return secure_channels_.back();
}
protected:
base::test::TaskEnvironment task_environment_{
base::test::TaskEnvironment::TimeSource::MOCK_TIME};
PrivateAiLogger logger_;
network::TestNetworkContext test_network_context_;
FakeTokenManager token_manager_;
std::vector<raw_ptr<FakeSecureChannel>> secure_channels_;
std::unique_ptr<ClientImpl> client_;
};
TEST_F(ClientImplIntegrationTest, FullStackSuccess) {
base::test::TestFuture<base::expected<std::string, ErrorCode>> future;
client_->SendTextRequest(proto::FeatureName::FEATURE_NAME_UNSPECIFIED,
"hello", future.GetCallback(), /*options=*/{});
// 1. Attestation starts. FakeTokenManager gets a request.
token_manager_.RunPendingCallbacks();
// 2. SecureChannel (Basic) is created and gets the attestation request.
auto* channel = last_secure_channel();
ASSERT_TRUE(channel);
// 3. Now the original text request should be sent immediately after
// attestation.
ASSERT_EQ(channel->written_requests().size(), 2u);
EXPECT_TRUE(channel->written_requests()[0].has_anonymous_token_request());
EXPECT_EQ(channel->written_requests()[0].feature_name(),
proto::FeatureName::FEATURE_NAME_CHROME_CLIENT_ATTESTATION);
EXPECT_TRUE(channel->written_requests()[1].has_generate_content_request());
// 4. Respond to text request.
proto::PrivateAiResponse text_response;
text_response.set_request_id(channel->written_requests()[1].request_id());
text_response.mutable_generate_content_response()
->add_candidates()
->mutable_content()
->add_parts()
->set_text("world");
channel->send_back_response(text_response);
// 6. Verify final result.
auto result = future.Get();
ASSERT_TRUE(result.has_value());
EXPECT_EQ(result.value(), "world");
}
TEST_F(ClientImplIntegrationTest, AttestationFailure) {
base::test::TestFuture<base::expected<std::string, ErrorCode>> future;
client_->SendTextRequest(proto::FeatureName::FEATURE_NAME_UNSPECIFIED,
"hello", future.GetCallback(), /*options=*/{});
// 1. Attestation starts.
// Simulate token fetch failure.
token_manager_.RespondToGetAuthToken(std::nullopt);
// 2. Client should receive an error.
auto result = future.Get();
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error(), ErrorCode::kClientAttestationFailed);
}
TEST_F(ClientImplIntegrationTest, Timeout) {
base::test::TestFuture<base::expected<std::string, ErrorCode>> future;
client_->SendTextRequest(proto::FeatureName::FEATURE_NAME_UNSPECIFIED,
"hello", future.GetCallback(),
{.timeout = base::Seconds(5)});
// Provide the token.
token_manager_.RunPendingCallbacks();
auto* channel = last_secure_channel();
ASSERT_TRUE(channel);
// Text request is sent immediately after attestation.
ASSERT_EQ(channel->written_requests().size(), 2u);
EXPECT_TRUE(channel->written_requests()[0].has_anonymous_token_request());
EXPECT_EQ(channel->written_requests()[0].feature_name(),
proto::FeatureName::FEATURE_NAME_CHROME_CLIENT_ATTESTATION);
EXPECT_TRUE(channel->written_requests()[1].has_generate_content_request());
// Wait for timeout.
task_environment_.FastForwardBy(base::Seconds(6));
// Verify timeout error.
auto result = future.Get();
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error(), ErrorCode::kTimeout);
}
TEST_F(ClientImplIntegrationTest, ConcurrentRequestsDuringAttestation) {
base::test::TestFuture<base::expected<std::string, ErrorCode>> future1;
base::test::TestFuture<base::expected<std::string, ErrorCode>> future2;
client_->SendTextRequest(proto::FeatureName::FEATURE_NAME_UNSPECIFIED,
"request1", future1.GetCallback(), /*options=*/{});
client_->SendTextRequest(proto::FeatureName::FEATURE_NAME_UNSPECIFIED,
"request2", future2.GetCallback(), /*options=*/{});
// 1. Attestation starts (only one token fetch should be triggered).
token_manager_.RunPendingCallbacks();
auto* channel = last_secure_channel();
ASSERT_TRUE(channel);
// 2. Now both requests should be sent immediately after attestation.
ASSERT_EQ(channel->written_requests().size(), 3u);
EXPECT_TRUE(channel->written_requests()[0].has_anonymous_token_request());
EXPECT_EQ(channel->written_requests()[0].feature_name(),
proto::FeatureName::FEATURE_NAME_CHROME_CLIENT_ATTESTATION);
// Handle request 1
EXPECT_EQ(channel->written_requests()[1]
.generate_content_request()
.contents(0)
.parts(0)
.text(),
"request1");
int32_t id1 = channel->written_requests()[1].request_id();
proto::PrivateAiResponse resp1;
resp1.set_request_id(id1);
resp1.mutable_generate_content_response()
->add_candidates()
->mutable_content()
->add_parts()
->set_text("response1");
channel->send_back_response(resp1);
// Handle request 2
EXPECT_EQ(channel->written_requests()[2]
.generate_content_request()
.contents(0)
.parts(0)
.text(),
"request2");
int32_t id2 = channel->written_requests()[2].request_id();
proto::PrivateAiResponse resp2;
resp2.set_request_id(id2);
resp2.mutable_generate_content_response()
->add_candidates()
->mutable_content()
->add_parts()
->set_text("response2");
channel->send_back_response(resp2);
// 4. Verify both results.
EXPECT_EQ(future1.Get().value(), "response1");
EXPECT_EQ(future2.Get().value(), "response2");
}
TEST_F(ClientImplIntegrationTest, DisconnectDuringAttestation) {
base::test::TestFuture<base::expected<std::string, ErrorCode>> future;
client_->SendTextRequest(proto::FeatureName::FEATURE_NAME_UNSPECIFIED,
"hello", future.GetCallback(), /*options=*/{});
// 1. Attestation starts.
token_manager_.RunPendingCallbacks();
auto* channel = last_secure_channel();
ASSERT_TRUE(channel);
// 2. Simulate channel disconnect before responding to attestation.
channel->send_back_error(ErrorCode::kNetworkError);
// 3. The original request should fail with the disconnect error.
ASSERT_TRUE(future.IsReady());
auto result = future.Get();
ASSERT_FALSE(result.has_value());
// Our heuristic correctly rewrites this early error (before first successful
// response) into kClientAttestationFailed.
EXPECT_EQ(result.error(), ErrorCode::kClientAttestationFailed);
}
TEST_F(ClientImplIntegrationTest, ClientDestroyedDuringAttestation) {
base::test::TestFuture<base::expected<std::string, ErrorCode>> future;
client_->SendTextRequest(proto::FeatureName::FEATURE_NAME_UNSPECIFIED,
"hello", future.GetCallback(), /*options=*/{});
// 1. Attestation starts.
token_manager_.RunPendingCallbacks();
// 2. Destroy the client while attestation is pending.
client_.reset();
// 3. The request should be resolved with kDestroyed.
ASSERT_TRUE(future.IsReady());
auto result = future.Get();
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error(), ErrorCode::kDestroyed);
}
TEST_F(ClientImplIntegrationTest, AttestationTimedOut) {
base::test::TestFuture<base::expected<std::string, ErrorCode>> future;
client_->SendTextRequest(proto::FeatureName::FEATURE_NAME_UNSPECIFIED,
"hello", future.GetCallback(),
{.timeout = base::Seconds(5)});
// 1. Attestation starts.
token_manager_.RunPendingCallbacks();
auto* channel = last_secure_channel();
ASSERT_TRUE(channel);
// 2. Wait for the request to time out.
task_environment_.FastForwardBy(base::Seconds(10));
// 3. Result should be kTimeout because the actual request timed out.
ASSERT_TRUE(future.IsReady());
auto result = future.Get();
ASSERT_FALSE(result.has_value());
EXPECT_EQ(result.error(), ErrorCode::kTimeout);
}
} // namespace private_ai