blob: 78370bbab70b1bb3e4ed96f7ff82e0e3ffd5b9aa [file] [log] [blame] [edit]
/*
* Copyright (C) 2024 Apple Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS ``AS IS''
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
* THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS
* BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
* THE POSSIBILITY OF SUCH DAMAGE.
*/
#import "config.h"
#import "NetworkConnection.h"
#import "HTTPServer.h"
#import <pal/spi/cocoa/NetworkSPI.h>
#import <wtf/BlockPtr.h>
#import <wtf/SHA1.h>
#import <wtf/SoftLinking.h>
#import <wtf/StdLibExtras.h>
#import <wtf/ThreadSafeRefCounted.h>
#import <wtf/cocoa/VectorCocoa.h>
#import <wtf/darwin/DispatchExtras.h>
#import <wtf/text/Base64.h>
#import <wtf/text/StringToIntegerConversion.h>
#if HAVE(WEB_TRANSPORT)
SOFT_LINK_FRAMEWORK(Network)
SOFT_LINK_MAY_FAIL(Network, nw_webtransport_metadata_set_local_draining, void, (nw_protocol_metadata_t metadata), (metadata))
#define nw_webtransport_metadata_set_local_draining softLinknw_webtransport_metadata_set_local_draining
#endif
namespace TestWebKitAPI {
static Vector<uint8_t> vectorFromData(dispatch_data_t content)
{
ASSERT(content);
__block Vector<uint8_t> request;
dispatch_data_apply(content, ^bool(dispatch_data_t, size_t, const void* buffer, size_t size) {
request.append(std::span { static_cast<const uint8_t*>(buffer), size });
return true;
});
return request;
}
static OSObjectPtr<dispatch_data_t> dataFromString(String&& s)
{
auto impl = s.releaseImpl();
ASSERT(impl->is8Bit());
auto characters = impl->span8();
return adoptOSObject(dispatch_data_create(characters.data(), characters.size(), mainDispatchQueueSingleton(), ^{
(void)impl;
}));
}
void Connection::receiveBytes(CompletionHandler<void(Vector<uint8_t>&&)>&& completionHandler, size_t minimumSize) const
{
nw_connection_receive(m_connection.get(), minimumSize, std::numeric_limits<uint32_t>::max(), makeBlockPtr([connection = *this, completionHandler = WTF::move(completionHandler)](dispatch_data_t content, nw_content_context_t, bool, nw_error_t error) mutable {
if (error || !content)
return completionHandler({ });
completionHandler(vectorFromData(content));
}).get());
}
void Connection::receiveHTTPRequest(CompletionHandler<void(Vector<char>&&)>&& completionHandler, Vector<char>&& buffer) const
{
receiveBytes([connection = *this, completionHandler = WTF::move(completionHandler), buffer = WTF::move(buffer)](Vector<uint8_t>&& bytes) mutable {
buffer.appendVector(WTF::move(bytes));
if (size_t doubleNewlineIndex = find(buffer.span(), "\r\n\r\n"_span); doubleNewlineIndex != notFound) {
if (size_t contentLengthBeginIndex = find(buffer.span(), "Content-Length"_span); contentLengthBeginIndex != notFound) {
size_t contentLength = parseIntegerAllowingTrailingJunk<int>(buffer.span().subspan(contentLengthBeginIndex + strlen("Content-Length: "))).value_or(0);
size_t headerLength = doubleNewlineIndex + strlen("\r\n\r\n");
if (buffer.size() - headerLength < contentLength)
return connection.receiveHTTPRequest(WTF::move(completionHandler), WTF::move(buffer));
}
completionHandler(WTF::move(buffer));
} else
connection.receiveHTTPRequest(WTF::move(completionHandler), WTF::move(buffer));
});
}
ReceiveHTTPRequestOperation Connection::awaitableReceiveHTTPRequest() const
{
return { *this };
}
void ReceiveHTTPRequestOperation::await_suspend(std::coroutine_handle<> handle)
{
m_connection.receiveHTTPRequest([this, handle](Vector<char>&& result) mutable {
m_result = WTF::move(result);
handle();
});
}
ReceiveBytesOperation Connection::awaitableReceiveBytes() const
{
return { *this };
}
void ReceiveBytesOperation::await_suspend(std::coroutine_handle<> handle)
{
m_connection.receiveBytes([this, handle](Vector<uint8_t>&& result) mutable {
m_result = WTF::move(result);
handle();
});
}
void SendOperation::await_suspend(std::coroutine_handle<> handle)
{
m_connection.send(WTF::move(m_data), [handle] (bool) mutable {
handle();
});
}
SendOperation Connection::awaitableSend(Vector<uint8_t>&& message)
{
return { makeDispatchData(WTF::move(message)), *this };
}
SendOperation Connection::awaitableSend(String&& message)
{
return { dataFromString(WTF::move(message)), *this };
}
SendOperation Connection::awaitableSend(OSObjectPtr<dispatch_data_t>&& data)
{
return { WTF::move(data), *this };
}
void Connection::send(String&& message, CompletionHandler<void()>&& completionHandler) const
{
send(dataFromString(WTF::move(message)), [completionHandler = WTF::move(completionHandler)] (bool) mutable {
if (completionHandler)
completionHandler();
});
}
void Connection::send(Vector<uint8_t>&& message, CompletionHandler<void()>&& completionHandler) const
{
send(makeDispatchData(WTF::move(message)), [completionHandler = WTF::move(completionHandler)] (bool) mutable {
if (completionHandler)
completionHandler();
});
}
void Connection::sendAndReportError(Vector<uint8_t>&& message, CompletionHandler<void(bool)>&& completionHandler) const
{
send(makeDispatchData(WTF::move(message)), WTF::move(completionHandler));
}
void Connection::send(OSObjectPtr<dispatch_data_t>&& message, CompletionHandler<void(bool)>&& completionHandler) const
{
nw_connection_send(m_connection.get(), message.get(), NW_CONNECTION_DEFAULT_MESSAGE_CONTEXT, true, makeBlockPtr([completionHandler = WTF::move(completionHandler)](nw_error_t error) mutable {
if (completionHandler)
completionHandler(!!error);
}).get());
}
void Connection::webSocketHandshake(CompletionHandler<void()>&& connectionHandler)
{
receiveHTTPRequest([connection = Connection(*this), connectionHandler = WTF::move(connectionHandler)] (Vector<char>&& request) mutable {
auto webSocketAcceptValue = [] (const Vector<char>& request) {
constexpr auto keyHeaderField = "Sec-WebSocket-Key: "_s;
size_t keyBegin = find(request.span(), keyHeaderField.span());
ASSERT(keyBegin != notFound);
auto keySpan = request.subspan(keyBegin + keyHeaderField.length());
size_t keyEnd = find(keySpan, "\r\n"_span);
ASSERT(keyEnd != notFound);
constexpr auto webSocketKeyGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"_s;
SHA1 sha1;
sha1.addBytes(byteCast<uint8_t>(keySpan.first(keyEnd)));
sha1.addBytes(webSocketKeyGUID.span());
SHA1::Digest hash;
sha1.computeHash(hash);
return base64EncodeToString(hash);
};
connection.send(HTTPResponse(101, {
{ "Upgrade"_s, "websocket"_s },
{ "Connection"_s, "Upgrade"_s },
{ "Sec-WebSocket-Accept"_s, webSocketAcceptValue(request) }
}).serialize(HTTPResponse::IncludeContentLength::No), WTF::move(connectionHandler));
});
}
void Connection::terminate(CompletionHandler<void()>&& completionHandler)
{
nw_connection_set_state_changed_handler(m_connection.get(), makeBlockPtr([completionHandler = WTF::move(completionHandler)] (nw_connection_state_t state, nw_error_t error) mutable {
ASSERT_UNUSED(error, !error);
if (state == nw_connection_state_cancelled && completionHandler)
completionHandler();
}).get());
nw_connection_cancel(m_connection.get());
}
#if HAVE(WEB_TRANSPORT)
// FIXME: This shouldn't need to be thread safe.
// Make it non-thread-safe once rdar://161905206 is resolved.
struct ConnectionGroup::Data : public ThreadSafeRefCounted<ConnectionGroup::Data> {
static Ref<Data> create(nw_connection_group_t group) { return adoptRef(*new Data(group)); }
Data(nw_connection_group_t group)
: group(group) { }
RetainPtr<nw_connection_group_t> group;
CompletionHandler<void(Connection)> connectionHandler;
Vector<Connection> connections;
CompletionHandler<void()> failureCompletionHandler;
};
ConnectionGroup::ConnectionGroup(nw_connection_group_t group)
: m_data(Data::create(group)) { }
ConnectionGroup::~ConnectionGroup() = default;
ConnectionGroup::ConnectionGroup(const ConnectionGroup&) = default;
void ConnectionGroup::markAsFailed()
{
if (auto handler = std::exchange(m_data->failureCompletionHandler, nullptr))
handler();
}
Awaitable<void> ConnectionGroup::awaitableFailure()
{
co_return co_await AwaitableFromCompletionHandler<void> { [data = m_data] (auto completionHandler) {
data->failureCompletionHandler = WTF::move(completionHandler);
} };
}
Connection ConnectionGroup::createWebTransportConnection(ConnectionType type) const
{
RetainPtr webtransportOptions = adoptNS(nw_webtransport_create_options());
ASSERT(webtransportOptions);
nw_webtransport_options_set_is_unidirectional(webtransportOptions.get(), type == ConnectionType::Unidirectional);
nw_webtransport_options_set_is_datagram(webtransportOptions.get(), type == ConnectionType::Datagram);
RetainPtr connection = adoptNS(nw_connection_group_extract_connection(m_data->group.get(), nil, webtransportOptions.get()));
ASSERT(connection);
nw_connection_set_queue(connection.get(), mainDispatchQueueSingleton());
nw_connection_start(connection.get());
return Connection(connection.get());
}
void ConnectionGroup::cancel()
{
nw_connection_group_cancel(m_data->group.get());
}
void ReceiveIncomingConnectionOperation::await_suspend(std::coroutine_handle<> handle)
{
m_group.receiveIncomingConnection([this, handle](Connection result) mutable {
m_result = WTF::move(result);
handle();
});
}
ReceiveIncomingConnectionOperation ConnectionGroup::receiveIncomingConnection() const
{
return { *this };
}
void ConnectionGroup::receiveIncomingConnection(CompletionHandler<void(Connection)>&& connectionHandler)
{
m_data->connectionHandler = WTF::move(connectionHandler);
}
void ConnectionGroup::receiveIncomingConnection(Connection connection)
{
m_data->connections.append(connection);
nw_connection_set_state_changed_handler(connection.m_connection.get(), [connection, data = m_data] (nw_connection_state_t state, nw_error_t error) mutable {
if (state != nw_connection_state_ready)
return;
data->connectionHandler(connection);
});
nw_connection_set_queue(connection.m_connection.get(), mainDispatchQueueSingleton());
nw_connection_start(connection.m_connection.get());
}
void ConnectionGroup::drainWebTransportSession()
{
RetainPtr metadata = nw_connection_group_copy_protocol_metadata(m_data->group.get(), adoptNS(nw_protocol_copy_webtransport_definition()).get());
if (metadata && canLoadnw_webtransport_metadata_set_local_draining())
nw_webtransport_metadata_set_local_draining(metadata.get());
}
#endif // HAVE(WEB_TRANSPORT)
} // namespace TestWebKitAPI