blob: 4309a25e387316f80f22cef4b60bb2bd45994454 [file] [log] [blame]
import asyncio
import functools
import json
import logging
import sys
from collections import defaultdict
from typing import Any, Awaitable, Callable, Coroutine, List, Optional, Mapping, MutableMapping
from urllib.parse import urljoin, urlparse
import websockets
logger = logging.getLogger("webdriver.bidi")
class BidiException(Exception):
def __init__(self, err: str, msg: str, stack: Optional[str] = None):
self.err = err
self.msg = msg
self.stack = stack
def get_running_loop() -> asyncio.AbstractEventLoop:
if sys.version_info >= (3, 7):
return asyncio.get_running_loop()
# Unlike the above, this will actually create an event loop
# if there isn't one; hopefully running tests in Python >= 3.7
# will allow us to catch any behaviour difference
return asyncio.get_event_loop()
class BidiSession:
"""A WebDriver BiDi session.
This is the main representation of a BiDi session and provides the
interface for running commands in the session, and for attaching
event handlers to the session. For example:
async def on_log(data):
print(data)
session = BidiSession("ws://localhost:4445", capabilities)
session.add_event_listener("log.entryAdded", on_log)
await session.start()
await session.subscribe("log.entryAdded")
# Do some stuff with the session
session.end()
If the session id is provided it's assumed that the underlying
WebDriver session was already created, and the WebSocket URL was
taken from the new session response. If no session id is provided, it's
assumed that a BiDi-only session should be created when start() is called.
It can also be used as a context manager, with the WebSocket transport
implictly being created when the context is entered, and closed when
the context is exited.
:param websocket_url: WebSockets URL on which to connect to the session.
This excludes any path component.
:param session_id: String id of existing HTTP session
:param capabilities: Capabilities response of existing session
:param requested_capabilities: Dictionary representing the capabilities request.
"""
def __init__(self,
websocket_url: str,
session_id: Optional[str] = None,
capabilities: Optional[Mapping[str, Any]] = None,
requested_capabilities: Optional[Mapping[str, Any]] = None,
loop: Optional[asyncio.AbstractEventLoop] = None):
self.transport: Optional[Transport] = None
# The full URL for a websocket looks like
# ws://<host>:<port>/session when we're creating a session and
# ws://<host>:<port>/session/<sessionid> when we're connecting to an existing session.
# To be user friendly, handle the case where the class was created with either a
# full URL including the path, and also the case where just a server url is passed in.
parsed_url = urlparse(websocket_url)
if parsed_url.path == "" or parsed_url.path == "/":
if session_id is None:
websocket_url = urljoin(websocket_url, "session")
else:
websocket_url = urljoin(websocket_url, f"session/{session_id}")
else:
if session_id is not None:
if parsed_url.path != f"/session/{session_id}":
raise ValueError(f"WebSocket URL {session_id} doesn't match session id")
else:
if parsed_url.path != "/session":
raise ValueError(f"WebSocket URL {session_id} doesn't match session url")
if session_id is None and capabilities is not None:
raise ValueError("Tried to create BiDi-only session with existing capabilities")
self.websocket_url = websocket_url
self.requested_capabilities = requested_capabilities
self.capabilities = capabilities
self.session_id = session_id
self.command_id = 0
self.pending_commands: MutableMapping[int, "asyncio.Future[Any]"] = {}
self.event_listeners: MutableMapping[Optional[str], List[Callable[[str, Mapping[str, Any]], Any]]] = defaultdict(list)
# Modules.
# For each module, have a property representing that module
self.session = Session(self)
if loop is None:
loop = get_running_loop()
self.loop = loop
@classmethod
def from_http(cls,
session_id: str,
capabilities: Mapping[str, Any],
loop: Optional[asyncio.AbstractEventLoop] = None) -> "BidiSession":
"""Create a BiDi session from an existing HTTP session
:param session_id: String id of the session
:param capabilities: Capabilities returned in the New Session HTTP response."""
websocket_url = capabilities.get("webSocketUrl")
if websocket_url is None:
raise ValueError("No webSocketUrl found in capabilities")
if not isinstance(websocket_url, str):
raise ValueError("webSocketUrl is not a string")
return cls(websocket_url, session_id=session_id, capabilities=capabilities, loop=loop)
@classmethod
def bidi_only(cls,
websocket_url: str,
requested_capabilities: Optional[Mapping[str, Any]],
loop: Optional[asyncio.AbstractEventLoop] = None) -> "BidiSession":
"""Create a BiDi session where there is no existing HTTP session
:param webdocket_url: URL to the WebSocket server listening for BiDi connections
:param requested_capabilities: Capabilities request for establishing the session."""
return cls(websocket_url, requested_capabilities=requested_capabilities, loop=loop)
async def __aenter__(self) -> "BidiSession":
await self.start()
return self
async def __aexit__(self, *args: Any) -> None:
await self.end()
async def start(self) -> None:
"""Connect to the WebDriver BiDi remote via WebSockets"""
self.transport = Transport(self.websocket_url, self.on_message, loop=self.loop)
if self.session_id is None:
self.session_id, self.capabilities = await self.session.new(self.requested_capabilities)
await self.transport.start()
async def send_command(self, method: str, params: Mapping[str, Any]) -> Awaitable[Mapping[str, Any]]:
"""Send a command to the remote server"""
# this isn't threadsafe
self.command_id += 1
command_id = self.command_id
body = {
"id": command_id,
"method": method,
"params": params
}
assert command_id not in self.pending_commands
self.pending_commands[command_id] = self.loop.create_future()
assert self.transport is not None
await self.transport.send(body)
return self.pending_commands[command_id]
async def on_message(self, data: Mapping[str, Any]) -> None:
"""Handle a message from the remote server"""
if "id" in data:
# This is a command response or error
future = self.pending_commands.get(data["id"])
if future is None:
raise ValueError(f"No pending command with id {data['id']}")
if "result" in data:
future.set_result(data["result"])
elif "error" in data and "message" in data:
assert isinstance(data["error"], str)
assert isinstance(data["message"], str)
future.set_exception(BidiException(data["error"],
data["message"],
data.get("stacktrace")))
else:
raise ValueError(f"Unexpected message: {data!r}")
elif "method" in data and "params" in data:
# This is an event
method = data["method"]
listeners = self.event_listeners.get(method, [])
if not listeners:
listeners = self.event_listeners.get(None, [])
for listener in listeners:
await listener(method, data["params"])
else:
raise ValueError(f"Unexpected message: {data!r}")
async def end(self) -> None:
"""Close websocket connection."""
assert self.transport is not None
await self.transport.end()
def add_event_listener(self,
name: Optional[str],
fn: Callable[[str, Mapping[str, Any]], Awaitable[Any]]) -> None:
"""Add a listener for the event with a given name.
If name is None, the listener is called for all messages that are not otherwise
handled.
:param name: Name of event to listen for or None to register a default handler
:param fn: Async callback function that receives event data
"""
self.event_listeners[name].append(fn)
class Transport:
"""Low level message handler for the WebSockets connection"""
def __init__(self, url: str,
msg_handler: Callable[[Mapping[str, Any]], Coroutine[Any, Any, None]],
loop: Optional[asyncio.AbstractEventLoop] = None):
self.url = url
self.connection: Optional[websockets.WebSocketClientProtocol] = None
self.msg_handler = msg_handler
self.send_buf: List[Mapping[str, Any]] = []
if loop is None:
loop = get_running_loop()
self.loop = loop
self.read_message_task: Optional[asyncio.Task[Any]] = None
async def start(self) -> None:
self.connection = await websockets.client.connect(self.url)
self.read_message_task = self.loop.create_task(self.read_messages())
for msg in self.send_buf:
await self._send(self.connection, msg)
async def send(self, data: Mapping[str, Any]) -> None:
if self.connection is not None:
await self._send(self.connection, data)
else:
self.send_buf.append(data)
@staticmethod
async def _send(connection: websockets.WebSocketClientProtocol, data: Mapping[str, Any]) -> None:
msg = json.dumps(data)
logger.debug("→ %s", msg)
await connection.send(msg)
async def handle(self, msg: str) -> None:
logger.debug("← %s", msg)
data = json.loads(msg)
await self.msg_handler(data)
async def end(self) -> None:
if self.connection:
await self.connection.close()
self.connection = None
async def read_messages(self) -> None:
assert self.connection is not None
async for msg in self.connection:
if not isinstance(msg, str):
raise ValueError("Got a binary message")
await self.handle(msg)
class command:
"""Decorator for implementing bidi commands
Implementing a command involves specifying an async function that
builds the parameters to the command. The decorator arranges those
parameters to be turned into a send_command call, using the class
and method names to determine the method in the call.
Commands decorated in this way don't return a future, but await
the actual response. In some cases it can be useful to
post-process this response before returning it to the client. This
can be done by specifying a second decorated method like
@command_name.result. That method will then be called once the
result of the original command is known, and the return value of
the method used as the response of the command.
So for an example, if we had a command test.testMethod, which
returned a result which we want to convert to a TestResult type,
the implementation might look like:
class Test(BidiModule):
@command
def test_method(self, test_data=None):
return {"testData": test_data}
@test_method.result
def convert_test_method_result(self, result):
return TestData(**result)
"""
def __init__(self, fn: Callable[..., Mapping[str, Any]]):
self.params_fn = fn
self.result_fn: Optional[Callable[..., Any]] = None
def result(self, fn: Callable[[Any, MutableMapping[str, Any]], Mapping[str, Any]]) -> None:
self.result_fn = fn
def __set_name__(self, owner: Any, name: str) -> None:
# This is called when the class is created
# see https://docs.python.org/3/reference/datamodel.html#object.__set_name__
params_fn = self.params_fn
result_fn = self.result_fn
@functools.wraps(params_fn)
async def inner(self: Any, **kwargs: Any) -> Any:
params = params_fn(self, **kwargs)
# Convert the classname and the method name to a bidi command name
mod_name = owner.__name__.lower()
if hasattr(owner, "prefix"):
mod_name = f"{owner.prefix}:{mod_name}"
cmd_name = f"{mod_name}.{to_camelcase(name)}"
future = await self.session.send_command(cmd_name, params)
result = await future
if result_fn is not None:
# Convert the result if we have a conversion function defined
result = result_fn(self, result)
return result
# Overwrite the method on the owner class with the wrapper
setattr(owner, name, inner)
def __call__(*args: Any, **kwargs: Any) -> Awaitable[Any]:
# This isn't really used, but mypy doesn't understand __set_name__
pass
def to_camelcase(name: str) -> str:
"""Convert a python style method name foo_bar to a BiDi command name fooBar"""
parts = name.split("_")
parts[0] = parts[0].lower()
for i in range(1, len(parts)):
parts[i] = parts[i].title()
return "".join(parts)
class BidiModule:
def __init__(self, session: BidiSession):
self.session = session
class Session(BidiModule):
@command
def new(self, capabilities: Mapping[str, Any]) -> Mapping[str, Mapping[str, Any]]:
return {"capabilities": capabilities}
@new.result
def _new(self, result: Mapping[str, Any]) -> Any:
return result.get("session_id"), result.get("capabilities", {})
@command
def subscribe(self,
events: Optional[List[str]] = None,
contexts: Optional[List[str]] = None) -> Mapping[str, Any]:
params: MutableMapping[str, Any] = {"events": events if events is not None else []}
if contexts is not None:
params["contexts"] = contexts
return params
@command
def unsubscribe(self,
events: Optional[List[str]] = None,
contexts: Optional[List[str]] = None) -> Mapping[str, Any]:
params: MutableMapping[str, Any] = {"events": events if events is not None else []}
if contexts is not None:
params["contexts"] = contexts
return params
class Test(BidiModule):
"""Very temporary module that does nothing, except demonstrate a vendor prefix and
provide a way to work with Gecko's current skeleton implementation."""
prefix = "moz"
@command
def test_method(self, **kwargs):
return kwargs