| import argparse |
| import asyncio |
| import json |
| import logging |
| import os |
| import pickle |
| import ssl |
| import time |
| from collections import deque |
| from typing import Callable, Deque, Dict, List, Optional, Union, cast |
| from urllib.parse import urlparse |
| |
| import wsproto |
| import wsproto.events |
| |
| import aioquic |
| from aioquic.asyncio.client import connect |
| from aioquic.asyncio.protocol import QuicConnectionProtocol |
| from aioquic.h0.connection import H0_ALPN, H0Connection |
| from aioquic.h3.connection import H3_ALPN, H3Connection |
| from aioquic.h3.events import ( |
| DataReceived, |
| H3Event, |
| HeadersReceived, |
| PushPromiseReceived, |
| ) |
| from aioquic.quic.configuration import QuicConfiguration |
| from aioquic.quic.events import QuicEvent |
| from aioquic.quic.logger import QuicLogger |
| from aioquic.tls import SessionTicket |
| |
| try: |
| import uvloop |
| except ImportError: |
| uvloop = None |
| |
| logger = logging.getLogger("client") |
| |
| HttpConnection = Union[H0Connection, H3Connection] |
| |
| USER_AGENT = "aioquic/" + aioquic.__version__ |
| |
| |
| class URL: |
| def __init__(self, url: str) -> None: |
| parsed = urlparse(url) |
| |
| self.authority = parsed.netloc |
| self.full_path = parsed.path |
| if parsed.query: |
| self.full_path += "?" + parsed.query |
| self.scheme = parsed.scheme |
| |
| |
| class HttpRequest: |
| def __init__( |
| self, method: str, url: URL, content: bytes = b"", headers: Dict = {} |
| ) -> None: |
| self.content = content |
| self.headers = headers |
| self.method = method |
| self.url = url |
| |
| |
| class WebSocket: |
| def __init__( |
| self, http: HttpConnection, stream_id: int, transmit: Callable[[], None] |
| ) -> None: |
| self.http = http |
| self.queue: asyncio.Queue[str] = asyncio.Queue() |
| self.stream_id = stream_id |
| self.subprotocol: Optional[str] = None |
| self.transmit = transmit |
| self.websocket = wsproto.Connection(wsproto.ConnectionType.CLIENT) |
| |
| async def close(self, code=1000, reason="") -> None: |
| """ |
| Perform the closing handshake. |
| """ |
| data = self.websocket.send( |
| wsproto.events.CloseConnection(code=code, reason=reason) |
| ) |
| self.http.send_data(stream_id=self.stream_id, data=data, end_stream=True) |
| self.transmit() |
| |
| async def recv(self) -> str: |
| """ |
| Receive the next message. |
| """ |
| return await self.queue.get() |
| |
| async def send(self, message: str) -> None: |
| """ |
| Send a message. |
| """ |
| assert isinstance(message, str) |
| |
| data = self.websocket.send(wsproto.events.TextMessage(data=message)) |
| self.http.send_data(stream_id=self.stream_id, data=data, end_stream=False) |
| self.transmit() |
| |
| def http_event_received(self, event: H3Event) -> None: |
| if isinstance(event, HeadersReceived): |
| for header, value in event.headers: |
| if header == b"sec-websocket-protocol": |
| self.subprotocol = value.decode() |
| elif isinstance(event, DataReceived): |
| self.websocket.receive_data(event.data) |
| |
| for ws_event in self.websocket.events(): |
| self.websocket_event_received(ws_event) |
| |
| def websocket_event_received(self, event: wsproto.events.Event) -> None: |
| if isinstance(event, wsproto.events.TextMessage): |
| self.queue.put_nowait(event.data) |
| |
| |
| class HttpClient(QuicConnectionProtocol): |
| def __init__(self, *args, **kwargs) -> None: |
| super().__init__(*args, **kwargs) |
| |
| self.pushes: Dict[int, Deque[H3Event]] = {} |
| self._http: Optional[HttpConnection] = None |
| self._request_events: Dict[int, Deque[H3Event]] = {} |
| self._request_waiter: Dict[int, asyncio.Future[Deque[H3Event]]] = {} |
| self._websockets: Dict[int, WebSocket] = {} |
| |
| if self._quic.configuration.alpn_protocols[0].startswith("hq-"): |
| self._http = H0Connection(self._quic) |
| else: |
| self._http = H3Connection(self._quic) |
| |
| async def get(self, url: str, headers: Dict = {}) -> Deque[H3Event]: |
| """ |
| Perform a GET request. |
| """ |
| return await self._request( |
| HttpRequest(method="GET", url=URL(url), headers=headers) |
| ) |
| |
| async def post(self, url: str, data: bytes, headers: Dict = {}) -> Deque[H3Event]: |
| """ |
| Perform a POST request. |
| """ |
| return await self._request( |
| HttpRequest(method="POST", url=URL(url), content=data, headers=headers) |
| ) |
| |
| async def websocket(self, url: str, subprotocols: List[str] = []) -> WebSocket: |
| """ |
| Open a WebSocket. |
| """ |
| request = HttpRequest(method="CONNECT", url=URL(url)) |
| stream_id = self._quic.get_next_available_stream_id() |
| websocket = WebSocket( |
| http=self._http, stream_id=stream_id, transmit=self.transmit |
| ) |
| |
| self._websockets[stream_id] = websocket |
| |
| headers = [ |
| (b":method", b"CONNECT"), |
| (b":scheme", b"https"), |
| (b":authority", request.url.authority.encode()), |
| (b":path", request.url.full_path.encode()), |
| (b":protocol", b"websocket"), |
| (b"user-agent", USER_AGENT.encode()), |
| (b"sec-websocket-version", b"13"), |
| ] |
| if subprotocols: |
| headers.append( |
| (b"sec-websocket-protocol", ", ".join(subprotocols).encode()) |
| ) |
| self._http.send_headers(stream_id=stream_id, headers=headers) |
| |
| self.transmit() |
| |
| return websocket |
| |
| def http_event_received(self, event: H3Event) -> None: |
| if isinstance(event, (HeadersReceived, DataReceived)): |
| stream_id = event.stream_id |
| if stream_id in self._request_events: |
| # http |
| self._request_events[event.stream_id].append(event) |
| if event.stream_ended: |
| request_waiter = self._request_waiter.pop(stream_id) |
| request_waiter.set_result(self._request_events.pop(stream_id)) |
| |
| elif stream_id in self._websockets: |
| # websocket |
| websocket = self._websockets[stream_id] |
| websocket.http_event_received(event) |
| |
| elif event.push_id in self.pushes: |
| # push |
| self.pushes[event.push_id].append(event) |
| |
| elif isinstance(event, PushPromiseReceived): |
| self.pushes[event.push_id] = deque() |
| self.pushes[event.push_id].append(event) |
| |
| def quic_event_received(self, event: QuicEvent) -> None: |
| # Â pass event to the HTTP layer |
| if self._http is not None: |
| for http_event in self._http.handle_event(event): |
| self.http_event_received(http_event) |
| |
| async def _request(self, request: HttpRequest): |
| stream_id = self._quic.get_next_available_stream_id() |
| self._http.send_headers( |
| stream_id=stream_id, |
| headers=[ |
| (b":method", request.method.encode()), |
| (b":scheme", request.url.scheme.encode()), |
| (b":authority", request.url.authority.encode()), |
| (b":path", request.url.full_path.encode()), |
| (b"user-agent", USER_AGENT.encode()), |
| ] |
| + [(k.encode(), v.encode()) for (k, v) in request.headers.items()], |
| ) |
| self._http.send_data(stream_id=stream_id, data=request.content, end_stream=True) |
| |
| waiter = self._loop.create_future() |
| self._request_events[stream_id] = deque() |
| self._request_waiter[stream_id] = waiter |
| self.transmit() |
| |
| return await asyncio.shield(waiter) |
| |
| |
| async def perform_http_request( |
| client: HttpClient, url: str, data: str, include: bool, output_dir: Optional[str], |
| ) -> None: |
| # perform request |
| start = time.time() |
| if data is not None: |
| http_events = await client.post( |
| url, |
| data=data.encode(), |
| headers={"content-type": "application/x-www-form-urlencoded"}, |
| ) |
| else: |
| http_events = await client.get(url) |
| elapsed = time.time() - start |
| |
| # print speed |
| octets = 0 |
| for http_event in http_events: |
| if isinstance(http_event, DataReceived): |
| octets += len(http_event.data) |
| logger.info( |
| "Received %d bytes in %.1f s (%.3f Mbps)" |
| % (octets, elapsed, octets * 8 / elapsed / 1000000) |
| ) |
| |
| # output response |
| if output_dir is not None: |
| output_path = os.path.join( |
| output_dir, os.path.basename(urlparse(url).path) or "index.html" |
| ) |
| with open(output_path, "wb") as output_file: |
| for http_event in http_events: |
| if isinstance(http_event, HeadersReceived) and include: |
| headers = b"" |
| for k, v in http_event.headers: |
| headers += k + b": " + v + b"\r\n" |
| if headers: |
| output_file.write(headers + b"\r\n") |
| elif isinstance(http_event, DataReceived): |
| output_file.write(http_event.data) |
| |
| |
| def save_session_ticket(ticket: SessionTicket) -> None: |
| """ |
| Callback which is invoked by the TLS engine when a new session ticket |
| is received. |
| """ |
| logger.info("New session ticket received") |
| if args.session_ticket: |
| with open(args.session_ticket, "wb") as fp: |
| pickle.dump(ticket, fp) |
| |
| |
| async def run( |
| configuration: QuicConfiguration, |
| urls: List[str], |
| data: str, |
| include: bool, |
| output_dir: Optional[str], |
| local_port: int, |
| ) -> None: |
| # parse URL |
| parsed = urlparse(urls[0]) |
| assert parsed.scheme in ( |
| "https", |
| "wss", |
| ), "Only https:// or wss:// URLs are supported." |
| if ":" in parsed.netloc: |
| host, port_str = parsed.netloc.split(":") |
| port = int(port_str) |
| else: |
| host = parsed.netloc |
| port = 443 |
| |
| async with connect( |
| host, |
| port, |
| configuration=configuration, |
| create_protocol=HttpClient, |
| session_ticket_handler=save_session_ticket, |
| local_port=local_port, |
| ) as client: |
| client = cast(HttpClient, client) |
| |
| if parsed.scheme == "wss": |
| ws = await client.websocket(urls[0], subprotocols=["chat", "superchat"]) |
| |
| # send some messages and receive reply |
| for i in range(2): |
| message = "Hello {}, WebSocket!".format(i) |
| print("> " + message) |
| await ws.send(message) |
| |
| message = await ws.recv() |
| print("< " + message) |
| |
| await ws.close() |
| else: |
| # perform request |
| coros = [ |
| perform_http_request( |
| client=client, |
| url=url, |
| data=data, |
| include=include, |
| output_dir=output_dir, |
| ) |
| for url in urls |
| ] |
| await asyncio.gather(*coros) |
| |
| |
| if __name__ == "__main__": |
| defaults = QuicConfiguration(is_client=True) |
| |
| parser = argparse.ArgumentParser(description="HTTP/3 client") |
| parser.add_argument( |
| "url", type=str, nargs="+", help="the URL to query (must be HTTPS)" |
| ) |
| parser.add_argument( |
| "--ca-certs", type=str, help="load CA certificates from the specified file" |
| ) |
| parser.add_argument( |
| "-d", "--data", type=str, help="send the specified data in a POST request" |
| ) |
| parser.add_argument( |
| "-i", |
| "--include", |
| action="store_true", |
| help="include the HTTP response headers in the output", |
| ) |
| parser.add_argument( |
| "--max-data", |
| type=int, |
| help="connection-wide flow control limit (default: %d)" % defaults.max_data, |
| ) |
| parser.add_argument( |
| "--max-stream-data", |
| type=int, |
| help="per-stream flow control limit (default: %d)" % defaults.max_stream_data, |
| ) |
| parser.add_argument( |
| "-k", |
| "--insecure", |
| action="store_true", |
| help="do not validate server certificate", |
| ) |
| parser.add_argument("--legacy-http", action="store_true", help="use HTTP/0.9") |
| parser.add_argument( |
| "--output-dir", type=str, help="write downloaded files to this directory", |
| ) |
| parser.add_argument( |
| "-q", "--quic-log", type=str, help="log QUIC events to a file in QLOG format" |
| ) |
| parser.add_argument( |
| "-l", |
| "--secrets-log", |
| type=str, |
| help="log secrets to a file, for use with Wireshark", |
| ) |
| parser.add_argument( |
| "-s", |
| "--session-ticket", |
| type=str, |
| help="read and write session ticket from the specified file", |
| ) |
| parser.add_argument( |
| "-v", "--verbose", action="store_true", help="increase logging verbosity" |
| ) |
| parser.add_argument( |
| "--local-port", type=int, default=0, help="local port to bind for connections", |
| ) |
| |
| args = parser.parse_args() |
| |
| logging.basicConfig( |
| format="%(asctime)s %(levelname)s %(name)s %(message)s", |
| level=logging.DEBUG if args.verbose else logging.INFO, |
| ) |
| |
| if args.output_dir is not None and not os.path.isdir(args.output_dir): |
| raise Exception("%s is not a directory" % args.output_dir) |
| |
| # prepare configuration |
| configuration = QuicConfiguration( |
| is_client=True, alpn_protocols=H0_ALPN if args.legacy_http else H3_ALPN |
| ) |
| if args.ca_certs: |
| configuration.load_verify_locations(args.ca_certs) |
| if args.insecure: |
| configuration.verify_mode = ssl.CERT_NONE |
| if args.max_data: |
| configuration.max_data = args.max_data |
| if args.max_stream_data: |
| configuration.max_stream_data = args.max_stream_data |
| if args.quic_log: |
| configuration.quic_logger = QuicLogger() |
| if args.secrets_log: |
| configuration.secrets_log_file = open(args.secrets_log, "a") |
| if args.session_ticket: |
| try: |
| with open(args.session_ticket, "rb") as fp: |
| configuration.session_ticket = pickle.load(fp) |
| except FileNotFoundError: |
| pass |
| |
| if uvloop is not None: |
| uvloop.install() |
| loop = asyncio.get_event_loop() |
| try: |
| loop.run_until_complete( |
| run( |
| configuration=configuration, |
| urls=args.url, |
| data=args.data, |
| include=args.include, |
| output_dir=args.output_dir, |
| local_port=args.local_port, |
| ) |
| ) |
| finally: |
| if configuration.quic_logger is not None: |
| with open(args.quic_log, "w") as logger_fp: |
| json.dump(configuration.quic_logger.to_dict(), logger_fp, indent=4) |