| // Copyright 2024 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "net/test/embedded_test_server/create_websocket_handler.h" |
| |
| #include "base/base64.h" |
| #include "base/functional/bind.h" |
| #include "base/memory/scoped_refptr.h" |
| #include "base/strings/string_util.h" |
| #include "base/test/bind.h" |
| #include "base/time/time.h" |
| #include "base/types/expected.h" |
| #include "net/base/host_port_pair.h" |
| #include "net/base/url_util.h" |
| #include "net/http/http_status_code.h" |
| #include "net/test/embedded_test_server/embedded_test_server.h" |
| #include "net/test/embedded_test_server/http_request.h" |
| #include "net/test/embedded_test_server/http_response.h" |
| #include "net/test/embedded_test_server/websocket_connection.h" |
| |
| namespace net::test_server { |
| |
| namespace { |
| |
| // Helper function to strip the query part of the URL |
| std::string_view StripQuery(std::string_view url) { |
| const size_t query_position = url.find('?'); |
| return (query_position == std::string_view::npos) |
| ? url |
| : url.substr(0, query_position); |
| } |
| |
| std::unique_ptr<HttpResponse> MakeErrorResponse(HttpStatusCode code, |
| std::string_view content) { |
| auto error_response = std::make_unique<BasicHttpResponse>(); |
| error_response->set_code(code); |
| error_response->set_content(content); |
| VLOG(3) << "Error response created. Code: " << static_cast<int>(code) |
| << ", Content: " << content; |
| return error_response; |
| } |
| |
| EmbeddedTestServer::UpgradeResultOrHttpResponse HandleWebSocketUpgrade( |
| std::string_view handle_path, |
| WebSocketHandlerCreator websocket_handler_creator, |
| EmbeddedTestServer* server, |
| const HttpRequest& request, |
| HttpConnection* connection) { |
| VLOG(3) << "Handling WebSocket upgrade for path: " << handle_path; |
| |
| std::string_view request_path = StripQuery(request.relative_url); |
| |
| if (request_path != handle_path) { |
| return UpgradeResult::kNotHandled; |
| } |
| |
| if (request.method != METHOD_GET) { |
| return base::unexpected( |
| MakeErrorResponse(HttpStatusCode::HTTP_BAD_REQUEST, |
| "Invalid request method. Expected GET.")); |
| } |
| |
| // TODO(crbug.com/40812029): Check that the HTTP version is 1.1 |
| // See https://datatracker.ietf.org/doc/html/rfc6455#section-4.2.1 |
| |
| auto host_header = request.headers.find("Host"); |
| if (host_header == request.headers.end()) { |
| VLOG(1) << "Host header is missing."; |
| return base::unexpected(MakeErrorResponse(HttpStatusCode::HTTP_BAD_REQUEST, |
| "Host header is missing.")); |
| } |
| |
| HostPortPair host_port = HostPortPair::FromString(host_header->second); |
| if (!IsCanonicalizedHostCompliant(host_port.host())) { |
| VLOG(1) << "Host header is invalid: " << host_port.host(); |
| return base::unexpected(MakeErrorResponse(HttpStatusCode::HTTP_BAD_REQUEST, |
| "Host header is invalid.")); |
| } |
| |
| auto upgrade_header = request.headers.find("Upgrade"); |
| if (upgrade_header == request.headers.end() || |
| !base::EqualsCaseInsensitiveASCII(upgrade_header->second, "websocket")) { |
| VLOG(1) << "Upgrade header is missing or invalid: " |
| << upgrade_header->second; |
| return base::unexpected( |
| MakeErrorResponse(HttpStatusCode::HTTP_BAD_REQUEST, |
| "Upgrade header is missing or invalid.")); |
| } |
| |
| auto connection_header = request.headers.find("Connection"); |
| if (connection_header == request.headers.end()) { |
| VLOG(1) << "Connection header is missing."; |
| return base::unexpected(MakeErrorResponse(HttpStatusCode::HTTP_BAD_REQUEST, |
| "Connection header is missing.")); |
| } |
| |
| auto tokens = |
| base::SplitStringPiece(connection_header->second, ",", |
| base::TRIM_WHITESPACE, base::SPLIT_WANT_NONEMPTY); |
| if (!std::ranges::any_of(tokens, [](std::string_view token) { |
| return base::EqualsCaseInsensitiveASCII(token, "Upgrade"); |
| })) { |
| VLOG(1) << "Connection header does not contain 'Upgrade'. Tokens: " |
| << connection_header->second; |
| return base::unexpected( |
| MakeErrorResponse(HttpStatusCode::HTTP_BAD_REQUEST, |
| "Connection header does not contain 'Upgrade'.")); |
| } |
| |
| auto websocket_version_header = request.headers.find("Sec-WebSocket-Version"); |
| if (websocket_version_header == request.headers.end() || |
| websocket_version_header->second != "13") { |
| VLOG(1) << "Invalid or missing Sec-WebSocket-Version: " |
| << websocket_version_header->second; |
| return base::unexpected(MakeErrorResponse( |
| HttpStatusCode::HTTP_BAD_REQUEST, "Sec-WebSocket-Version must be 13.")); |
| } |
| |
| auto sec_websocket_key_iter = request.headers.find("Sec-WebSocket-Key"); |
| if (sec_websocket_key_iter == request.headers.end()) { |
| VLOG(1) << "Sec-WebSocket-Key header is missing."; |
| return base::unexpected( |
| MakeErrorResponse(HttpStatusCode::HTTP_BAD_REQUEST, |
| "Sec-WebSocket-Key header is missing.")); |
| } |
| |
| auto decoded = base::Base64Decode(sec_websocket_key_iter->second); |
| if (!decoded || decoded->size() != 16) { |
| VLOG(1) << "Sec-WebSocket-Key is invalid or has incorrect length."; |
| return base::unexpected(MakeErrorResponse( |
| HttpStatusCode::HTTP_BAD_REQUEST, |
| "Sec-WebSocket-Key is invalid or has incorrect length.")); |
| } |
| |
| std::unique_ptr<StreamSocket> socket = connection->TakeSocket(); |
| CHECK(socket); |
| |
| auto websocket_connection = base::MakeRefCounted<WebSocketConnection>( |
| std::move(socket), sec_websocket_key_iter->second, server); |
| |
| auto handler = websocket_handler_creator.Run(websocket_connection); |
| handler->OnHandshake(request); |
| websocket_connection->SetHandler(std::move(handler)); |
| websocket_connection->SendHandshakeResponse(); |
| return UpgradeResult::kUpgraded; |
| } |
| |
| } // namespace |
| |
| EmbeddedTestServer::HandleUpgradeRequestCallback CreateWebSocketHandler( |
| std::string_view handle_path, |
| WebSocketHandlerCreator websocket_handler_creator, |
| EmbeddedTestServer* server) { |
| // Note: The callback registered in ControllableHttpResponse will not be |
| // called after the server has been destroyed. This guarantees that the |
| // EmbeddedTestServer pointer remains valid for the lifetime of the |
| // ControllableHttpResponse instance. |
| return base::BindRepeating(&HandleWebSocketUpgrade, handle_path, |
| websocket_handler_creator, server); |
| } |
| |
| } // namespace net::test_server |