blob: 20a440d4c728963c7e8388677ba7ca7d426206be [file]
// Copyright 2026 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "components/private_ai/connection_proxy.h"
#include <utility>
#include "base/base64url.h"
#include "base/check.h"
#include "base/logging.h"
#include "base/strings/strcat.h"
#include "base/strings/string_util.h"
#include "base/task/sequenced_task_runner.h"
#include "components/private_ai/common/base64_utils.h"
#include "components/private_ai/common/private_ai_logger.h"
#include "components/private_ai/phosphor/token_manager.h"
#include "content/public/browser/network_service_instance.h"
#include "net/http/http_request_headers.h"
#include "net/proxy_resolution/proxy_config.h"
#include "services/cert_verifier/public/mojom/cert_verifier_service_factory.mojom.h"
#include "services/network/public/mojom/network_service.mojom.h"
#include "services/network/public/mojom/proxy_config.mojom.h"
#include "url/origin.h"
namespace private_ai {
namespace internal {
network::mojom::CustomProxyConfigPtr CreateCustomProxyConfig(
const GURL& proxy_url,
const phosphor::BlindSignedAuthToken& auth_token,
PrivateAiLogger* logger) {
network::mojom::CustomProxyConfigPtr proxy_config =
network::mojom::CustomProxyConfig::New();
// The proxy rule parser expects exactly "[scheme://]host[:port]" and fails or
// ignores if a trailing slash is present like using GURL::spec().
proxy_config->rules.ParseFromString(
url::Origin::Create(proxy_url).Serialize());
std::optional<std::string> formatted_token = ConvertBase64toBase64Url(
auth_token.token, base::Base64UrlEncodePolicy::OMIT_PADDING);
if (!formatted_token) {
logger->LogError(FROM_HERE, "Invalid base64 encoding in private token.");
return nullptr;
}
std::optional<std::string> formatted_extensions = ConvertBase64toBase64Url(
auth_token.encoded_extensions, base::Base64UrlEncodePolicy::OMIT_PADDING);
if (!formatted_extensions) {
logger->LogError(FROM_HERE,
"Invalid base64 encoding in private token extensions.");
return nullptr;
}
net::HttpRequestHeaders headers;
headers.SetHeader(
net::HttpRequestHeaders::kAuthorization,
base::StrCat({"PrivateToken token=\"", *formatted_token,
"\" extensions=\"", *formatted_extensions, "\""}));
proxy_config->connect_tunnel_headers = headers;
proxy_config->should_override_existing_config = true;
proxy_config->allow_non_idempotent_methods = true;
return proxy_config;
}
} // namespace internal
ConnectionProxy::PendingRequest::PendingRequest(proto::PrivateAiRequest request,
base::TimeDelta timeout,
OnRequestCallback callback)
: request(std::move(request)),
timeout(timeout),
callback(std::move(callback)) {}
ConnectionProxy::PendingRequest::~PendingRequest() = default;
ConnectionProxy::PendingRequest::PendingRequest(PendingRequest&&) = default;
ConnectionProxy::PendingRequest& ConnectionProxy::PendingRequest::operator=(
PendingRequest&&) = default;
ConnectionProxy::ConnectionProxy(
const GURL& proxy_url,
PrivateAiLogger* logger,
phosphor::TokenManager* token_manager,
network::mojom::NetworkService* network_service,
InnerConnectionFactory inner_connection_factory,
base::OnceCallback<void(ErrorCode)> on_disconnect)
: proxy_url_(proxy_url),
logger_(logger),
token_manager_(token_manager),
network_service_(network_service),
inner_connection_factory_(std::move(inner_connection_factory)),
on_disconnect_(std::move(on_disconnect)) {
CHECK(proxy_url_.is_valid());
CHECK(logger_);
CHECK(token_manager_);
CHECK(network_service_);
CHECK(inner_connection_factory_);
CHECK(on_disconnect_);
base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE,
base::BindOnce(&ConnectionProxy::FetchToken, weak_factory_.GetWeakPtr()));
}
ConnectionProxy::~ConnectionProxy() = default;
void ConnectionProxy::Send(proto::PrivateAiRequest request,
base::TimeDelta timeout,
OnRequestCallback callback) {
if (is_initializing_) {
pending_requests_.emplace_back(std::move(request), timeout,
std::move(callback));
return;
}
if (!inner_connection_) {
// Initialization failed or connection disconnected.
std::move(callback).Run(base::unexpected(ErrorCode::kError));
return;
}
inner_connection_->Send(std::move(request), timeout, std::move(callback));
}
void ConnectionProxy::OnDestroy(ErrorCode error) {
on_disconnect_.Reset();
auto pending_requests = std::move(pending_requests_);
for (auto& pending : pending_requests) {
std::move(pending.callback).Run(base::unexpected(error));
}
if (inner_connection_) {
inner_connection_->OnDestroy(error);
}
token_manager_ = nullptr;
network_service_ = nullptr;
logger_ = nullptr;
weak_factory_.InvalidateWeakPtrsAndDoom();
}
void ConnectionProxy::CallOnDisconnect(ErrorCode error_code) {
if (on_disconnect_) {
std::move(on_disconnect_).Run(error_code);
}
}
void ConnectionProxy::FetchToken() {
token_manager_->GetAuthTokenForProxy(base::BindOnce(
&ConnectionProxy::OnProxyToken, weak_factory_.GetWeakPtr()));
}
void ConnectionProxy::OnProxyToken(
std::optional<phosphor::BlindSignedAuthToken> auth_token) {
is_initializing_ = false;
if (!auth_token) {
logger_->LogError(FROM_HERE, "Failed to get auth token for proxy.");
CallOnDisconnect(ErrorCode::kError);
return;
}
logger_->LogInfo(FROM_HERE, "Got auth token for proxy. Connecting to " +
proxy_url_.spec());
auto context_params = network::mojom::NetworkContextParams::New();
context_params->cert_verifier_params = content::GetCertVerifierParams(
cert_verifier::mojom::CertVerifierCreationParams::New());
context_params->initial_custom_proxy_config =
internal::CreateCustomProxyConfig(proxy_url_, *auth_token, logger_);
network_service_->CreateNetworkContext(
proxied_context_.BindNewPipeAndPassReceiver(), std::move(context_params));
inner_connection_ =
std::move(inner_connection_factory_).Run(proxied_context_.get());
for (auto& pending : pending_requests_) {
inner_connection_->Send(std::move(pending.request), pending.timeout,
std::move(pending.callback));
}
pending_requests_.clear();
}
} // namespace private_ai