Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: threadsafe waiting queue #301

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 12 additions & 23 deletions roborock/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
RoborockTimeout,
UnknownMethodError,
)
from .roborock_future import RoborockFuture
from .roborock_future import RequestKey, RoborockFuture, WaitingQueue
from .roborock_message import (
RoborockMessage,
RoborockMessageProtocol,
Expand All @@ -38,7 +38,7 @@ def __init__(self, device_info: DeviceData) -> None:
"""Initialize RoborockClient."""
self.device_info = device_info
self._nonce = secrets.token_bytes(16)
self._waiting_queue: dict[int, RoborockFuture] = {}
self._waiting_queue = WaitingQueue()
self._last_device_msg_in = time.monotonic()
self._last_disconnection = time.monotonic()
self.keep_alive = KEEPALIVE
Expand Down Expand Up @@ -89,33 +89,22 @@ async def validate_connection(self) -> None:
await self.async_disconnect()
await self.async_connect()

async def _wait_response(self, request_id: int, queue: RoborockFuture) -> Any:
async def _wait_response(self, request_key: RequestKey, future: RoborockFuture) -> Any:
try:
response = await queue.async_get(self.queue_timeout)
response = await future.async_get(self.queue_timeout)
if response == "unknown_method":
raise UnknownMethodError("Unknown method")
return response
except (asyncio.TimeoutError, asyncio.CancelledError):
raise RoborockTimeout(f"id={request_id} Timeout after {self.queue_timeout} seconds") from None
raise RoborockTimeout(f"id={request_key} Timeout after {self.queue_timeout} seconds") from None
finally:
self._waiting_queue.pop(request_id, None)

def _async_response(self, request_id: int, protocol_id: int = 0) -> Any:
queue = RoborockFuture(protocol_id)
if request_id in self._waiting_queue and not (
request_id == 2 and protocol_id == RoborockMessageProtocol.PING_REQUEST
):
new_id = get_next_int(10000, 32767)
self._logger.warning(
"Attempting to create a future with an existing id %s (%s)... New id is %s. "
"Code may not function properly.",
request_id,
protocol_id,
new_id,
)
request_id = new_id
self._waiting_queue[request_id] = queue
return asyncio.ensure_future(self._wait_response(request_id, queue))
self._waiting_queue.safe_pop(request_key)


def _async_response(self, request_key: RequestKey) -> Any:
future = RoborockFuture()
self._waiting_queue.put(request_key, future)
return asyncio.ensure_future(self._wait_response(request_key, future))

@abstractmethod
async def send_message(self, roborock_message: RoborockMessage):
Expand Down
26 changes: 11 additions & 15 deletions roborock/cloud_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .containers import DeviceData, UserData
from .exceptions import RoborockException, VacuumError
from .protocol import MessageParser, md5hex
from .roborock_future import RoborockFuture
from .roborock_future import RequestKey

_LOGGER = logging.getLogger(__name__)
CONNECT_REQUEST_ID = 0
Expand Down Expand Up @@ -72,32 +72,28 @@ def __init__(self, user_data: UserData, device_info: DeviceData) -> None:
self._mqtt_password = rriot.s
self._hashed_password = md5hex(self._mqtt_password + ":" + rriot.k)[16:]
self._mqtt_client.username_pw_set(self._hashed_user, self._hashed_password)
self._waiting_queue: dict[int, RoborockFuture] = {}
self._mutex = Lock()

def _mqtt_on_connect(self, *args, **kwargs):
_, __, ___, rc, ____ = args
connection_queue = self._waiting_queue.get(CONNECT_REQUEST_ID)
if not (connection_queue := self._waiting_queue.safe_pop(RequestKey(CONNECT_REQUEST_ID), "connect")):
self._logger.info("Received unexpected connect event")
return
if rc != mqtt.MQTT_ERR_SUCCESS:
message = f"Failed to connect ({mqtt.error_string(rc)})"
self._logger.error(message)
if connection_queue:
connection_queue.set_exception(VacuumError(message))
else:
self._logger.debug("Failed to notify connect future, not in queue")
connection_queue.set_exception(VacuumError(message))
return
self._logger.info(f"Connected to mqtt {self._mqtt_host}:{self._mqtt_port}")
topic = f"rr/m/o/{self._mqtt_user}/{self._hashed_user}/{self.device_info.device.duid}"
(result, mid) = self._mqtt_client.subscribe(topic)
if result != 0:
message = f"Failed to subscribe ({mqtt.error_string(rc)})"
self._logger.error(message)
if connection_queue:
connection_queue.set_exception(VacuumError(message))
connection_queue.set_exception(VacuumError(message))
return
self._logger.info(f"Subscribed to topic {topic}")
if connection_queue:
connection_queue.set_result(True)
connection_queue.set_result(True)

def _mqtt_on_message(self, *args, **kwargs):
client, __, msg = args
Expand All @@ -112,8 +108,7 @@ def _mqtt_on_disconnect(self, *args, **kwargs):
try:
exc = RoborockException(mqtt.error_string(rc)) if rc != mqtt.MQTT_ERR_SUCCESS else None
super().on_connection_lost(exc)
connection_queue = self._waiting_queue.get(DISCONNECT_REQUEST_ID)
if connection_queue:
if connection_queue := self._waiting_queue.safe_pop(RequestKey(DISCONNECT_REQUEST_ID), "disconnect"):
connection_queue.set_result(True)
except Exception as ex:
self._logger.exception(ex)
Expand All @@ -124,10 +119,11 @@ def is_connected(self) -> bool:

def _sync_disconnect(self) -> Any:
if not self.is_connected():
self._logger.debug("Already disconnected from mqtt")
return None

self._logger.info("Disconnecting from mqtt")
disconnected_future = self._async_response(DISCONNECT_REQUEST_ID)
disconnected_future = self._async_response(RequestKey(DISCONNECT_REQUEST_ID))
rc = self._mqtt_client.disconnect()

if rc == mqtt.MQTT_ERR_NO_CONN:
Expand All @@ -149,7 +145,7 @@ def _sync_connect(self) -> Any:
raise RoborockException("Mqtt information was not entered. Cannot connect.")

self._logger.debug("Connecting to mqtt")
connected_future = self._async_response(CONNECT_REQUEST_ID)
connected_future = self._async_response(RequestKey(CONNECT_REQUEST_ID))
self._mqtt_client.connect(host=self._mqtt_host, port=self._mqtt_port, keepalive=KEEPALIVE)
self._mqtt_client.maybe_restart_loop()
return connected_future
Expand Down
60 changes: 57 additions & 3 deletions roborock/roborock_future.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,69 @@
from __future__ import annotations

import logging
from asyncio import Future
from dataclasses import dataclass
from threading import Lock
from typing import Any

import async_timeout

from .exceptions import VacuumError
from .roborock_message import RoborockMessageProtocol

_LOGGER = logging.getLogger(__name__)


@dataclass(frozen=True)
class RequestKey:
"""A key for a Roborock message request."""

request_id: int
protocol: RoborockMessageProtocol | int = 0

def __str__(self) -> str:
"""Get the key for the request."""
return f"{self.request_id}-{self.protocol}"


class WaitingQueue:
"""A threadsafe waiting queue for Roborock messages."""

def __init__(self) -> None:
"""Initialize the waiting queue."""
self._lock = Lock()
self._queue: dict[RequestKey, RoborockFuture] = {}

def put(self, request_key: RequestKey, future: RoborockFuture) -> None:
"""Create a future for the given protocol."""
_LOGGER.debug("Putting request key %s in the queue", request_key)
with self._lock:
if request_key in self._queue:
raise ValueError(f"Request key {request_key} already exists in the queue")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this should be a subclass of RoborockException for easier error handling?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thinking is that if this happens its a bug/shouldn't happen case and so it shouldn't be treated like a server side error. My thinking is: what do we expect a caller to do to handle this?

self._queue[request_key] = future

def safe_pop(self, request_key: RequestKey, label: str | None = None) -> RoborockFuture | None:
"""Get the future from the queue if it has not yet been popped, otherwise ignore.

The label is used for logging when the request key is not found in the queue.
"""
_LOGGER.debug("Popping request key %s (%s) from the queue", request_key, label)
with self._lock:
future = self._queue.pop(request_key, None)
if future is None and label is not None:
_LOGGER.warning("Received message for key %s (%s) not found in the queue", request_key, label)
return future


class RoborockFuture:
def __init__(self, protocol: int):
self.protocol = protocol
"""A threadsafe asyncio Future for Roborock messages.

The results may be set from a background thread. The future
must be awaited in an asyncio event loop.
"""

def __init__(self):
"""Initialize the Roborock future."""
self.fut: Future = Future()
self.loop = self.fut.get_loop()

Expand All @@ -28,7 +81,8 @@ def _set_exception(self, exc: VacuumError) -> None:
def set_exception(self, exc: VacuumError) -> None:
self.loop.call_soon_threadsafe(self._set_exception, exc)

async def async_get(self, timeout: float | int) -> tuple[Any, VacuumError | None]:
async def async_get(self, timeout: float | int) -> Any:
"""Get the result from the future or raises an error."""
try:
async with async_timeout.timeout(timeout):
return await self.fut
Expand Down
19 changes: 7 additions & 12 deletions roborock/version_1_apis/roborock_client_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
WashTowelMode,
)
from roborock.protocol import Utils
from roborock.roborock_future import RequestKey
from roborock.roborock_message import (
ROBOROCK_DATA_CONSUMABLE_PROTOCOL,
ROBOROCK_DATA_STATUS_PROTOCOL,
Expand Down Expand Up @@ -391,8 +392,8 @@ def on_message_received(self, messages: list[RoborockMessage]) -> None:
if data_point_number == "102":
data_point_response = json.loads(data_point)
request_id = data_point_response.get("id")
queue = self._waiting_queue.get(request_id)
if queue and queue.protocol == protocol:
request_key = RequestKey(request_id, protocol)
if queue := self._waiting_queue.safe_pop(request_key, "v1_rpc"):
error = data_point_response.get("error")
if error:
queue.set_exception(
Expand All @@ -406,8 +407,6 @@ def on_message_received(self, messages: list[RoborockMessage]) -> None:
if isinstance(result, list) and len(result) == 1:
result = result[0]
queue.set_result(result)
else:
self._logger.debug("Received response for unknown request id %s", request_id)
else:
try:
data_protocol = RoborockDataProtocol(int(data_point_number))
Expand Down Expand Up @@ -467,19 +466,15 @@ def on_message_received(self, messages: list[RoborockMessage]) -> None:
except ValueError as err:
raise RoborockException(f"Failed to decode {data.payload!r} for {data.protocol}") from err
decompressed = Utils.decompress(decrypted)
queue = self._waiting_queue.get(request_id)
if queue:
request_key = RequestKey(request_id, protocol)
if queue := self._waiting_queue.safe_pop(request_key, "v1_map"):
if isinstance(decompressed, list):
decompressed = decompressed[0]
queue.set_result(decompressed)
else:
self._logger.debug("Received response for unknown request id %s", request_id)
else:
queue = self._waiting_queue.get(data.seq)
if queue:
request_key = RequestKey(data.seq, protocol)
if queue := self._waiting_queue.safe_pop(request_key, "v1_other"):
queue.set_result(data.payload)
else:
self._logger.debug("Received response for unknown request id %s", data.seq)
except Exception as ex:
self._logger.exception(ex)

Expand Down
9 changes: 7 additions & 2 deletions roborock/version_1_apis/roborock_local_client_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .. import CommandVacuumError, DeviceData, RoborockCommand, RoborockException
from ..exceptions import VacuumError
from ..protocol import MessageParser
from ..roborock_future import RequestKey
from ..roborock_message import MessageRetry, RoborockMessage, RoborockMessageProtocol
from ..util import RoborockLoggerAdapter
from .roborock_client_v1 import COMMANDS_SECURED, RoborockClientV1
Expand Down Expand Up @@ -54,15 +55,19 @@ async def send_message(self, roborock_message: RoborockMessage):
response_protocol = request_id + 1
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this trying to set the response protocol to the next thing, e.g. HELLO_REQUEST to HELLO_RESPONSE? I think it ends up munging things up and i can't tell what this is trying to do.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes
I believe that it might be a mapping for dps request and response. But since I didn't have that I noticed that the response protocol was always a number higher

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK great, maybe i can figure out another way to use the new protocol types. I'll think about this more and maybe:
(1) Add some tests to exercise these cases explicitly
(2) address first as a separate change
(3) come back to this

else:
request_id = roborock_message.get_request_id()
_LOGGER.debug("Getting next request id: %s", request_id)
response_protocol = RoborockMessageProtocol.GENERAL_REQUEST
if request_id is None:
raise RoborockException(f"Failed build message {roborock_message}")
local_key = self.device_info.device.local_key
msg = MessageParser.build(roborock_message, local_key=local_key)
request_key = RequestKey(request_id, response_protocol)
if method:
self._logger.debug(f"id={request_id} Requesting method {method} with {params}")
self._logger.debug(f"id={request_key} Requesting method {method} with {params}")
else:
self._logger.debug(f"id={request_key} Requesting with {params}")
# Send the command to the Roborock device
async_response = self._async_response(request_id, response_protocol)
async_response = self._async_response(request_key)
self._send_msg_raw(msg)
diagnostic_key = method if method is not None else "unknown"
try:
Expand Down
11 changes: 6 additions & 5 deletions roborock/version_1_apis/roborock_mqtt_client_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ..containers import DeviceData, UserData
from ..exceptions import CommandVacuumError, RoborockException, VacuumError
from ..protocol import MessageParser, Utils
from ..roborock_future import RequestKey
from ..roborock_message import (
RoborockMessage,
RoborockMessageProtocol,
Expand Down Expand Up @@ -47,11 +48,11 @@ async def send_message(self, roborock_message: RoborockMessage):
response_protocol = (
RoborockMessageProtocol.MAP_RESPONSE if method in COMMANDS_SECURED else RoborockMessageProtocol.RPC_RESPONSE
)

request_key = RequestKey(request_id, response_protocol)
local_key = self.device_info.device.local_key
msg = MessageParser.build(roborock_message, local_key, False)
self._logger.debug(f"id={request_id} Requesting method {method} with {params}")
async_response = self._async_response(request_id, response_protocol)
self._logger.debug(f"id={request_key} Requesting method {method} with {params}")
async_response = self._async_response(request_key)
self._send_msg_raw(msg)
diagnostic_key = method if method is not None else "unknown"
try:
Expand All @@ -67,9 +68,9 @@ async def send_message(self, roborock_message: RoborockMessage):
"response": response,
}
if response_protocol == RoborockMessageProtocol.MAP_RESPONSE:
self._logger.debug(f"id={request_id} Response from {method}: {len(response)} bytes")
self._logger.debug(f"id={request_key} Response from {method}: {len(response)} bytes")
else:
self._logger.debug(f"id={request_id} Response from {method}: {response}")
self._logger.debug(f"id={request_key} Response from {method}: {response}")
return response

async def _send_command(
Expand Down
7 changes: 4 additions & 3 deletions roborock/version_a01_apis/roborock_client_a01.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
ZeoTemperature,
)
from roborock.containers import DyadProductInfo, DyadSndState, RoborockCategory
from roborock.roborock_future import RequestKey
from roborock.roborock_message import (
RoborockDyadDataProtocol,
RoborockMessage,
Expand Down Expand Up @@ -142,9 +143,9 @@ def on_message_received(self, messages: list[RoborockMessage]) -> None:
if data_point_protocol in entries:
# Auto convert into data struct we want.
converted_response = entries[data_point_protocol].post_process_fn(data_point)
queue = self._waiting_queue.get(int(data_point_number))
if queue and queue.protocol == protocol:
queue.set_result(converted_response)
request_key = RequestKey(int(data_point_number), protocol)
if future := self._waiting_queue.safe_pop(request_key, "a01"):
future.set_result(converted_response)

@abstractmethod
async def update_values(
Expand Down
3 changes: 2 additions & 1 deletion roborock/version_a01_apis/roborock_mqtt_client_a01.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from roborock.containers import DeviceData, RoborockCategory, UserData
from roborock.exceptions import RoborockException
from roborock.protocol import MessageParser
from roborock.roborock_future import RequestKey
from roborock.roborock_message import (
RoborockDyadDataProtocol,
RoborockMessage,
Expand Down Expand Up @@ -50,7 +51,7 @@ async def send_message(self, roborock_message: RoborockMessage):
futures = []
if "10000" in payload["dps"]:
for dps in json.loads(payload["dps"]["10000"]):
futures.append(self._async_response(dps, response_protocol))
futures.append(self._async_response(RequestKey(dps, response_protocol)))
self._send_msg_raw(m)
responses = await asyncio.gather(*futures, return_exceptions=True)
dps_responses: dict[int, typing.Any] = {}
Expand Down
Loading
Loading