blob: 4d6051467e1e378eeb83d70136799c5f3a428807 [file] [log] [blame]
// Copyright 2025 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/test/embedded_test_server/http_connect_proxy_handler.h"
#include <stdint.h>
#include <memory>
#include <optional>
#include <set>
#include "base/check_op.h"
#include "base/containers/flat_set.h"
#include "base/containers/span.h"
#include "base/containers/unique_ptr_adapters.h"
#include "base/functional/bind.h"
#include "base/location.h"
#include "base/memory/raw_ptr.h"
#include "base/strings/strcat.h"
#include "base/task/sequenced_task_runner.h"
#include "net/base/address_list.h"
#include "net/base/host_port_pair.h"
#include "net/base/io_buffer.h"
#include "net/base/ip_address.h"
#include "net/base/net_errors.h"
#include "net/http/http_status_code.h"
#include "net/log/net_log_source.h"
#include "net/socket/stream_socket.h"
#include "net/socket/tcp_client_socket.h"
#include "net/test/embedded_test_server/http_connection.h"
#include "net/test/embedded_test_server/http_request.h"
#include "net/test/embedded_test_server/http_response.h"
#include "net/traffic_annotation/network_traffic_annotation_test_helper.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "url/origin.h"
namespace net::test_server {
class HttpConnectProxyHandler::ConnectTunnel {
public:
static constexpr size_t kCapacity = 32 * 1024;
using DeleteCallback = base::OnceCallback<void(ConnectTunnel*)>;
ConnectTunnel(HttpConnectProxyHandler* http_proxy_handler,
std::unique_ptr<StreamSocket> socket)
: http_proxy_handler_(http_proxy_handler), socket_(std::move(socket)) {}
~ConnectTunnel() = default;
// Tries to establish a connection to localhost on `dest_port`, and on
// success, tells the client socket a tunnel was successfully established, and
// starts tunnelling data between the connections.
void Start(uint16_t dest_port) {
dest_socket_ = std::make_unique<TCPClientSocket>(
AddressList::CreateFromIPAddress(IPAddress::IPv4Localhost(), dest_port),
/*socket_performance_watcher=*/nullptr,
/*network_quality_estimator=*/nullptr, /*net_log=*/nullptr,
NetLogSource());
int result = dest_socket_->Connect(base::BindOnce(
&ConnectTunnel::OnConnectComplete, base::Unretained(this)));
if (result != ERR_IO_PENDING) {
OnConnectComplete(result);
}
}
private:
void OnConnectComplete(int result) {
// If unable to connect, write a bad gateway error to `socket_` before
// deleting `this`.
if (result != OK) {
VLOG(1) << "Failed to establish tunnel connection.";
BasicHttpResponse response;
response.set_code(HttpStatusCode::HTTP_BAD_GATEWAY);
response.set_reason("Bad Gateway");
std::string response_string = response.ToResponseString();
scoped_refptr<GrowableIOBuffer> buffer =
base::MakeRefCounted<GrowableIOBuffer>();
buffer->SetCapacity(response_string.size());
buffer->span().copy_prefix_from(base::as_byte_span(response_string));
DoWrite(/*src=*/nullptr, /*dest=*/socket_.get(), std::move(buffer),
response_string.size());
return;
}
// Write HTTP headers to client socket to indicate the connect succeeded,
// and then start tunnelling.
BasicHttpResponse response;
response.set_reason("Connection established");
StartTunneling(/*src=*/dest_socket_.get(), /*dest=*/socket_.get(),
response.ToResponseString());
// Start tunneling from client socket to destination immediately, no need to
// write anything else.
StartTunneling(/*src=*/socket_.get(), /*dest=*/dest_socket_.get());
}
// Starts reading from `src` and writing that data to `dest`. If
// `initial_data` is provided, writes that `dest` before writing to `src`.
// Since a CONNECT proxy passes data in both directions, this needs to be
// called twice, flipping `src` and `dest` between calls, to fully set up the
// tunnelling.
void StartTunneling(StreamSocket* src,
StreamSocket* dest,
std::string initial_data = {}) {
scoped_refptr<GrowableIOBuffer> buffer =
base::MakeRefCounted<GrowableIOBuffer>();
buffer->SetCapacity(std::max(kCapacity, initial_data.size()));
if (!initial_data.empty()) {
// Start with a write, if `initial_data` is provided.
buffer->span().copy_prefix_from(base::as_byte_span(initial_data));
DoWrite(src, dest, std::move(buffer), initial_data.size());
return;
}
DoRead(src, dest, std::move(buffer));
}
// Try to read data from `src`. Once data is read, write it all to `dest`, and
// then repeat the process, until an error is encountered. Uses
// GrowableIOBuffer because it can track an offset that indicates how much
// data has been written. DrainableIOBuffer can do the same, but requires a
// nested IOBuffer, so is a little more complicated to us.
void DoRead(StreamSocket* src,
StreamSocket* dest,
scoped_refptr<GrowableIOBuffer> buffer) {
int result =
src->Read(buffer.get(), buffer->size(),
base::BindOnce(&ConnectTunnel::OnReadComplete,
base::Unretained(this), src, dest, buffer));
if (result == ERR_IO_PENDING) {
return;
}
OnReadComplete(src, dest, std::move(buffer), result);
}
// Called when a read from `src` completes. On error, tears down the socket.
// Otherwise, starts writing the data to `dest`, and will start reading from
// `src` again, once all data is written.
void OnReadComplete(StreamSocket* src,
StreamSocket* dest,
scoped_refptr<GrowableIOBuffer> buffer,
int result) {
CHECK_NE(result, ERR_IO_PENDING);
if (result <= 0) {
// On error / close, close both sockets - this behavior is good enough,
// since the client side (Chrome's network stack) only closes the write
// pipe when it's done reading, and since this code doesn't read from the
// destination pipe (likely to another EmbeddedTestServer) while there's
// data in the buffer to write to the client pipe, all data will be
// written before the EmbeddedTestServer closing the pipe will be
// observed.
http_proxy_handler_->DeleteTunnel(this);
return;
}
DoWrite(src, dest, std::move(buffer), result);
}
// Writes `remaining_bytes` from `buffer` to `dest`. Once all data has been
// written, will start reading from `src` again. If `src` is nullptr, writes
// data to `dest`, and destroys the `ConnectTunnel` once everything has been
// written.
void DoWrite(StreamSocket* src,
StreamSocket* dest,
scoped_refptr<GrowableIOBuffer> buffer,
int remaining_bytes) {
CHECK_GE(remaining_bytes, 0);
int result = dest->Write(
buffer.get(), remaining_bytes,
base::BindOnce(&ConnectTunnel::OnWriteComplete, base::Unretained(this),
src, dest, buffer, remaining_bytes),
TRAFFIC_ANNOTATION_FOR_TESTS);
if (result == ERR_IO_PENDING) {
return;
}
OnWriteComplete(src, dest, std::move(buffer), remaining_bytes, result);
}
// Called once data has been written to `dest` or there was a write error. On
// error, tears down `this`. Otherwise, keeps writing until all data has been
// written, and then starts reading from `src` again.
void OnWriteComplete(StreamSocket* src,
StreamSocket* dest,
scoped_refptr<GrowableIOBuffer> buffer,
int remaining_bytes,
int result) {
CHECK_NE(result, ERR_IO_PENDING);
if (result < 0) {
// See OnReadComplete() for explanation on why this is ok to do.
http_proxy_handler_->DeleteTunnel(this);
return;
}
CHECK_LE(result, remaining_bytes);
buffer->DidConsume(result);
remaining_bytes -= result;
if (remaining_bytes > 0) {
DoWrite(src, dest, std::move(buffer), remaining_bytes);
return;
}
// `src` will be nullptr when writing a connect error. In that case, once
// everything has been written, delete `this` to close `socket_`.
if (!src) {
http_proxy_handler_->DeleteTunnel(this);
return;
}
buffer->set_offset(0);
DoRead(src, dest, std::move(buffer));
}
raw_ptr<HttpConnectProxyHandler> http_proxy_handler_;
// The socket to the client (Chrome's network stack).
std::unique_ptr<StreamSocket> socket_;
// The socket to the server (typically another EmbeddedTestServer instance).
std::unique_ptr<TCPClientSocket> dest_socket_;
};
HttpConnectProxyHandler::HttpConnectProxyHandler(
base::span<const HostPortPair> proxied_destinations)
: proxied_destinations_(proxied_destinations.begin(),
proxied_destinations.end()) {}
HttpConnectProxyHandler::~HttpConnectProxyHandler() = default;
bool HttpConnectProxyHandler::HandleProxyRequest(HttpConnection& connection,
const HttpRequest& request) {
// This class only supports HTTP/1.x.
CHECK_EQ(connection.protocol(), HttpConnection::Protocol::kHttp1);
CHECK_EQ(request.method, METHOD_CONNECT);
// For CONNECT requests, `relative_url` is actually a host and port.
HostPortPair dest = HostPortPair::FromString(request.relative_url);
std::unique_ptr<BasicHttpResponse> error_response;
if (dest.IsEmpty()) {
ADD_FAILURE() << "Invalid CONNECT destination: " << request.relative_url;
// Returning true on error will result in an HTTP error message being
// written to the socket.
return false;
}
if (!proxied_destinations_.contains(dest)) {
// Returning true on error will result in an HTTP error message being
// written to the socket.
return false;
}
auto tunnel = std::make_unique<ConnectTunnel>(this, connection.TakeSocket());
auto tunnel_it = connect_tunnels_.insert(std::move(tunnel)).first;
(*tunnel_it)->Start(dest.port());
return true;
}
void HttpConnectProxyHandler::DeleteTunnel(ConnectTunnel* tunnel) {
auto tunnel_it = connect_tunnels_.find(tunnel);
CHECK(tunnel_it != connect_tunnels_.end());
connect_tunnels_.erase(tunnel_it);
}
} // namespace net::test_server