Start creating a proper client for WebDriver BiDi (#28381)

Add a webdriver.bidi.client module with a BidiSession class that
implements only the BiDi part of the protocol. An ordinary Session may
have an associated BidiSession if the session was created through
HTTP. To make this easier, an enable_bidi flag is added to the
top-level session which automatically sets the capabilities required
to enable BiDi.

The session allows sending commands, and registering handlers for
events.

In addition to the session class, this adds some intial work on
creating a structure for commands, with a BidiModule abstract class
that can be used as the basis for command implementations and an
@command decorator that handles actually sending the command and
waiting for the result.

Co-authored-by: Maksim Sadym <69349599+sadym-chromium@users.noreply.github.com>
diff --git a/tools/webdriver/webdriver/bidi/__init__.py b/tools/webdriver/webdriver/bidi/__init__.py
new file mode 100644
index 0000000..e7c5633
--- /dev/null
+++ b/tools/webdriver/webdriver/bidi/__init__.py
@@ -0,0 +1,3 @@
+# flake8: noqa
+
+from .client import BidiSession
diff --git a/tools/webdriver/webdriver/bidi/client.py b/tools/webdriver/webdriver/bidi/client.py
new file mode 100644
index 0000000..4309a25
--- /dev/null
+++ b/tools/webdriver/webdriver/bidi/client.py
@@ -0,0 +1,393 @@
+import asyncio
+import functools
+import json
+import logging
+import sys
+from collections import defaultdict
+from typing import Any, Awaitable, Callable, Coroutine, List, Optional, Mapping, MutableMapping
+from urllib.parse import urljoin, urlparse
+
+import websockets
+
+
+logger = logging.getLogger("webdriver.bidi")
+
+
+class BidiException(Exception):
+    def __init__(self, err: str, msg: str, stack: Optional[str] = None):
+        self.err = err
+        self.msg = msg
+        self.stack = stack
+
+
+def get_running_loop() -> asyncio.AbstractEventLoop:
+    if sys.version_info >= (3, 7):
+        return asyncio.get_running_loop()
+    # Unlike the above, this will actually create an event loop
+    # if there isn't one; hopefully running tests in Python >= 3.7
+    # will allow us to catch any behaviour difference
+    return asyncio.get_event_loop()
+
+
+class BidiSession:
+    """A WebDriver BiDi session.
+
+    This is the main representation of a BiDi session and provides the
+    interface for running commands in the session, and for attaching
+    event handlers to the session. For example:
+
+    async def on_log(data):
+        print(data)
+
+    session = BidiSession("ws://localhost:4445", capabilities)
+    session.add_event_listener("log.entryAdded", on_log)
+    await session.start()
+    await session.subscribe("log.entryAdded")
+    # Do some stuff with the session
+    session.end()
+
+    If the session id is provided it's assumed that the underlying
+    WebDriver session was already created, and the WebSocket URL was
+    taken from the new session response. If no session id is provided, it's
+    assumed that a BiDi-only session should be created when start() is called.
+
+    It can also be used as a context manager, with the WebSocket transport
+    implictly being created when the context is entered, and closed when
+    the context is exited.
+
+    :param websocket_url: WebSockets URL on which to connect to the session.
+                          This excludes any path component.
+    :param session_id: String id of existing HTTP session
+    :param capabilities: Capabilities response of existing session
+    :param requested_capabilities: Dictionary representing the capabilities request.
+
+    """
+
+    def __init__(self,
+                 websocket_url: str,
+                 session_id: Optional[str] = None,
+                 capabilities: Optional[Mapping[str, Any]] = None,
+                 requested_capabilities: Optional[Mapping[str, Any]] = None,
+                 loop: Optional[asyncio.AbstractEventLoop] = None):
+        self.transport: Optional[Transport] = None
+
+        # The full URL for a websocket looks like
+        # ws://<host>:<port>/session when we're creating a session and
+        # ws://<host>:<port>/session/<sessionid> when we're connecting to an existing session.
+        # To be user friendly, handle the case where the class was created with either a
+        # full URL including the path, and also the case where just a server url is passed in.
+        parsed_url = urlparse(websocket_url)
+        if parsed_url.path == "" or parsed_url.path == "/":
+            if session_id is None:
+                websocket_url = urljoin(websocket_url, "session")
+            else:
+                websocket_url = urljoin(websocket_url, f"session/{session_id}")
+        else:
+            if session_id is not None:
+                if parsed_url.path != f"/session/{session_id}":
+                    raise ValueError(f"WebSocket URL {session_id} doesn't match session id")
+            else:
+                if parsed_url.path != "/session":
+                    raise ValueError(f"WebSocket URL {session_id} doesn't match session url")
+
+        if session_id is None and capabilities is not None:
+            raise ValueError("Tried to create BiDi-only session with existing capabilities")
+
+        self.websocket_url = websocket_url
+        self.requested_capabilities = requested_capabilities
+        self.capabilities = capabilities
+        self.session_id = session_id
+
+        self.command_id = 0
+        self.pending_commands: MutableMapping[int, "asyncio.Future[Any]"] = {}
+        self.event_listeners: MutableMapping[Optional[str], List[Callable[[str, Mapping[str, Any]], Any]]] = defaultdict(list)
+
+        # Modules.
+        # For each module, have a property representing that module
+        self.session = Session(self)
+
+        if loop is None:
+            loop = get_running_loop()
+        self.loop = loop
+
+    @classmethod
+    def from_http(cls,
+                  session_id: str,
+                  capabilities: Mapping[str, Any],
+                  loop: Optional[asyncio.AbstractEventLoop] = None) -> "BidiSession":
+        """Create a BiDi session from an existing HTTP session
+
+        :param session_id: String id of the session
+        :param capabilities: Capabilities returned in the New Session HTTP response."""
+        websocket_url = capabilities.get("webSocketUrl")
+        if websocket_url is None:
+            raise ValueError("No webSocketUrl found in capabilities")
+        if not isinstance(websocket_url, str):
+            raise ValueError("webSocketUrl is not a string")
+        return cls(websocket_url, session_id=session_id, capabilities=capabilities, loop=loop)
+
+    @classmethod
+    def bidi_only(cls,
+                  websocket_url: str,
+                  requested_capabilities: Optional[Mapping[str, Any]],
+                  loop: Optional[asyncio.AbstractEventLoop] = None) -> "BidiSession":
+        """Create a BiDi session where there is no existing HTTP session
+
+        :param webdocket_url: URL to the WebSocket server listening for BiDi connections
+        :param requested_capabilities: Capabilities request for establishing the session."""
+        return cls(websocket_url, requested_capabilities=requested_capabilities, loop=loop)
+
+    async def __aenter__(self) -> "BidiSession":
+        await self.start()
+        return self
+
+    async def __aexit__(self, *args: Any) -> None:
+        await self.end()
+
+    async def start(self) -> None:
+        """Connect to the WebDriver BiDi remote via WebSockets"""
+        self.transport = Transport(self.websocket_url, self.on_message, loop=self.loop)
+
+        if self.session_id is None:
+            self.session_id, self.capabilities = await self.session.new(self.requested_capabilities)
+
+        await self.transport.start()
+
+    async def send_command(self, method: str, params: Mapping[str, Any]) -> Awaitable[Mapping[str, Any]]:
+        """Send a command to the remote server"""
+        # this isn't threadsafe
+        self.command_id += 1
+        command_id = self.command_id
+
+        body = {
+            "id": command_id,
+            "method": method,
+            "params": params
+        }
+        assert command_id not in self.pending_commands
+        self.pending_commands[command_id] = self.loop.create_future()
+        assert self.transport is not None
+        await self.transport.send(body)
+
+        return self.pending_commands[command_id]
+
+    async def on_message(self, data: Mapping[str, Any]) -> None:
+        """Handle a message from the remote server"""
+        if "id" in data:
+            # This is a command response or error
+            future = self.pending_commands.get(data["id"])
+            if future is None:
+                raise ValueError(f"No pending command with id {data['id']}")
+            if "result" in data:
+                future.set_result(data["result"])
+            elif "error" in data and "message" in data:
+                assert isinstance(data["error"], str)
+                assert isinstance(data["message"], str)
+                future.set_exception(BidiException(data["error"],
+                                                   data["message"],
+                                                   data.get("stacktrace")))
+            else:
+                raise ValueError(f"Unexpected message: {data!r}")
+        elif "method" in data and "params" in data:
+            # This is an event
+            method = data["method"]
+            listeners = self.event_listeners.get(method, [])
+            if not listeners:
+                listeners = self.event_listeners.get(None, [])
+            for listener in listeners:
+                await listener(method, data["params"])
+        else:
+            raise ValueError(f"Unexpected message: {data!r}")
+
+    async def end(self) -> None:
+        """Close websocket connection."""
+        assert self.transport is not None
+        await self.transport.end()
+
+    def add_event_listener(self,
+                           name: Optional[str],
+                           fn: Callable[[str, Mapping[str, Any]], Awaitable[Any]]) -> None:
+        """Add a listener for the event with a given name.
+
+        If name is None, the listener is called for all messages that are not otherwise
+        handled.
+
+        :param name: Name of event to listen for or None to register a default handler
+        :param fn: Async callback function that receives event data
+        """
+        self.event_listeners[name].append(fn)
+
+
+class Transport:
+    """Low level message handler for the WebSockets connection"""
+    def __init__(self, url: str,
+                 msg_handler: Callable[[Mapping[str, Any]], Coroutine[Any, Any, None]],
+                 loop: Optional[asyncio.AbstractEventLoop] = None):
+        self.url = url
+        self.connection: Optional[websockets.WebSocketClientProtocol] = None
+        self.msg_handler = msg_handler
+        self.send_buf: List[Mapping[str, Any]] = []
+
+        if loop is None:
+            loop = get_running_loop()
+        self.loop = loop
+
+        self.read_message_task: Optional[asyncio.Task[Any]] = None
+
+    async def start(self) -> None:
+        self.connection = await websockets.client.connect(self.url)
+        self.read_message_task = self.loop.create_task(self.read_messages())
+
+        for msg in self.send_buf:
+            await self._send(self.connection, msg)
+
+    async def send(self, data: Mapping[str, Any]) -> None:
+        if self.connection is not None:
+            await self._send(self.connection, data)
+        else:
+            self.send_buf.append(data)
+
+    @staticmethod
+    async def _send(connection: websockets.WebSocketClientProtocol, data: Mapping[str, Any]) -> None:
+        msg = json.dumps(data)
+        logger.debug("→ %s", msg)
+        await connection.send(msg)
+
+    async def handle(self, msg: str) -> None:
+        logger.debug("← %s", msg)
+        data = json.loads(msg)
+        await self.msg_handler(data)
+
+    async def end(self) -> None:
+        if self.connection:
+            await self.connection.close()
+            self.connection = None
+
+    async def read_messages(self) -> None:
+        assert self.connection is not None
+        async for msg in self.connection:
+            if not isinstance(msg, str):
+                raise ValueError("Got a binary message")
+            await self.handle(msg)
+
+
+class command:
+    """Decorator for implementing bidi commands
+
+    Implementing a command involves specifying an async function that
+    builds the parameters to the command. The decorator arranges those
+    parameters to be turned into a send_command call, using the class
+    and method names to determine the method in the call.
+
+    Commands decorated in this way don't return a future, but await
+    the actual response. In some cases it can be useful to
+    post-process this response before returning it to the client. This
+    can be done by specifying a second decorated method like
+    @command_name.result. That method will then be called once the
+    result of the original command is known, and the return value of
+    the method used as the response of the command.
+
+    So for an example, if we had a command test.testMethod, which
+    returned a result which we want to convert to a TestResult type,
+    the implementation might look like:
+
+    class Test(BidiModule):
+        @command
+        def test_method(self, test_data=None):
+            return {"testData": test_data}
+
+       @test_method.result
+       def convert_test_method_result(self, result):
+           return TestData(**result)
+    """
+
+    def __init__(self, fn: Callable[..., Mapping[str, Any]]):
+        self.params_fn = fn
+        self.result_fn: Optional[Callable[..., Any]] = None
+
+    def result(self, fn: Callable[[Any, MutableMapping[str, Any]], Mapping[str, Any]]) -> None:
+        self.result_fn = fn
+
+    def __set_name__(self, owner: Any, name: str) -> None:
+        # This is called when the class is created
+        # see https://docs.python.org/3/reference/datamodel.html#object.__set_name__
+        params_fn = self.params_fn
+        result_fn = self.result_fn
+
+        @functools.wraps(params_fn)
+        async def inner(self: Any, **kwargs: Any) -> Any:
+            params = params_fn(self, **kwargs)
+
+            # Convert the classname and the method name to a bidi command name
+            mod_name = owner.__name__.lower()
+            if hasattr(owner, "prefix"):
+                mod_name = f"{owner.prefix}:{mod_name}"
+            cmd_name = f"{mod_name}.{to_camelcase(name)}"
+
+            future = await self.session.send_command(cmd_name, params)
+            result = await future
+
+            if result_fn is not None:
+                # Convert the result if we have a conversion function defined
+                result = result_fn(self, result)
+            return result
+
+        # Overwrite the method on the owner class with the wrapper
+        setattr(owner, name, inner)
+
+    def __call__(*args: Any, **kwargs: Any) -> Awaitable[Any]:
+        # This isn't really used, but mypy doesn't understand __set_name__
+        pass
+
+
+def to_camelcase(name: str) -> str:
+    """Convert a python style method name foo_bar to a BiDi command name fooBar"""
+    parts = name.split("_")
+    parts[0] = parts[0].lower()
+    for i in range(1, len(parts)):
+        parts[i] = parts[i].title()
+    return "".join(parts)
+
+
+class BidiModule:
+    def __init__(self, session: BidiSession):
+        self.session = session
+
+
+class Session(BidiModule):
+    @command
+    def new(self, capabilities: Mapping[str, Any]) -> Mapping[str, Mapping[str, Any]]:
+        return {"capabilities": capabilities}
+
+    @new.result
+    def _new(self, result: Mapping[str, Any]) -> Any:
+        return result.get("session_id"), result.get("capabilities", {})
+
+    @command
+    def subscribe(self,
+                  events: Optional[List[str]] = None,
+                  contexts: Optional[List[str]] = None) -> Mapping[str, Any]:
+        params: MutableMapping[str, Any] = {"events": events if events is not None else []}
+        if contexts is not None:
+            params["contexts"] = contexts
+        return params
+
+    @command
+    def unsubscribe(self,
+                    events: Optional[List[str]] = None,
+                    contexts: Optional[List[str]] = None) -> Mapping[str, Any]:
+        params: MutableMapping[str, Any] = {"events": events if events is not None else []}
+        if contexts is not None:
+            params["contexts"] = contexts
+        return params
+
+
+class Test(BidiModule):
+    """Very temporary module that does nothing, except demonstrate a vendor prefix and
+    provide a way to work with Gecko's current skeleton implementation."""
+
+    prefix = "moz"
+
+    @command
+    def test_method(self, **kwargs):
+        return kwargs
diff --git a/tools/webdriver/webdriver/client.py b/tools/webdriver/webdriver/client.py
index 69ab749..a7a1357 100644
--- a/tools/webdriver/webdriver/client.py
+++ b/tools/webdriver/webdriver/client.py
@@ -1,6 +1,7 @@
 from . import error
 from . import protocol
 from . import transport
+from .bidi.client import BidiSession
 
 from urllib import parse as urlparse
 
@@ -473,8 +474,16 @@
                  host,
                  port,
                  url_prefix="/",
+                 enable_bidi=False,
                  capabilities=None,
                  extension=None):
