blob: 5c72b1ae9c2b5593792d102ad2d2d67b32ebc6bf [file] [log] [blame]
import asyncio
import logging
import os
import ssl
import threading
import traceback
from urllib.parse import urlparse
from typing import Any, Dict, List, Optional, Tuple
# TODO(bashi): Remove import check suppressions once aioquic dependency is resolved.
from aioquic.asyncio import QuicConnectionProtocol, serve # type: ignore
from aioquic.asyncio.client import connect # type: ignore
from aioquic.h3.connection import H3_ALPN, FrameType, H3Connection # type: ignore
from aioquic.h3.events import H3Event, HeadersReceived, WebTransportStreamDataReceived, DatagramReceived # type: ignore
from aioquic.quic.configuration import QuicConfiguration # type: ignore
from aioquic.quic.connection import stream_is_unidirectional # type: ignore
from aioquic.quic.events import QuicEvent, ProtocolNegotiated # type: ignore
from aioquic.tls import SessionTicket # type: ignore
from aioquic.quic.packet import QuicErrorCode # type: ignore
from tools.wptserve.wptserve import stash # type: ignore
"""
A WebTransport over HTTP/3 server for testing.
The server interprets the underlying protocols (WebTransport, HTTP/3 and QUIC)
and passes events to a particular webtransport handler. From the standpoint of
test authors, a webtransport handler is a Python script which contains some
callback functions. See handler.py for available callbacks.
"""
SERVER_NAME = 'webtransport-h3-server'
_logger: logging.Logger = logging.getLogger(__name__)
_doc_root: str = ""
class WebTransportH3Protocol(QuicConnectionProtocol):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._handler: Optional[Any] = None
self._http: Optional[H3Connection] = None
def quic_event_received(self, event: QuicEvent) -> None:
if isinstance(event, ProtocolNegotiated):
self._http = H3Connection(self._quic, enable_webtransport=True)
if self._http is not None:
for http_event in self._http.handle_event(event):
self._h3_event_received(http_event)
def _h3_event_received(self, event: H3Event) -> None:
if isinstance(event, HeadersReceived):
# Convert from List[Tuple[bytes, bytes]] to Dict[bytes, bytes].
# Only the last header will be kept when there are duplicate
# headers.
headers = {}
for header, value in event.headers:
headers[header] = value
method = headers.get(b":method")
protocol = headers.get(b":protocol")
if method == b"CONNECT" and protocol == b"webtransport":
self._handshake_webtransport(event, headers)
else:
self._send_error_response(event.stream_id, 400)
if self._handler is not None:
if isinstance(event, WebTransportStreamDataReceived):
self._handler.stream_data_received(stream_id=event.stream_id,
data=event.data,
stream_ended=event.stream_ended)
elif isinstance(event, DatagramReceived):
self._handler.datagram_received(data=event.data)
def _send_error_response(self, stream_id: int, status_code: int) -> None:
assert self._http is not None
headers = [(b"server", SERVER_NAME.encode()),
(b":status", str(status_code).encode())]
self._http.send_headers(stream_id=stream_id,
headers=headers,
end_stream=True)
def _handshake_webtransport(self, event: HeadersReceived,
request_headers: Dict[bytes, bytes]) -> None:
assert self._http is not None
path = request_headers.get(b":path")
if path is None:
# `:path` must be provided.
self._send_error_response(event.stream_id, 400)
return
# Create a handler using `:path`.
try:
self._handler = self._create_event_handler(
session_id=event.stream_id,
path=path,
request_headers=event.headers)
except IOError:
self._send_error_response(event.stream_id, 404)
return
response_headers = [
(b"server", SERVER_NAME.encode()),
]
self._handler.connect_received(response_headers=response_headers)
status_code = None
for name, value in response_headers:
if name == b":status":
status_code = value
break
if not status_code:
response_headers.append((b":status", b"200"))
self._http.send_headers(stream_id=event.stream_id,
headers=response_headers)
if status_code is None or status_code == b"200":
self._handler.session_established()
def _create_event_handler(self, session_id: int, path: bytes,
request_headers: List[Tuple[bytes, bytes]]) -> Any:
parsed = urlparse(path.decode())
file_path = os.path.join(_doc_root, parsed.path.lstrip("/"))
callbacks = {"__file__": file_path}
with open(file_path) as f:
exec(compile(f.read(), path, "exec"), callbacks)
session = WebTransportSession(self, session_id, request_headers)
return WebTransportEventHandler(session, callbacks)
class WebTransportSession:
"""
A WebTransport session.
"""
def __init__(self, protocol: WebTransportH3Protocol, session_id: int,
request_headers: List[Tuple[bytes, bytes]]) -> None:
self.session_id = session_id
self.request_headers = request_headers
self._protocol: WebTransportH3Protocol = protocol
self._http: H3Connection = protocol._http
# Use the a shared default path for all handlers so that different
# WebTransport sessions can access the same store easily.
self._stash_path = '/webtransport/handlers'
self._stash: Optional[stash.Stash] = None
@property
def stash(self) -> stash.Stash:
"""A Stash object for storing cross-session state."""
if self._stash is None:
address, authkey = stash.load_env_config()
self._stash = stash.Stash(self._stash_path, address, authkey)
return self._stash
def stream_is_unidirectional(self, stream_id: int) -> bool:
"""Return True if the stream is unidirectional."""
return stream_is_unidirectional(stream_id)
def close(self,
error_code: int = QuicErrorCode.NO_ERROR,
reason_phrase: str = "") -> None:
"""
Close the session.
:param error_code: An error code indicating why the session is
being closed.
:param reason_phrase: A human readable explanation of why the
session is being closed.
"""
self._http._quic.close(error_code=error_code,
reason_phrase=reason_phrase)
self._protocol.transmit()
def create_unidirectional_stream(self) -> int:
"""
Create a unidirectional WebTransport stream and return the stream ID.
"""
return self._http.create_webtransport_stream(
session_id=self.session_id, is_unidirectional=True)
def create_bidirectional_stream(self) -> int:
"""
Create a bidirectional WebTransport stream and return the stream ID.
"""
stream_id = self._http.create_webtransport_stream(
session_id=self.session_id, is_unidirectional=False)
# TODO(bashi): Remove this workaround when aioquic supports receiving
# data on server-initiated bidirectional streams.
stream = self._http._get_or_create_stream(stream_id)
assert stream.frame_type is None
assert stream.session_id is None
stream.frame_type = FrameType.WEBTRANSPORT_STREAM
stream.session_id = self.session_id
return stream_id
def send_stream_data(self,
stream_id: int,
data: bytes,
end_stream: bool = False) -> None:
"""
Send data on the specific stream.
:param stream_id: The stream ID on which to send the data.
:param data: The data to send.
:param end_stream: If set to True, the stream will be closed.
"""
self._http._quic.send_stream_data(stream_id=stream_id,
data=data,
end_stream=end_stream)
def send_datagram(self, data: bytes) -> None:
"""
Send data using a datagram frame.
:param data: The data to send.
"""
self._http.send_datagram(flow_id=self.session_id, data=data)
class WebTransportEventHandler:
def __init__(self, session: WebTransportSession,
callbacks: Dict[str, Any]) -> None:
self._session = session
self._callbacks = callbacks
def _run_callback(self, callback_name: str,
*args: Any, **kwargs: Any) -> None:
if callback_name not in self._callbacks:
return
try:
self._callbacks[callback_name](*args, **kwargs)
except Exception as e:
_logger.warn(str(e))
traceback.print_exc()
def connect_received(self, response_headers: List[Tuple[bytes,
bytes]]) -> None:
self._run_callback("connect_received", self._session.request_headers,
response_headers)
def session_established(self) -> None:
self._run_callback("session_established", self._session)
def stream_data_received(self, stream_id: int, data: bytes,
stream_ended: bool) -> None:
self._run_callback("stream_data_received", self._session, stream_id,
data, stream_ended)
def datagram_received(self, data: bytes) -> None:
self._run_callback("datagram_received", self._session, data)
class SessionTicketStore:
"""
Simple in-memory store for session tickets.
"""
def __init__(self) -> None:
self.tickets: Dict[bytes, SessionTicket] = {}
def add(self, ticket: SessionTicket) -> None:
self.tickets[ticket.ticket] = ticket
def pop(self, label: bytes) -> Optional[SessionTicket]:
return self.tickets.pop(label, None)
class WebTransportH3Server:
"""
A WebTransport over HTTP/3 for testing.
:param host: Host from which to serve.
:param port: Port from which to serve.
:param doc_root: Document root for serving handlers.
:param cert_path: Path to certificate file to use.
:param key_path: Path to key file to use.
:param logger: a Logger object for this server.
"""
def __init__(self, host: str, port: int, doc_root: str, cert_path: str,
key_path: str, logger: Optional[logging.Logger]) -> None:
self.host = host
self.port = port
self.doc_root = doc_root
self.cert_path = cert_path
self.key_path = key_path
self.started = False
global _doc_root
_doc_root = self.doc_root
global _logger
if logger is not None:
_logger = logger
def start(self) -> None:
"""Start the server."""
self.server_thread = threading.Thread(
target=self._start_on_server_thread, daemon=True)
self.server_thread.start()
self.started = True
def _start_on_server_thread(self) -> None:
configuration = QuicConfiguration(
alpn_protocols=H3_ALPN,
is_client=False,
max_datagram_frame_size=65536,
)
_logger.info("Starting WebTransport over HTTP/3 server on %s:%s",
self.host, self.port)
configuration.load_cert_chain(self.cert_path, self.key_path)
ticket_store = SessionTicketStore()
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.loop.run_until_complete(
serve(
self.host,
self.port,
configuration=configuration,
create_protocol=WebTransportH3Protocol,
session_ticket_fetcher=ticket_store.pop,
session_ticket_handler=ticket_store.add,
))
self.loop.run_forever()
def stop(self) -> None:
"""Stop the server."""
if self.started:
asyncio.run_coroutine_threadsafe(self._stop_on_server_thread(),
self.loop)
self.server_thread.join()
_logger.info("Stopped WebTransport over HTTP/3 server on %s:%s",
self.host, self.port)
self.started = False
async def _stop_on_server_thread(self) -> None:
self.loop.stop()
def server_is_running(host: str, port: int, timeout: float) -> bool:
"""
Check the WebTransport over HTTP/3 server is running at the given `host` and
`port`.
"""
loop = asyncio.get_event_loop()
return loop.run_until_complete(_connect_server_with_timeout(host, port, timeout))
async def _connect_server_with_timeout(host: str, port: int, timeout: float) -> bool:
try:
await asyncio.wait_for(_connect_to_server(host, port), timeout=timeout)
except asyncio.TimeoutError:
_logger.warning("Failed to connect WebTransport over HTTP/3 server")
return False
return True
async def _connect_to_server(host: str, port: int) -> None:
configuration = QuicConfiguration(
alpn_protocols=H3_ALPN,
is_client=True,
verify_mode=ssl.CERT_NONE,
)
async with connect(host, port, configuration=configuration) as protocol:
await protocol.ping()