blob: f8eac043cc5b94b779ed04b3570b40eca9befd6c [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/test/embedded_test_server/websocket_connection.h"
#include <stdint.h>
#include "base/compiler_specific.h"
#include "base/containers/extend.h"
#include "base/containers/span.h"
#include "base/containers/span_writer.h"
#include "base/functional/bind.h"
#include "base/logging.h"
#include "base/memory/scoped_refptr.h"
#include "base/numerics/byte_conversions.h"
#include "base/numerics/safe_conversions.h"
#include "base/strings/strcat.h"
#include "net/base/net_errors.h"
#include "net/socket/socket.h"
#include "net/socket/stream_socket.h"
#include "net/test/embedded_test_server/websocket_handler.h"
#include "net/test/embedded_test_server/websocket_message_assembler.h"
#include "net/websockets/websocket_frame.h"
#include "net/websockets/websocket_frame_parser.h"
#include "net/websockets/websocket_handshake_challenge.h"
namespace net::test_server {
WebSocketConnection::WebSocketConnection(std::unique_ptr<StreamSocket> socket,
std::string_view sec_websocket_key,
EmbeddedTestServer* server)
: stream_socket_(std::move(socket)),
// Register a shutdown closure to safely disconnect this connection when
// the
// server shuts down. base::Unretained is safe here because:
// 1. The shutdown closure is registered during the construction of the
// WebSocketConnection object, ensuring `this` is fully initialized.
// 2. The lifetime of the closure is tied to the `WebSocketConnection`
// object via `shutdown_subscription_`, which ensures that the closure
// is automatically unregistered when the object is destroyed.
// 3. DisconnectImmediately() ensures safe cleanup by resetting the socket
// and marking the connection state as closed.
shutdown_subscription_(server->RegisterShutdownClosure(
base::BindOnce(&WebSocketConnection::DisconnectImmediately,
base::Unretained(this)))) {
CHECK(stream_socket_);
response_headers_.emplace_back("Upgrade", "websocket");
response_headers_.emplace_back("Connection", "Upgrade");
response_headers_.emplace_back(
"Sec-WebSocket-Accept",
ComputeSecWebSocketAccept(std::string(sec_websocket_key)));
}
WebSocketConnection::~WebSocketConnection() {
DisconnectImmediately();
}
void WebSocketConnection::SetResponseHeader(std::string_view name,
std::string_view value) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
CHECK(stream_socket_);
for (auto& header : response_headers_) {
if (header.first == name) {
header.second = value;
return;
}
}
response_headers_.emplace_back(name, value);
}
void WebSocketConnection::SendTextMessage(std::string_view message) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
CHECK(stream_socket_);
CHECK(base::IsStringUTF8AllowingNoncharacters(message));
scoped_refptr<IOBufferWithSize> frame = CreateTextFrame(message);
SendInternal(std::move(frame), /*wait_for_handshake=*/true);
}
void WebSocketConnection::SendBinaryMessage(base::span<const uint8_t> message) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
CHECK(stream_socket_);
scoped_refptr<IOBufferWithSize> frame = CreateBinaryFrame(message);
SendInternal(std::move(frame), /*wait_for_handshake=*/true);
}
void WebSocketConnection::StartClosingHandshake(std::optional<uint16_t> code,
std::string_view message) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!stream_socket_) {
DVLOG(2) << "Attempted to start closing handshake, but socket is null.";
return;
}
DVLOG(3) << "Starting closing handshake. Code: "
<< (code ? base::NumberToString(*code) : "none")
<< ", Message: " << message;
if (!code) {
CHECK(base::IsStringUTF8AllowingNoncharacters(message));
SendInternal(BuildWebSocketFrame(base::span<const uint8_t>(),
WebSocketFrameHeader::kOpCodeClose),
/*wait_for_handshake=*/true);
state_ = WebSocketState::kWaitingForClientClose;
return;
}
scoped_refptr<IOBufferWithSize> close_frame = CreateCloseFrame(code, message);
SendInternal(std::move(close_frame), /*wait_for_handshake=*/true);
state_ = WebSocketState::kWaitingForClientClose;
}
void WebSocketConnection::RespondToCloseFrame(std::optional<uint16_t> code,
std::string_view message) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (state_ == WebSocketState::kClosed) {
DVLOG(2) << "Attempted to respond to close frame, but connection is "
"already closed.";
return;
}
CHECK(base::IsStringUTF8AllowingNoncharacters(message));
scoped_refptr<IOBufferWithSize> close_frame = CreateCloseFrame(code, message);
SendInternal(std::move(close_frame), /*wait_for_handshake=*/false);
DisconnectAfterAnyWritesDone();
}
void WebSocketConnection::SendPing(base::span<const uint8_t> payload) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
scoped_refptr<IOBufferWithSize> ping_frame = CreatePingFrame(payload);
SendInternal(std::move(ping_frame), /*wait_for_handshake=*/true);
}
void WebSocketConnection::SendPong(base::span<const uint8_t> payload) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
scoped_refptr<IOBufferWithSize> pong_frame = CreatePongFrame(payload);
SendInternal(std::move(pong_frame), /*wait_for_handshake=*/true);
}
void WebSocketConnection::DisconnectAfterAnyWritesDone() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!stream_socket_) {
DVLOG(3) << "Socket is already disconnected.";
return;
}
if (!pending_buffer_) {
DisconnectImmediately();
return;
}
should_disconnect_after_write_ = true;
state_ = WebSocketState::kDisconnectingSoon;
handler_.reset();
}
void WebSocketConnection::DisconnectImmediately() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!stream_socket_) {
DVLOG(3) << "Socket is already disconnected.";
handler_.reset();
return;
}
// Intentionally not calling Disconnect(), as it doesn't work with
// SSLServerSocket. Resetting the socket here is sufficient to disconnect.
ResetStreamSocket();
handler_.reset();
}
void WebSocketConnection::ResetStreamSocket() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (stream_socket_) {
stream_socket_.reset();
state_ = WebSocketState::kClosed;
}
// `this` may be deleted here.
}
void WebSocketConnection::SendRaw(base::span<const uint8_t> bytes) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
scoped_refptr<IOBufferWithSize> buffer =
base::MakeRefCounted<IOBufferWithSize>(bytes.size());
buffer->span().copy_from(bytes);
SendInternal(std::move(buffer), /*wait_for_handshake=*/false);
}
void WebSocketConnection::SendInternal(scoped_refptr<IOBufferWithSize> buffer,
bool wait_for_handshake) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if ((wait_for_handshake && state_ != WebSocketState::kOpen) ||
pending_buffer_) {
pending_messages_.emplace(std::move(buffer));
return;
}
const size_t buffer_size = buffer->size();
pending_buffer_ =
base::MakeRefCounted<DrainableIOBuffer>(std::move(buffer), buffer_size);
PerformWrite();
}
void WebSocketConnection::SetHandler(
std::unique_ptr<WebSocketHandler> handler) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
handler_ = std::move(handler);
}
void WebSocketConnection::PerformWrite()
VALID_CONTEXT_REQUIRED(sequence_checker_) {
const int result = stream_socket_->Write(
pending_buffer_.get(), pending_buffer_->BytesRemaining(),
base::BindOnce(&WebSocketConnection::OnWriteComplete,
base::WrapRefCounted(this)),
DefineNetworkTrafficAnnotation(
"test", "Traffic annotation for unit, browser and other tests"));
if (result != ERR_IO_PENDING) {
OnWriteComplete(result);
}
}
void WebSocketConnection::OnWriteComplete(int result)
VALID_CONTEXT_REQUIRED(sequence_checker_) {
if (result < 0) {
DVLOG(1) << "Failed to write to WebSocket connection, error: " << result;
DisconnectImmediately();
return;
}
pending_buffer_->DidConsume(result);
if (pending_buffer_->BytesRemaining() > 0) {
PerformWrite();
return;
}
pending_buffer_ = nullptr;
if (!pending_messages_.empty()) {
scoped_refptr<IOBufferWithSize> next_message =
std::move(pending_messages_.front());
pending_messages_.pop();
SendInternal(std::move(next_message), /*wait_for_handshake=*/false);
return;
}
if (should_disconnect_after_write_) {
DisconnectImmediately();
}
}
void WebSocketConnection::Read() VALID_CONTEXT_REQUIRED(sequence_checker_) {
read_buffer_ = base::MakeRefCounted<IOBufferWithSize>(4096);
const int result =
stream_socket_->Read(read_buffer_.get(), read_buffer_->size(),
base::BindOnce(&WebSocketConnection::OnReadComplete,
base::WrapRefCounted(this)));
if (result != ERR_IO_PENDING) {
OnReadComplete(result);
}
}
void WebSocketConnection::OnReadComplete(int result)
VALID_CONTEXT_REQUIRED(sequence_checker_) {
if (result <= 0) {
DVLOG(1) << "Failed to read from WebSocket connection, error: " << result;
DisconnectImmediately();
return;
}
if (!handler_) {
DVLOG(1) << "No handler set, ignoring read.";
return;
}
base::span<uint8_t> data_span =
read_buffer_->first(static_cast<size_t>(result));
WebSocketFrameParser parser;
std::vector<std::unique_ptr<WebSocketFrameChunk>> frame_chunks;
parser.Decode(data_span, &frame_chunks);
for (auto& chunk : frame_chunks) {
auto assemble_result = chunk_assembler_.HandleChunk(std::move(chunk));
if (assemble_result.has_value()) {
std::unique_ptr<WebSocketFrame> assembled_frame =
std::move(assemble_result).value();
HandleFrame(assembled_frame->header.opcode,
base::as_chars(assembled_frame->payload),
assembled_frame->header.final);
continue;
}
if (assemble_result.error() == ERR_WS_PROTOCOL_ERROR) {
DVLOG(1) << "Protocol error while handling frame.";
StartClosingHandshake(1002, "Protocol error");
DisconnectAfterAnyWritesDone();
return;
}
}
if (stream_socket_) {
Read();
}
}
void WebSocketConnection::HandleFrame(WebSocketFrameHeader::OpCode opcode,
base::span<const char> payload,
bool is_final)
VALID_CONTEXT_REQUIRED(sequence_checker_) {
CHECK(handler_) << "No handler set for WebSocket connection.";
switch (opcode) {
case WebSocketFrameHeader::kOpCodeText:
case WebSocketFrameHeader::kOpCodeBinary:
case WebSocketFrameHeader::kOpCodeContinuation: {
auto message_result =
message_assembler_.HandleFrame(is_final, opcode, payload);
if (message_result.has_value()) {
if (message_result->is_text_message) {
handler_->OnTextMessage(base::as_string_view(message_result->body));
} else {
handler_->OnBinaryMessage(message_result->body);
}
} else if (message_result.error() == ERR_WS_PROTOCOL_ERROR) {
StartClosingHandshake(1002, "Protocol error");
DisconnectAfterAnyWritesDone();
}
break;
}
case WebSocketFrameHeader::kOpCodeClose: {
auto parse_close_frame_result = ParseCloseFrame(payload);
if (parse_close_frame_result.error.has_value()) {
DVLOG(1) << "Failed to parse close frame: "
<< parse_close_frame_result.error.value();
StartClosingHandshake(1002, "Protocol error");
DisconnectAfterAnyWritesDone();
} else {
handler_->OnClosingHandshake(parse_close_frame_result.code,
parse_close_frame_result.reason);
}
break;
}
case WebSocketFrameHeader::kOpCodePing:
handler_->OnPing(base::as_bytes(payload));
break;
case WebSocketFrameHeader::kOpCodePong:
handler_->OnPong(base::as_bytes(payload));
break;
default:
DVLOG(2) << "Unknown frame opcode: " << opcode;
StartClosingHandshake(1002, "Protocol error");
DisconnectAfterAnyWritesDone();
break;
}
}
void WebSocketConnection::SendHandshakeResponse() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!stream_socket_) {
DVLOG(2) << "Stream socket is already null. Returning early.";
return;
}
std::string response_text = "HTTP/1.1 101 Switching Protocols\r\n";
for (const auto& header : response_headers_) {
base::StrAppend(&response_text,
{header.first, ": ", header.second, "\r\n"});
}
base::StrAppend(&response_text, {"\r\n"});
SendRaw(base::as_byte_span(response_text));
state_ = WebSocketState::kOpen;
Read();
// A nullptr check is performed because the connection may have been closed
// within Read().
if (handler_) {
handler_->OnHandshakeComplete();
} else {
DVLOG(2)
<< "Handler is null after starting Read. Connection likely closed.";
}
}
scoped_refptr<IOBufferWithSize> CreateTextFrame(std::string_view message) {
return BuildWebSocketFrame(base::as_byte_span(message),
WebSocketFrameHeader::kOpCodeText);
}
scoped_refptr<IOBufferWithSize> CreateBinaryFrame(
base::span<const uint8_t> message) {
return BuildWebSocketFrame(message, WebSocketFrameHeader::kOpCodeBinary);
}
scoped_refptr<IOBufferWithSize> CreateCloseFrame(std::optional<uint16_t> code,
std::string_view message) {
DVLOG(3) << "Creating close frame with code: "
<< (code ? base::NumberToString(*code) : "none")
<< ", Message: " << message;
CHECK(message.empty() || code);
CHECK(base::IsStringUTF8AllowingNoncharacters(message));
if (!code) {
return BuildWebSocketFrame(base::span<const uint8_t>(),
WebSocketFrameHeader::kOpCodeClose);
}
auto payload =
base::HeapArray<uint8_t>::Uninit(sizeof(uint16_t) + message.size());
base::SpanWriter<uint8_t> writer{payload};
writer.WriteU16BigEndian(code.value());
writer.Write(base::as_byte_span(message));
return BuildWebSocketFrame(payload, WebSocketFrameHeader::kOpCodeClose);
}
scoped_refptr<IOBufferWithSize> CreatePingFrame(
base::span<const uint8_t> payload) {
return BuildWebSocketFrame(payload, WebSocketFrameHeader::kOpCodePing);
}
scoped_refptr<IOBufferWithSize> CreatePongFrame(
base::span<const uint8_t> payload) {
return BuildWebSocketFrame(payload, WebSocketFrameHeader::kOpCodePong);
}
scoped_refptr<IOBufferWithSize> BuildWebSocketFrame(
base::span<const uint8_t> payload,
WebSocketFrameHeader::OpCode op_code) {
WebSocketFrameHeader header(op_code);
header.final = true;
header.payload_length = payload.size();
const size_t header_size = GetWebSocketFrameHeaderSize(header);
scoped_refptr<IOBufferWithSize> buffer =
base::MakeRefCounted<IOBufferWithSize>(header_size + payload.size());
const int written_header_size =
WriteWebSocketFrameHeader(header, nullptr, buffer->span());
base::span<uint8_t> buffer_span = buffer->span().subspan(
base::checked_cast<size_t>(written_header_size), payload.size());
buffer_span.copy_from(payload);
return buffer;
}
} // namespace net::test_server