| # mypy: allow-subclassing-any, no-warn-return-any |
| |
| import asyncio |
| import logging |
| import os |
| import ssl |
| import sys |
| import threading |
| import traceback |
| from enum import IntEnum |
| from urllib.parse import urlparse |
| from typing import Any, Dict, List, Optional, Tuple |
| |
| # TODO(bashi): Remove import check suppressions once aioquic dependency is resolved. |
| from aioquic.buffer import Buffer # type: ignore |
| from aioquic.asyncio import QuicConnectionProtocol, serve # type: ignore |
| from aioquic.asyncio.client import connect # type: ignore |
| from aioquic.h3.connection import H3_ALPN, FrameType, H3Connection, ProtocolError, SettingsError # type: ignore |
| from aioquic.h3.events import H3Event, HeadersReceived, WebTransportStreamDataReceived, DatagramReceived, DataReceived # type: ignore |
| from aioquic.quic.configuration import QuicConfiguration # type: ignore |
| from aioquic.quic.connection import logger as quic_connection_logger # type: ignore |
| from aioquic.quic.connection import stream_is_unidirectional |
| from aioquic.quic.events import QuicEvent, ProtocolNegotiated, ConnectionTerminated, StreamReset # type: ignore |
| from aioquic.tls import SessionTicket # type: ignore |
| |
| from tools import localpaths # noqa: F401 |
| from wptserve import stash |
| from .capsule import H3Capsule, H3CapsuleDecoder, CapsuleType |
| |
| """ |
| A WebTransport over HTTP/3 server for testing. |
| |
| The server interprets the underlying protocols (WebTransport, HTTP/3 and QUIC) |
| and passes events to a particular webtransport handler. From the standpoint of |
| test authors, a webtransport handler is a Python script which contains some |
| callback functions. See handler.py for available callbacks. |
| """ |
| |
| SERVER_NAME = 'webtransport-h3-server' |
| |
| _logger: logging.Logger = logging.getLogger(__name__) |
| _doc_root: str = "" |
| |
| # Set aioquic's log level to WARNING to suppress some INFO logs which are |
| # recorded every connection close. |
| quic_connection_logger.setLevel(logging.WARNING) |
| |
| |
| class H3DatagramSetting(IntEnum): |
| # https://datatracker.ietf.org/doc/html/draft-ietf-masque-h3-datagram-04#section-8.1 |
| DRAFT04 = 0xffd277 |
| # https://datatracker.ietf.org/doc/html/rfc9220#section-5-2.2.1 |
| RFC = 0x33 |
| |
| |
| class H3ConnectionWithDatagram(H3Connection): |
| """ |
| A H3Connection subclass, to make it work with the latest |
| HTTP Datagram protocol. |
| """ |
| # https://datatracker.ietf.org/doc/html/rfc9220#name-iana-considerations |
| ENABLE_CONNECT_PROTOCOL = 0x08 |
| |
| def __init__(self, *args: Any, **kwargs: Any) -> None: |
| super().__init__(*args, **kwargs) |
| self._datagram_setting: Optional[H3DatagramSetting] = None |
| |
| def _validate_settings(self, settings: Dict[int, int]) -> None: |
| # aioquic doesn't recognize the RFC version of HTTP Datagrams yet. |
| # Intentionally don't call `super()._validate_settings(settings)` since |
| # it raises a SettingsError when only the RFC version is negotiated. |
| if settings.get(H3DatagramSetting.RFC) == 1: |
| self._datagram_setting = H3DatagramSetting.RFC |
| elif settings.get(H3DatagramSetting.DRAFT04) == 1: |
| self._datagram_setting = H3DatagramSetting.DRAFT04 |
| |
| if self._datagram_setting is None: |
| raise SettingsError("HTTP Datagrams support required") |
| |
| |
| def _get_local_settings(self) -> Dict[int, int]: |
| settings = super()._get_local_settings() |
| settings[H3DatagramSetting.RFC] = 1 |
| settings[H3DatagramSetting.DRAFT04] = 1 |
| settings[H3ConnectionWithDatagram.ENABLE_CONNECT_PROTOCOL] = 1 |
| return settings |
| |
| @property |
| def datagram_setting(self) -> Optional[H3DatagramSetting]: |
| return self._datagram_setting |
| |
| |
| class WebTransportH3Protocol(QuicConnectionProtocol): |
| def __init__(self, *args: Any, **kwargs: Any) -> None: |
| super().__init__(*args, **kwargs) |
| self._handler: Optional[Any] = None |
| self._http: Optional[H3ConnectionWithDatagram] = None |
| self._session_stream_id: Optional[int] = None |
| self._close_info: Optional[Tuple[int, bytes]] = None |
| self._capsule_decoder_for_session_stream: H3CapsuleDecoder =\ |
| H3CapsuleDecoder() |
| self._allow_calling_session_closed = True |
| self._allow_datagrams = False |
| |
| def quic_event_received(self, event: QuicEvent) -> None: |
| if isinstance(event, ProtocolNegotiated): |
| self._http = H3ConnectionWithDatagram( |
| self._quic, enable_webtransport=True) |
| if self._http.datagram_setting != H3DatagramSetting.DRAFT04: |
| self._allow_datagrams = True |
| |
| if self._http is not None: |
| for http_event in self._http.handle_event(event): |
| self._h3_event_received(http_event) |
| |
| if isinstance(event, ConnectionTerminated): |
| self._call_session_closed(close_info=None, abruptly=True) |
| if isinstance(event, StreamReset): |
| if self._handler: |
| self._handler.stream_reset(event.stream_id, event.error_code) |
| |
| def _h3_event_received(self, event: H3Event) -> None: |
| if isinstance(event, HeadersReceived): |
| # Convert from List[Tuple[bytes, bytes]] to Dict[bytes, bytes]. |
| # Only the last header will be kept when there are duplicate |
| # headers. |
| headers = {} |
| for header, value in event.headers: |
| headers[header] = value |
| |
| method = headers.get(b":method") |
| protocol = headers.get(b":protocol") |
| origin = headers.get(b"origin") |
| # Accept any Origin but the client must send it. |
| if method == b"CONNECT" and protocol == b"webtransport" and origin: |
| self._session_stream_id = event.stream_id |
| self._handshake_webtransport(event, headers) |
| else: |
| status_code = 404 if origin else 403 |
| self._send_error_response(event.stream_id, status_code) |
| |
| if isinstance(event, DataReceived) and\ |
| self._session_stream_id == event.stream_id: |
| if self._http and not self._http.datagram_setting and\ |
| len(event.data) > 0: |
| raise ProtocolError('Unexpected data on the session stream') |
| self._receive_data_on_session_stream( |
| event.data, event.stream_ended) |
| elif self._handler is not None: |
| if isinstance(event, WebTransportStreamDataReceived): |
| self._handler.stream_data_received( |
| stream_id=event.stream_id, |
| data=event.data, |
| stream_ended=event.stream_ended) |
| elif isinstance(event, DatagramReceived): |
| if self._allow_datagrams: |
| self._handler.datagram_received(data=event.data) |
| |
| def _receive_data_on_session_stream(self, data: bytes, fin: bool) -> None: |
| assert self._http is not None |
| if len(data) > 0: |
| self._capsule_decoder_for_session_stream.append(data) |
| if fin: |
| self._capsule_decoder_for_session_stream.final() |
| for capsule in self._capsule_decoder_for_session_stream: |
| if self._close_info is not None: |
| raise ProtocolError(( |
| "Receiving a capsule with type = {} after receiving " + |
| "CLOSE_WEBTRANSPORT_SESSION").format(capsule.type)) |
| assert self._http.datagram_setting is not None |
| if self._http.datagram_setting == H3DatagramSetting.RFC: |
| self._receive_h3_datagram_rfc_capsule_data( |
| capsule=capsule, fin=fin) |
| elif self._http.datagram_setting == H3DatagramSetting.DRAFT04: |
| self._receive_h3_datagram_draft04_capsule_data( |
| capsule=capsule, fin=fin) |
| |
| def _receive_h3_datagram_rfc_capsule_data(self, capsule: H3Capsule, fin: bool) -> None: |
| if capsule.type == CapsuleType.DATAGRAM_RFC: |
| raise ProtocolError( |
| f"Unimplemented capsule type: {capsule.type}") |
| elif capsule.type == CapsuleType.CLOSE_WEBTRANSPORT_SESSION: |
| self._set_close_info_and_may_close_session( |
| data=capsule.data, fin=fin) |
| else: |
| # Ignore unknown capsules. |
| return |
| |
| def _receive_h3_datagram_draft04_capsule_data( |
| self, capsule: H3Capsule, fin: bool) -> None: |
| if capsule.type in {CapsuleType.DATAGRAM_DRAFT04, |
| CapsuleType.REGISTER_DATAGRAM_CONTEXT_DRAFT04, |
| CapsuleType.CLOSE_DATAGRAM_CONTEXT_DRAFT04}: |
| raise ProtocolError( |
| f"Unimplemented capsule type: {capsule.type}") |
| if capsule.type in {CapsuleType.REGISTER_DATAGRAM_NO_CONTEXT_DRAFT04, |
| CapsuleType.CLOSE_WEBTRANSPORT_SESSION}: |
| # We'll handle this case below. |
| pass |
| else: |
| # We should ignore unknown capsules. |
| return |
| |
| if capsule.type == CapsuleType.REGISTER_DATAGRAM_NO_CONTEXT_DRAFT04: |
| buffer = Buffer(data=capsule.data) |
| format_type = buffer.pull_uint_var() |
| # https://ietf-wg-webtrans.github.io/draft-ietf-webtrans-http3/draft-ietf-webtrans-http3.html#name-datagram-format-type |
| WEBTRANPORT_FORMAT_TYPE = 0xff7c00 |
| if format_type != WEBTRANPORT_FORMAT_TYPE: |
| raise ProtocolError( |
| "Unexpected datagram format type: {}".format( |
| format_type)) |
| self._allow_datagrams = True |
| elif capsule.type == CapsuleType.CLOSE_WEBTRANSPORT_SESSION: |
| self._set_close_info_and_may_close_session( |
| data=capsule.data, fin=fin) |
| |
| def _set_close_info_and_may_close_session( |
| self, data: bytes, fin: bool) -> None: |
| buffer = Buffer(data=data) |
| code = buffer.pull_uint32() |
| # 4 bytes for the uint32. |
| reason = buffer.pull_bytes(len(data) - 4) |
| # TODO(bashi): Make sure `reason` is a UTF-8 text. |
| self._close_info = (code, reason) |
| if fin: |
| self._call_session_closed(self._close_info, abruptly=False) |
| |
| def _send_error_response(self, stream_id: int, status_code: int) -> None: |
| assert self._http is not None |
| headers = [(b":status", str(status_code).encode()), |
| (b"server", SERVER_NAME.encode())] |
| self._http.send_headers(stream_id=stream_id, |
| headers=headers, |
| end_stream=True) |
| |
| def _handshake_webtransport(self, event: HeadersReceived, |
| request_headers: Dict[bytes, bytes]) -> None: |
| assert self._http is not None |
| path = request_headers.get(b":path") |
| if path is None: |
| # `:path` must be provided. |
| self._send_error_response(event.stream_id, 400) |
| return |
| |
| # Create a handler using `:path`. |
| try: |
| self._handler = self._create_event_handler( |
| session_id=event.stream_id, |
| path=path, |
| request_headers=event.headers) |
| except OSError: |
| self._send_error_response(event.stream_id, 404) |
| return |
| |
| response_headers = [ |
| (b"server", SERVER_NAME.encode()), |
| (b"sec-webtransport-http3-draft", b"draft02"), |
| ] |
| self._handler.connect_received(response_headers=response_headers) |
| |
| status_code = None |
| for name, value in response_headers: |
| if name == b":status": |
| status_code = value |
| response_headers.remove((b":status", status_code)) |
| response_headers.insert(0, (b":status", status_code)) |
| break |
| if not status_code: |
| response_headers.insert(0, (b":status", b"200")) |
| self._http.send_headers(stream_id=event.stream_id, |
| headers=response_headers) |
| |
| if status_code is None or status_code == b"200": |
| self._handler.session_established() |
| |
| def _create_event_handler(self, session_id: int, path: bytes, |
| request_headers: List[Tuple[bytes, bytes]]) -> Any: |
| parsed = urlparse(path.decode()) |
| file_path = os.path.join(_doc_root, parsed.path.lstrip("/")) |
| callbacks = {"__file__": file_path} |
| with open(file_path) as f: |
| exec(compile(f.read(), path, "exec"), callbacks) |
| session = WebTransportSession(self, session_id, request_headers) |
| return WebTransportEventHandler(session, callbacks) |
| |
| def _call_session_closed( |
| self, close_info: Optional[Tuple[int, bytes]], |
| abruptly: bool) -> None: |
| allow_calling_session_closed = self._allow_calling_session_closed |
| self._allow_calling_session_closed = False |
| if self._handler and allow_calling_session_closed: |
| self._handler.session_closed(close_info, abruptly) |
| |
| |
| class WebTransportSession: |
| """ |
| A WebTransport session. |
| """ |
| |
| def __init__(self, protocol: WebTransportH3Protocol, session_id: int, |
| request_headers: List[Tuple[bytes, bytes]]) -> None: |
| self.session_id = session_id |
| self.request_headers = request_headers |
| |
| self._protocol: WebTransportH3Protocol = protocol |
| self._http: H3Connection = protocol._http |
| |
| # Use the a shared default path for all handlers so that different |
| # WebTransport sessions can access the same store easily. |
| self._stash_path = '/webtransport/handlers' |
| self._stash: Optional[stash.Stash] = None |
| self._dict_for_handlers: Dict[str, Any] = {} |
| |
| @property |
| def stash(self) -> stash.Stash: |
| """A Stash object for storing cross-session state.""" |
| if self._stash is None: |
| address, authkey = stash.load_env_config() # type: ignore |
| self._stash = stash.Stash(self._stash_path, address, authkey) # type: ignore |
| return self._stash |
| |
| @property |
| def dict_for_handlers(self) -> Dict[str, Any]: |
| """A dictionary that handlers can attach arbitrary data.""" |
| return self._dict_for_handlers |
| |
| def stream_is_unidirectional(self, stream_id: int) -> bool: |
| """Return True if the stream is unidirectional.""" |
| return stream_is_unidirectional(stream_id) |
| |
| def close(self, close_info: Optional[Tuple[int, bytes]]) -> None: |
| """ |
| Close the session. |
| |
| :param close_info The close information to send. |
| """ |
| self._protocol._allow_calling_session_closed = False |
| assert self._protocol._session_stream_id is not None |
| session_stream_id = self._protocol._session_stream_id |
| if close_info is not None: |
| code = close_info[0] |
| reason = close_info[1] |
| buffer = Buffer(capacity=len(reason) + 4) |
| buffer.push_uint32(code) |
| buffer.push_bytes(reason) |
| capsule =\ |
| H3Capsule(CapsuleType.CLOSE_WEBTRANSPORT_SESSION, buffer.data) |
| self._http.send_data( |
| session_stream_id, capsule.encode(), end_stream=False) |
| |
| self._http.send_data(session_stream_id, b'', end_stream=True) |
| # TODO(yutakahirano): Reset all other streams. |
| # TODO(yutakahirano): Reject future stream open requests |
| # We need to wait for the stream data to arrive at the client, and then |
| # we need to close the connection. At this moment we're relying on the |
| # client's behavior. |
| # TODO(yutakahirano): Implement the above. |
| |
| def create_unidirectional_stream(self) -> int: |
| """ |
| Create a unidirectional WebTransport stream and return the stream ID. |
| """ |
| return self._http.create_webtransport_stream( |
| session_id=self.session_id, is_unidirectional=True) |
| |
| def create_bidirectional_stream(self) -> int: |
| """ |
| Create a bidirectional WebTransport stream and return the stream ID. |
| """ |
| stream_id = self._http.create_webtransport_stream( |
| session_id=self.session_id, is_unidirectional=False) |
| # TODO(bashi): Remove this workaround when aioquic supports receiving |
| # data on server-initiated bidirectional streams. |
| stream = self._http._get_or_create_stream(stream_id) |
| assert stream.frame_type is None |
| assert stream.session_id is None |
| stream.frame_type = FrameType.WEBTRANSPORT_STREAM |
| stream.session_id = self.session_id |
| return stream_id |
| |
| def send_stream_data(self, |
| stream_id: int, |
| data: bytes, |
| end_stream: bool = False) -> None: |
| """ |
| Send data on the specific stream. |
| |
| :param stream_id: The stream ID on which to send the data. |
| :param data: The data to send. |
| :param end_stream: If set to True, the stream will be closed. |
| """ |
| self._http._quic.send_stream_data(stream_id=stream_id, |
| data=data, |
| end_stream=end_stream) |
| |
| def send_datagram(self, data: bytes) -> None: |
| """ |
| Send data using a datagram frame. |
| |
| :param data: The data to send. |
| """ |
| if not self._protocol._allow_datagrams: |
| _logger.warn( |
| "Sending a datagram while that's now allowed - discarding it") |
| return |
| flow_id = self.session_id |
| if self._http.datagram_setting is not None: |
| # We must have a WebTransport Session ID at this point because |
| # an extended CONNECT request is already received. |
| assert self._protocol._session_stream_id is not None |
| # TODO(yutakahirano): Make sure if this is the correct logic. |
| # Chrome always use 0 for the initial stream and the initial flow |
| # ID, we cannot check the correctness with it. |
| flow_id = self._protocol._session_stream_id // 4 |
| self._http.send_datagram(flow_id=flow_id, data=data) |
| |
| def stop_stream(self, stream_id: int, code: int) -> None: |
| """ |
| Send a STOP_SENDING frame to the given stream. |
| :param code: the reason of the error. |
| """ |
| self._http._quic.stop_stream(stream_id, code) |
| |
| def reset_stream(self, stream_id: int, code: int) -> None: |
| """ |
| Send a RESET_STREAM frame to the given stream. |
| :param code: the reason of the error. |
| """ |
| self._http._quic.reset_stream(stream_id, code) |
| |
| |
| class WebTransportEventHandler: |
| def __init__(self, session: WebTransportSession, |
| callbacks: Dict[str, Any]) -> None: |
| self._session = session |
| self._callbacks = callbacks |
| |
| def _run_callback(self, callback_name: str, |
| *args: Any, **kwargs: Any) -> None: |
| if callback_name not in self._callbacks: |
| return |
| try: |
| self._callbacks[callback_name](*args, **kwargs) |
| except Exception as e: |
| _logger.warn(str(e)) |
| traceback.print_exc() |
| |
| def connect_received(self, response_headers: List[Tuple[bytes, |
| bytes]]) -> None: |
| self._run_callback("connect_received", self._session.request_headers, |
| response_headers) |
| |
| def session_established(self) -> None: |
| self._run_callback("session_established", self._session) |
| |
| def stream_data_received(self, stream_id: int, data: bytes, |
| stream_ended: bool) -> None: |
| self._run_callback("stream_data_received", self._session, stream_id, |
| data, stream_ended) |
| |
| def datagram_received(self, data: bytes) -> None: |
| self._run_callback("datagram_received", self._session, data) |
| |
| def session_closed( |
| self, |
| close_info: Optional[Tuple[int, bytes]], |
| abruptly: bool) -> None: |
| self._run_callback( |
| "session_closed", self._session, close_info, abruptly=abruptly) |
| |
| def stream_reset(self, stream_id: int, error_code: int) -> None: |
| self._run_callback( |
| "stream_reset", self._session, stream_id, error_code) |
| |
| |
| class SessionTicketStore: |
| """ |
| Simple in-memory store for session tickets. |
| """ |
| |
| def __init__(self) -> None: |
| self.tickets: Dict[bytes, SessionTicket] = {} |
| |
| def add(self, ticket: SessionTicket) -> None: |
| self.tickets[ticket.ticket] = ticket |
| |
| def pop(self, label: bytes) -> Optional[SessionTicket]: |
| return self.tickets.pop(label, None) |
| |
| |
| class WebTransportH3Server: |
| """ |
| A WebTransport over HTTP/3 for testing. |
| |
| :param host: Host from which to serve. |
| :param port: Port from which to serve. |
| :param doc_root: Document root for serving handlers. |
| :param cert_path: Path to certificate file to use. |
| :param key_path: Path to key file to use. |
| :param logger: a Logger object for this server. |
| """ |
| |
| def __init__(self, host: str, port: int, doc_root: str, cert_path: str, |
| key_path: str, logger: Optional[logging.Logger]) -> None: |
| self.host = host |
| self.port = port |
| self.doc_root = doc_root |
| self.cert_path = cert_path |
| self.key_path = key_path |
| self.started = False |
| global _doc_root |
| _doc_root = self.doc_root |
| global _logger |
| if logger is not None: |
| _logger = logger |
| |
| def start(self) -> None: |
| """Start the server.""" |
| self.server_thread = threading.Thread( |
| target=self._start_on_server_thread, daemon=True) |
| self.server_thread.start() |
| self.started = True |
| |
| def _start_on_server_thread(self) -> None: |
| secrets_log_file = None |
| if "SSLKEYLOGFILE" in os.environ: |
| try: |
| secrets_log_file = open(os.environ["SSLKEYLOGFILE"], "a") |
| except Exception as e: |
| _logger.warn(str(e)) |
| |
| configuration = QuicConfiguration( |
| alpn_protocols=H3_ALPN, |
| is_client=False, |
| max_datagram_frame_size=65536, |
| secrets_log_file=secrets_log_file, |
| ) |
| |
| _logger.info("Starting WebTransport over HTTP/3 server on %s:%s", |
| self.host, self.port) |
| |
| configuration.load_cert_chain(self.cert_path, self.key_path) |
| |
| ticket_store = SessionTicketStore() |
| |
| # On Windows, the default event loop is ProactorEventLoop but it |
| # doesn't seem to work when aioquic detects a connection loss. |
| # Use SelectorEventLoop to work around the problem. |
| if sys.platform == "win32": |
| asyncio.set_event_loop_policy( |
| asyncio.WindowsSelectorEventLoopPolicy()) |
| self.loop = asyncio.new_event_loop() |
| asyncio.set_event_loop(self.loop) |
| |
| self.loop.run_until_complete( |
| serve( |
| self.host, |
| self.port, |
| configuration=configuration, |
| create_protocol=WebTransportH3Protocol, |
| session_ticket_fetcher=ticket_store.pop, |
| session_ticket_handler=ticket_store.add, |
| )) |
| self.loop.run_forever() |
| |
| def stop(self) -> None: |
| """Stop the server.""" |
| if self.started: |
| asyncio.run_coroutine_threadsafe(self._stop_on_server_thread(), |
| self.loop) |
| self.server_thread.join() |
| _logger.info("Stopped WebTransport over HTTP/3 server on %s:%s", |
| self.host, self.port) |
| self.started = False |
| |
| async def _stop_on_server_thread(self) -> None: |
| self.loop.stop() |
| |
| |
| def server_is_running(host: str, port: int, timeout: float) -> bool: |
| """ |
| Check the WebTransport over HTTP/3 server is running at the given `host` and |
| `port`. |
| """ |
| loop = asyncio.get_event_loop() |
| return loop.run_until_complete(_connect_server_with_timeout(host, port, timeout)) |
| |
| |
| async def _connect_server_with_timeout(host: str, port: int, timeout: float) -> bool: |
| try: |
| await asyncio.wait_for(_connect_to_server(host, port), timeout=timeout) |
| except asyncio.TimeoutError: |
| _logger.warning("Failed to connect WebTransport over HTTP/3 server") |
| return False |
| return True |
| |
| |
| async def _connect_to_server(host: str, port: int) -> None: |
| configuration = QuicConfiguration( |
| alpn_protocols=H3_ALPN, |
| is_client=True, |
| verify_mode=ssl.CERT_NONE, |
| ) |
| |
| async with connect(host, port, configuration=configuration) as protocol: |
| await protocol.ping() |