From 2637486e6e5697f6c78f76c723d143ea56a3cdf5 Mon Sep 17 00:00:00 2001 From: Ola Lidholm <34105674+olalid@users.noreply.github.com> Date: Fri, 29 Nov 2024 06:15:49 +0200 Subject: [PATCH] Reconnect fix (#22) --- src/pysignalr/__init__.py | 14 +++++--------- src/pysignalr/client.py | 10 +++++++++- src/pysignalr/transport/websocket.py | 24 +++++++++++++++++++++--- 3 files changed, 35 insertions(+), 13 deletions(-) diff --git a/src/pysignalr/__init__.py b/src/pysignalr/__init__.py index c255a56..8a63431 100644 --- a/src/pysignalr/__init__.py +++ b/src/pysignalr/__init__.py @@ -12,13 +12,12 @@ from websockets.exceptions import InvalidStatusCode -class NegotiationTimeout(Exception): +class NegotiationFailure(Exception): """ - Exception raised when the connection URL generated during negotiation is no longer valid. + Exception raised when the connection fails. """ pass - async def __aiter__( self: websockets.legacy.client.Connect, ) -> AsyncIterator[websockets.legacy.client.WebSocketClientProtocol]: @@ -43,12 +42,9 @@ async def __aiter__( async with self as protocol: yield protocol - # Handle expired connection URLs by raising a NegotiationTimeout exception. - except InvalidStatusCode as e: - if e.status_code == HTTPStatus.NOT_FOUND: - raise NegotiationTimeout from e - except asyncio.TimeoutError as e: - raise NegotiationTimeout from e + # Handle expired connection URLs by raising a NegotiationFailure exception. + except (InvalidStatusCode, asyncio.TimeoutError) as e: + raise NegotiationFailure from e except Exception: # Add a random initial delay between 0 and 5 seconds. diff --git a/src/pysignalr/client.py b/src/pysignalr/client.py index a911c99..86a817f 100644 --- a/src/pysignalr/client.py +++ b/src/pysignalr/client.py @@ -27,6 +27,9 @@ DEFAULT_CONNECTION_TIMEOUT, DEFAULT_MAX_SIZE, DEFAULT_PING_INTERVAL, + DEFAULT_RETRY_SLEEP, + DEFAULT_RETRY_MULTIPLIER, + DEFAULT_RETRY_COUNT, WebsocketTransport, ) @@ -35,7 +38,6 @@ MessageCallback = Callable[[Message], Awaitable[None]] CompletionMessageCallback = Callable[[CompletionMessage], Awaitable[None]] - class ClientStream: """ Client to server streaming implementation. @@ -100,6 +102,9 @@ def __init__( ping_interval: int = DEFAULT_PING_INTERVAL, connection_timeout: int = DEFAULT_CONNECTION_TIMEOUT, max_size: int | None = DEFAULT_MAX_SIZE, + retry_sleep: float = DEFAULT_RETRY_SLEEP, + retry_multiplier: float = DEFAULT_RETRY_MULTIPLIER, + retry_count: int = DEFAULT_RETRY_COUNT, access_token_factory: Callable[[], str] | None = None, ssl: ssl.SSLContext | None = None, ) -> None: @@ -121,6 +126,9 @@ def __init__( callback=self._on_message, headers=self._headers, ping_interval=ping_interval, + retry_sleep=retry_sleep, + retry_multiplier=retry_multiplier, + retry_count=retry_count, connection_timeout=connection_timeout, max_size=max_size, access_token_factory=access_token_factory, diff --git a/src/pysignalr/transport/websocket.py b/src/pysignalr/transport/websocket.py index e5e4443..7a5dd03 100644 --- a/src/pysignalr/transport/websocket.py +++ b/src/pysignalr/transport/websocket.py @@ -14,7 +14,7 @@ from websockets.protocol import State import pysignalr.exceptions as exceptions -from pysignalr import NegotiationTimeout +from pysignalr import NegotiationFailure from pysignalr.messages import CompletionMessage, Message, PingMessage from pysignalr.protocol.abstract import Protocol from pysignalr.transport.abstract import ConnectionState, Transport @@ -24,6 +24,10 @@ DEFAULT_PING_INTERVAL = 10 DEFAULT_CONNECTION_TIMEOUT = 10 +DEFAULT_RETRY_SLEEP = 1 +DEFAULT_RETRY_MULTIPLIER = 1.1 +DEFAULT_RETRY_COUNT = 10 + _logger = logging.getLogger('pysignalr.transport') @@ -53,6 +57,9 @@ def __init__( skip_negotiation: bool = False, ping_interval: int = DEFAULT_PING_INTERVAL, connection_timeout: int = DEFAULT_CONNECTION_TIMEOUT, + retry_sleep: float = DEFAULT_RETRY_SLEEP, + retry_multiplier: float = DEFAULT_RETRY_MULTIPLIER, + retry_count: int = DEFAULT_RETRY_COUNT, max_size: int | None = DEFAULT_MAX_SIZE, access_token_factory: Callable[[], str] | None = None, ssl: ssl.SSLContext | None = None, @@ -81,6 +88,9 @@ def __init__( self._connection_timeout = connection_timeout self._max_size = max_size self._access_token_factory = access_token_factory + self._retry_sleep = retry_sleep + self._retry_multiplier = retry_multiplier + self._retry_count = retry_count self._ssl = ssl self._state = ConnectionState.disconnected @@ -121,9 +131,17 @@ async def run(self) -> None: Runs the WebSocket transport, managing the connection lifecycle. """ while True: - with suppress(NegotiationTimeout): + try: await self._loop() - await self._set_state(ConnectionState.disconnected) + except NegotiationFailure as e: + await self._set_state(ConnectionState.disconnected) + self._retry_count -= 1 + if self._retry_count <= 0: + raise e + self._retry_sleep *= self._retry_multiplier + await asyncio.sleep(self._retry_sleep) + else: + await self._set_state(ConnectionState.disconnected) async def send(self, message: Message) -> None: """