+
+        if enable_bidi:
+            if capabilities is not None:
+                capabilities.setdefault("alwaysMatch", {}).update({"webSocketUrl": True})
+            else:
+                capabilities = {"alwaysMatch": {"webSocketUrl": True}}
+
         self.transport = transport.HTTPWireProtocol(host, port, url_prefix)
         self.requested_capabilities = capabilities
         self.capabilities = None
@@ -482,6 +491,8 @@
         self.timeouts = None
         self.window = None
         self.find = None
+        self.enable_bidi = enable_bidi
+        self.bidi_session = None
         self.extension = None
         self.extension_cls = extension
 
@@ -533,6 +544,14 @@
         self.session_id = value["sessionId"]
         self.capabilities = value["capabilities"]
 
+        if "webSocketUrl" in self.capabilities:
+            self.bidi_session = BidiSession.from_http(self.session_id,
+                                                      self.capabilities)
+        elif self.enable_bidi:
+            self.end()
+            raise error.SessionNotCreatedException(
+                "Requested bidi session, but webSocketUrl capability not found")
+
         if self.extension_cls:
             self.extension = self.extension_cls(self)
 
diff --git a/webdriver/tests/bidi/new_session/connect.py b/webdriver/tests/bidi/new_session/connect.py
index 2b02d8a..357da31 100644
--- a/webdriver/tests/bidi/new_session/connect.py
+++ b/webdriver/tests/bidi/new_session/connect.py
@@ -8,7 +8,6 @@
 @pytest.mark.asyncio
 @pytest.mark.capabilities({"webSocketUrl": True})
 async def test_websocket_url_connect(session):
