forked from web-platform-tests/wpt
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Start creating a proper client for WebDriver BiDi (web-platform-tests…
…#28381) Add a webdriver.bidi.client module with a BidiSession class that implements only the BiDi part of the protocol. An ordinary Session may have an associated BidiSession if the session was created through HTTP. To make this easier, an enable_bidi flag is added to the top-level session which automatically sets the capabilities required to enable BiDi. The session allows sending commands, and registering handlers for events. In addition to the session class, this adds some intial work on creating a structure for commands, with a BidiModule abstract class that can be used as the basis for command implementations and an @command decorator that handles actually sending the command and waiting for the result. Co-authored-by: Maksim Sadym <[email protected]>
- Loading branch information
1 parent
545f8a5
commit a38612f
Showing
5 changed files
with
426 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# flake8: noqa | ||
|
||
from .client import BidiSession |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,393 @@ | ||
import asyncio | ||
import functools | ||
import json | ||
import logging | ||
import sys | ||
from collections import defaultdict | ||
from typing import Any, Awaitable, Callable, Coroutine, List, Optional, Mapping, MutableMapping | ||
from urllib.parse import urljoin, urlparse | ||
|
||
import websockets | ||
|
||
|
||
logger = logging.getLogger("webdriver.bidi") | ||
|
||
|
||
class BidiException(Exception): | ||
def __init__(self, err: str, msg: str, stack: Optional[str] = None): | ||
self.err = err | ||
self.msg = msg | ||
self.stack = stack | ||
|
||
|
||
def get_running_loop() -> asyncio.AbstractEventLoop: | ||
if sys.version_info >= (3, 7): | ||
return asyncio.get_running_loop() | ||
# Unlike the above, this will actually create an event loop | ||
# if there isn't one; hopefully running tests in Python >= 3.7 | ||
# will allow us to catch any behaviour difference | ||
return asyncio.get_event_loop() | ||
|
||
|
||
class BidiSession: | ||
"""A WebDriver BiDi session. | ||
This is the main representation of a BiDi session and provides the | ||
interface for running commands in the session, and for attaching | ||
event handlers to the session. For example: | ||
async def on_log(data): | ||
print(data) | ||
session = BidiSession("ws://localhost:4445", capabilities) | ||
session.add_event_listener("log.entryAdded", on_log) | ||
await session.start() | ||
await session.subscribe("log.entryAdded") | ||
# Do some stuff with the session | ||
session.end() | ||
If the session id is provided it's assumed that the underlying | ||
WebDriver session was already created, and the WebSocket URL was | ||
taken from the new session response. If no session id is provided, it's | ||
assumed that a BiDi-only session should be created when start() is called. | ||
It can also be used as a context manager, with the WebSocket transport | ||
implictly being created when the context is entered, and closed when | ||
the context is exited. | ||
:param websocket_url: WebSockets URL on which to connect to the session. | ||
This excludes any path component. | ||
:param session_id: String id of existing HTTP session | ||
:param capabilities: Capabilities response of existing session | ||
:param requested_capabilities: Dictionary representing the capabilities request. | ||
""" | ||
|
||
def __init__(self, | ||
websocket_url: str, | ||
session_id: Optional[str] = None, | ||
capabilities: Optional[Mapping[str, Any]] = None, | ||
requested_capabilities: Optional[Mapping[str, Any]] = None, | ||
loop: Optional[asyncio.AbstractEventLoop] = None): | ||
self.transport: Optional[Transport] = None | ||
|
||
# The full URL for a websocket looks like | ||
# ws://<host>:<port>/session when we're creating a session and | ||
# ws://<host>:<port>/session/<sessionid> when we're connecting to an existing session. | ||
# To be user friendly, handle the case where the class was created with either a | ||
# full URL including the path, and also the case where just a server url is passed in. | ||
parsed_url = urlparse(websocket_url) | ||
if parsed_url.path == "" or parsed_url.path == "/": | ||
if session_id is None: | ||
websocket_url = urljoin(websocket_url, "session") | ||
else: | ||
websocket_url = urljoin(websocket_url, f"session/{session_id}") | ||
else: | ||
if session_id is not None: | ||
if parsed_url.path != f"/session/{session_id}": | ||
raise ValueError(f"WebSocket URL {session_id} doesn't match session id") | ||
else: | ||
if parsed_url.path != "/session": | ||
raise ValueError(f"WebSocket URL {session_id} doesn't match session url") | ||
|
||
if session_id is None and capabilities is not None: | ||
raise ValueError("Tried to create BiDi-only session with existing capabilities") | ||
|
||
self.websocket_url = websocket_url | ||
self.requested_capabilities = requested_capabilities | ||
self.capabilities = capabilities | ||
self.session_id = session_id | ||
|
||
self.command_id = 0 | ||
self.pending_commands: MutableMapping[int, "asyncio.Future[Any]"] = {} | ||
self.event_listeners: MutableMapping[Optional[str], List[Callable[[str, Mapping[str, Any]], Any]]] = defaultdict(list) | ||
|
||
# Modules. | ||
# For each module, have a property representing that module | ||
self.session = Session(self) | ||
|
||
if loop is None: | ||
loop = get_running_loop() | ||
self.loop = loop | ||
|
||
@classmethod | ||
def from_http(cls, | ||
session_id: str, | ||
capabilities: Mapping[str, Any], | ||
loop: Optional[asyncio.AbstractEventLoop] = None) -> "BidiSession": | ||
"""Create a BiDi session from an existing HTTP session | ||
:param session_id: String id of the session | ||
:param capabilities: Capabilities returned in the New Session HTTP response.""" | ||
websocket_url = capabilities.get("webSocketUrl") | ||
if websocket_url is None: | ||
raise ValueError("No webSocketUrl found in capabilities") | ||
if not isinstance(websocket_url, str): | ||
raise ValueError("webSocketUrl is not a string") | ||
return cls(websocket_url, session_id=session_id, capabilities=capabilities, loop=loop) | ||
|
||
@classmethod | ||
def bidi_only(cls, | ||
websocket_url: str, | ||
requested_capabilities: Optional[Mapping[str, Any]], | ||
loop: Optional[asyncio.AbstractEventLoop] = None) -> "BidiSession": | ||
"""Create a BiDi session where there is no existing HTTP session | ||
:param webdocket_url: URL to the WebSocket server listening for BiDi connections | ||
:param requested_capabilities: Capabilities request for establishing the session.""" | ||
return cls(websocket_url, requested_capabilities=requested_capabilities, loop=loop) | ||
|
||
async def __aenter__(self) -> "BidiSession": | ||
await self.start() | ||
return self | ||
|
||
async def __aexit__(self, *args: Any) -> None: | ||
await self.end() | ||
|
||
async def start(self) -> None: | ||
"""Connect to the WebDriver BiDi remote via WebSockets""" | ||
self.transport = Transport(self.websocket_url, self.on_message, loop=self.loop) | ||
|
||
if self.session_id is None: | ||
self.session_id, self.capabilities = await self.session.new(self.requested_capabilities) | ||
|
||
await self.transport.start() | ||
|
||
async def send_command(self, method: str, params: Mapping[str, Any]) -> Awaitable[Mapping[str, Any]]: | ||
"""Send a command to the remote server""" | ||
# this isn't threadsafe | ||
self.command_id += 1 | ||
command_id = self.command_id | ||
|
||
body = { | ||
"id": command_id, | ||
"method": method, | ||
"params": params | ||
} | ||
assert command_id not in self.pending_commands | ||
self.pending_commands[command_id] = self.loop.create_future() | ||
assert self.transport is not None | ||
await self.transport.send(body) | ||
|
||
return self.pending_commands[command_id] | ||
|
||
async def on_message(self, data: Mapping[str, Any]) -> None: | ||
"""Handle a message from the remote server""" | ||
if "id" in data: | ||
# This is a command response or error | ||
future = self.pending_commands.get(data["id"]) | ||
if future is None: | ||
raise ValueError(f"No pending command with id {data['id']}") | ||
if "result" in data: | ||
future.set_result(data["result"]) | ||
elif "error" in data and "message" in data: | ||
assert isinstance(data["error"], str) | ||
assert isinstance(data["message"], str) | ||
future.set_exception(BidiException(data["error"], | ||
data["message"], | ||
data.get("stacktrace"))) | ||
else: | ||
raise ValueError(f"Unexpected message: {data!r}") | ||
elif "method" in data and "params" in data: | ||
# This is an event | ||
method = data["method"] | ||
listeners = self.event_listeners.get(method, []) | ||
if not listeners: | ||
listeners = self.event_listeners.get(None, []) | ||
for listener in listeners: | ||
await listener(method, data["params"]) | ||
else: | ||
raise ValueError(f"Unexpected message: {data!r}") | ||
|
||
async def end(self) -> None: | ||
"""Close websocket connection.""" | ||
assert self.transport is not None | ||
await self.transport.end() | ||
|
||
def add_event_listener(self, | ||
name: Optional[str], | ||
fn: Callable[[str, Mapping[str, Any]], Awaitable[Any]]) -> None: | ||
"""Add a listener for the event with a given name. | ||
If name is None, the listener is called for all messages that are not otherwise | ||
handled. | ||
:param name: Name of event to listen for or None to register a default handler | ||
:param fn: Async callback function that receives event data | ||
""" | ||
self.event_listeners[name].append(fn) | ||
|
||
|
||
class Transport: | ||
"""Low level message handler for the WebSockets connection""" | ||
def __init__(self, url: str, | ||
msg_handler: Callable[[Mapping[str, Any]], Coroutine[Any, Any, None]], | ||
loop: Optional[asyncio.AbstractEventLoop] = None): | ||
self.url = url | ||
self.connection: Optional[websockets.WebSocketClientProtocol] = None | ||
self.msg_handler = msg_handler | ||
self.send_buf: List[Mapping[str, Any]] = [] | ||
|
||
if loop is None: | ||
loop = get_running_loop() | ||
self.loop = loop | ||
|
||
self.read_message_task: Optional[asyncio.Task[Any]] = None | ||
|
||
async def start(self) -> None: | ||
self.connection = await websockets.client.connect(self.url) | ||
self.read_message_task = self.loop.create_task(self.read_messages()) | ||
|
||
for msg in self.send_buf: | ||
await self._send(self.connection, msg) | ||
|
||
async def send(self, data: Mapping[str, Any]) -> None: | ||
if self.connection is not None: | ||
await self._send(self.connection, data) | ||
else: | ||
self.send_buf.append(data) | ||
|
||
@staticmethod | ||
async def _send(connection: websockets.WebSocketClientProtocol, data: Mapping[str, Any]) -> None: | ||
msg = json.dumps(data) | ||
logger.debug("→ %s", msg) | ||
await connection.send(msg) | ||
|
||
async def handle(self, msg: str) -> None: | ||
logger.debug("← %s", msg) | ||
data = json.loads(msg) | ||
await self.msg_handler(data) | ||
|
||
async def end(self) -> None: | ||
if self.connection: | ||
await self.connection.close() | ||
self.connection = None | ||
|
||
async def read_messages(self) -> None: | ||
assert self.connection is not None | ||
async for msg in self.connection: | ||
if not isinstance(msg, str): | ||
raise ValueError("Got a binary message") | ||
await self.handle(msg) | ||
|
||
|
||
class command: | ||
"""Decorator for implementing bidi commands | ||
Implementing a command involves specifying an async function that | ||
builds the parameters to the command. The decorator arranges those | ||
parameters to be turned into a send_command call, using the class | ||
and method names to determine the method in the call. | ||
Commands decorated in this way don't return a future, but await | ||
the actual response. In some cases it can be useful to | ||
post-process this response before returning it to the client. This | ||
can be done by specifying a second decorated method like | ||
@command_name.result. That method will then be called once the | ||
result of the original command is known, and the return value of | ||
the method used as the response of the command. | ||
So for an example, if we had a command test.testMethod, which | ||
returned a result which we want to convert to a TestResult type, | ||
the implementation might look like: | ||
class Test(BidiModule): | ||
@command | ||
def test_method(self, test_data=None): | ||
return {"testData": test_data} | ||
@test_method.result | ||
def convert_test_method_result(self, result): | ||
return TestData(**result) | ||
""" | ||
|
||
def __init__(self, fn: Callable[..., Mapping[str, Any]]): | ||
self.params_fn = fn | ||
self.result_fn: Optional[Callable[..., Any]] = None | ||
|
||
def result(self, fn: Callable[[Any, MutableMapping[str, Any]], Mapping[str, Any]]) -> None: | ||
self.result_fn = fn | ||
|
||
def __set_name__(self, owner: Any, name: str) -> None: | ||
# This is called when the class is created | ||
# see https://docs.python.org/3/reference/datamodel.html#object.__set_name__ | ||
params_fn = self.params_fn | ||
result_fn = self.result_fn | ||
|
||
@functools.wraps(params_fn) | ||
async def inner(self: Any, **kwargs: Any) -> Any: | ||
params = params_fn(self, **kwargs) | ||
|
||
# Convert the classname and the method name to a bidi command name | ||
mod_name = owner.__name__.lower() | ||
if hasattr(owner, "prefix"): | ||
mod_name = f"{owner.prefix}:{mod_name}" | ||
cmd_name = f"{mod_name}.{to_camelcase(name)}" | ||
|
||
future = await self.session.send_command(cmd_name, params) | ||
result = await future | ||
|
||
if result_fn is not None: | ||
# Convert the result if we have a conversion function defined | ||
result = result_fn(self, result) | ||
return result | ||
|
||
# Overwrite the method on the owner class with the wrapper | ||
setattr(owner, name, inner) | ||
|
||
def __call__(*args: Any, **kwargs: Any) -> Awaitable[Any]: | ||
# This isn't really used, but mypy doesn't understand __set_name__ | ||
pass | ||
|
||
|
||
def to_camelcase(name: str) -> str: | ||
"""Convert a python style method name foo_bar to a BiDi command name fooBar""" | ||
parts = name.split("_") | ||
parts[0] = parts[0].lower() | ||
for i in range(1, len(parts)): | ||
parts[i] = parts[i].title() | ||
return "".join(parts) | ||
|
||
|
||
class BidiModule: | ||
def __init__(self, session: BidiSession): | ||
self.session = session | ||
|
||
|
||
class Session(BidiModule): | ||
@command | ||
def new(self, capabilities: Mapping[str, Any]) -> Mapping[str, Mapping[str, Any]]: | ||
return {"capabilities": capabilities} | ||
|
||
@new.result | ||
def _new(self, result: Mapping[str, Any]) -> Any: | ||
return result.get("session_id"), result.get("capabilities", {}) | ||
|
||
@command | ||
def subscribe(self, | ||
events: Optional[List[str]] = None, | ||
contexts: Optional[List[str]] = None) -> Mapping[str, Any]: | ||
params: MutableMapping[str, Any] = {"events": events if events is not None else []} | ||
if contexts is not None: | ||
params["contexts"] = contexts | ||
return params | ||
|
||
@command | ||
def unsubscribe(self, | ||
events: Optional[List[str]] = None, | ||
contexts: Optional[List[str]] = None) -> Mapping[str, Any]: | ||
params: MutableMapping[str, Any] = {"events": events if events is not None else []} | ||
if contexts is not None: | ||
params["contexts"] = contexts | ||
return params | ||
|
||
|
||
class Test(BidiModule): | ||
"""Very temporary module that does nothing, except demonstrate a vendor prefix and | ||
provide a way to work with Gecko's current skeleton implementation.""" | ||
|
||
prefix = "moz" | ||
|
||
@command | ||
def test_method(self, **kwargs): | ||
return kwargs |
Oops, something went wrong.