blob: fda0d0760d9d3a97d17d0b1d707752ded2cc44e9 [file] [log] [blame]
// Copyright 2024 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/test/embedded_test_server/create_websocket_handler.h"
#include "base/base64.h"
#include "base/functional/bind.h"
#include "base/memory/scoped_refptr.h"
#include "base/strings/string_util.h"
#include "base/test/bind.h"
#include "base/time/time.h"
#include "base/types/expected.h"
#include "net/base/host_port_pair.h"
#include "net/base/url_util.h"
#include "net/http/http_status_code.h"
#include "net/test/embedded_test_server/embedded_test_server.h"
#include "net/test/embedded_test_server/http_request.h"
#include "net/test/embedded_test_server/http_response.h"
#include "net/test/embedded_test_server/websocket_connection.h"
namespace net::test_server {
namespace {
// Helper function to strip the query part of the URL
std::string_view StripQuery(std::string_view url) {
const size_t query_position = url.find('?');
return (query_position == std::string_view::npos)
? url
: url.substr(0, query_position);
}
std::unique_ptr<HttpResponse> MakeErrorResponse(HttpStatusCode code,
std::string_view content) {
auto error_response = std::make_unique<BasicHttpResponse>();
error_response->set_code(code);
error_response->set_content(content);
VLOG(3) << "Error response created. Code: " << static_cast<int>(code)
<< ", Content: " << content;
return error_response;
}
EmbeddedTestServer::UpgradeResultOrHttpResponse HandleWebSocketUpgrade(
std::string_view handle_path,
WebSocketHandlerCreator websocket_handler_creator,
EmbeddedTestServer* server,
const HttpRequest& request,
HttpConnection* connection) {
VLOG(3) << "Handling WebSocket upgrade for path: " << handle_path;
std::string_view request_path = StripQuery(request.relative_url);
if (request_path != handle_path) {
return UpgradeResult::kNotHandled;
}
if (request.method != METHOD_GET) {
return base::unexpected(
MakeErrorResponse(HttpStatusCode::HTTP_BAD_REQUEST,
"Invalid request method. Expected GET."));
}
// TODO(crbug.com/40812029): Check that the HTTP version is 1.1
// See https://datatracker.ietf.org/doc/html/rfc6455#section-4.2.1
auto host_header = request.headers.find("Host");
if (host_header == request.headers.end()) {
VLOG(1) << "Host header is missing.";
return base::unexpected(MakeErrorResponse(HttpStatusCode::HTTP_BAD_REQUEST,
"Host header is missing."));
}
HostPortPair host_port = HostPortPair::FromString(host_header->second);
if (!IsCanonicalizedHostCompliant(host_port.host())) {
VLOG(1) << "Host header is invalid: " << host_port.host();
return base::unexpected(MakeErrorResponse(HttpStatusCode::HTTP_BAD_REQUEST,
"Host header is invalid."));
}
auto upgrade_header = request.headers.find("Upgrade");
if (upgrade_header == request.headers.end() ||
!base::EqualsCaseInsensitiveASCII(upgrade_header->second, "websocket")) {
VLOG(1) << "Upgrade header is missing or invalid: "
<< upgrade_header->second;
return base::unexpected(
MakeErrorResponse(HttpStatusCode::HTTP_BAD_REQUEST,
"Upgrade header is missing or invalid."));
}
auto connection_header = request.headers.find("Connection");
if (connection_header == request.headers.end()) {
VLOG(1) << "Connection header is missing.";
return base::unexpected(MakeErrorResponse(HttpStatusCode::HTTP_BAD_REQUEST,
"Connection header is missing."));
}
auto tokens =
base::SplitStringPiece(connection_header->second, ",",
base::TRIM_WHITESPACE, base::SPLIT_WANT_NONEMPTY);
if (!std::ranges::any_of(tokens, [](std::string_view token) {
return base::EqualsCaseInsensitiveASCII(token, "Upgrade");
})) {
VLOG(1) << "Connection header does not contain 'Upgrade'. Tokens: "
<< connection_header->second;
return base::unexpected(
MakeErrorResponse(HttpStatusCode::HTTP_BAD_REQUEST,
"Connection header does not contain 'Upgrade'."));
}
auto websocket_version_header = request.headers.find("Sec-WebSocket-Version");
if (websocket_version_header == request.headers.end() ||
websocket_version_header->second != "13") {
VLOG(1) << "Invalid or missing Sec-WebSocket-Version: "
<< websocket_version_header->second;
return base::unexpected(MakeErrorResponse(
HttpStatusCode::HTTP_BAD_REQUEST, "Sec-WebSocket-Version must be 13."));
}
auto sec_websocket_key_iter = request.headers.find("Sec-WebSocket-Key");
if (sec_websocket_key_iter == request.headers.end()) {
VLOG(1) << "Sec-WebSocket-Key header is missing.";
return base::unexpected(
MakeErrorResponse(HttpStatusCode::HTTP_BAD_REQUEST,
"Sec-WebSocket-Key header is missing."));
}
auto decoded = base::Base64Decode(sec_websocket_key_iter->second);
if (!decoded || decoded->size() != 16) {
VLOG(1) << "Sec-WebSocket-Key is invalid or has incorrect length.";
return base::unexpected(MakeErrorResponse(
HttpStatusCode::HTTP_BAD_REQUEST,
"Sec-WebSocket-Key is invalid or has incorrect length."));
}
std::unique_ptr<StreamSocket> socket = connection->TakeSocket();
CHECK(socket);
auto websocket_connection = base::MakeRefCounted<WebSocketConnection>(
std::move(socket), sec_websocket_key_iter->second, server);
auto handler = websocket_handler_creator.Run(websocket_connection);
handler->OnHandshake(request);
websocket_connection->SetHandler(std::move(handler));
websocket_connection->SendHandshakeResponse();
return UpgradeResult::kUpgraded;
}
} // namespace
EmbeddedTestServer::HandleUpgradeRequestCallback CreateWebSocketHandler(
std::string_view handle_path,
WebSocketHandlerCreator websocket_handler_creator,
EmbeddedTestServer* server) {
// Note: The callback registered in ControllableHttpResponse will not be
// called after the server has been destroyed. This guarantees that the
// EmbeddedTestServer pointer remains valid for the lifetime of the
// ControllableHttpResponse instance.
return base::BindRepeating(&HandleWebSocketUpgrade, handle_path,
websocket_handler_creator, server);
}
} // namespace net::test_server