-    assert not isinstance(session, webdriver.BidiSession)
     websocket_url = session.capabilities["webSocketUrl"]
     async with websockets.connect(websocket_url) as websocket:
         await websocket.send("Hello world!")
@@ -17,14 +16,14 @@
 # using bidi_session is the recommended way to test bidi
 @pytest.mark.asyncio
 async def test_bidi_session_send(bidi_session):
-    await bidi_session.websocket_transport.send("test_bidi_session: send")
+    await bidi_session.send_command("test.test", {})
 
 # bidi session following a bidi session with a different capabilities
 # to test session recreation
 @pytest.mark.asyncio
 @pytest.mark.capabilities({"acceptInsecureCerts": True})
 async def test_bidi_session_with_different_capability(bidi_session):
-    await bidi_session.websocket_transport.send("test_bidi_session: different capability")
+    await bidi_session.send_command("test.test", {})
 
 # classic session following a bidi session to test session
 # recreation
@@ -32,4 +31,4 @@
 # classic session is not impacted by bidi tests
 @pytest.mark.asyncio
 def test_classic_after_bidi_session(session):
-    assert not isinstance(session, webdriver.BidiSession)
+    assert not isinstance(session, webdriver.bidi.BidiSession)
diff --git a/webdriver/tests/support/fixtures.py b/webdriver/tests/support/fixtures.py
index 6f166a9..db32931 100644
--- a/webdriver/tests/support/fixtures.py
+++ b/webdriver/tests/support/fixtures.py
@@ -195,13 +195,15 @@
     await reset_current_session_if_necessary(caps, True)
 
     if _current_session is None:
-        _current_session = webdriver.BidiSession(
+        _current_session = webdriver.Session(
             configuration["host"],
             configuration["port"],
-            capabilities=caps)
-    try:
-        await _current_session.start()
+            capabilities=caps,
+            enable_bidi=True)
 
+    try:
+        _current_session.start()
+        await _current_session.bidi_session.start()
     except webdriver.error.SessionNotCreatedException:
         if not _current_session.session_id:
             raise
@@ -210,8 +212,9 @@
     _current_session.window.size = defaults.WINDOW_SIZE
     _current_session.window.position = defaults.WINDOW_POSITION
 
-    yield _current_session
+    yield _current_session.bidi_session
 
+    await _current_session.bidi_session.end()
     cleanup_session(_current_session)