Skip to content

Commit

Permalink
Reconnect fix (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
olalid authored Nov 29, 2024
1 parent 234e8ec commit 2637486
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 13 deletions.
14 changes: 5 additions & 9 deletions src/pysignalr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@
from websockets.exceptions import InvalidStatusCode


class NegotiationTimeout(Exception):
class NegotiationFailure(Exception):
"""
Exception raised when the connection URL generated during negotiation is no longer valid.
Exception raised when the connection fails.
"""
pass


async def __aiter__(
self: websockets.legacy.client.Connect,
) -> AsyncIterator[websockets.legacy.client.WebSocketClientProtocol]:
Expand All @@ -43,12 +42,9 @@ async def __aiter__(
async with self as protocol:
yield protocol

# Handle expired connection URLs by raising a NegotiationTimeout exception.
except InvalidStatusCode as e:
if e.status_code == HTTPStatus.NOT_FOUND:
raise NegotiationTimeout from e
except asyncio.TimeoutError as e:
raise NegotiationTimeout from e
# Handle expired connection URLs by raising a NegotiationFailure exception.
except (InvalidStatusCode, asyncio.TimeoutError) as e:
raise NegotiationFailure from e

except Exception:
# Add a random initial delay between 0 and 5 seconds.
Expand Down
10 changes: 9 additions & 1 deletion src/pysignalr/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
DEFAULT_CONNECTION_TIMEOUT,
DEFAULT_MAX_SIZE,
DEFAULT_PING_INTERVAL,
DEFAULT_RETRY_SLEEP,
DEFAULT_RETRY_MULTIPLIER,
DEFAULT_RETRY_COUNT,
WebsocketTransport,
)

Expand All @@ -35,7 +38,6 @@
MessageCallback = Callable[[Message], Awaitable[None]]
CompletionMessageCallback = Callable[[CompletionMessage], Awaitable[None]]


class ClientStream:
"""
Client to server streaming implementation.
Expand Down Expand Up @@ -100,6 +102,9 @@ def __init__(
ping_interval: int = DEFAULT_PING_INTERVAL,
connection_timeout: int = DEFAULT_CONNECTION_TIMEOUT,
max_size: int | None = DEFAULT_MAX_SIZE,
retry_sleep: float = DEFAULT_RETRY_SLEEP,
retry_multiplier: float = DEFAULT_RETRY_MULTIPLIER,
retry_count: int = DEFAULT_RETRY_COUNT,
access_token_factory: Callable[[], str] | None = None,
ssl: ssl.SSLContext | None = None,
) -> None:
Expand All @@ -121,6 +126,9 @@ def __init__(
callback=self._on_message,
headers=self._headers,
ping_interval=ping_interval,
retry_sleep=retry_sleep,
retry_multiplier=retry_multiplier,
retry_count=retry_count,
connection_timeout=connection_timeout,
max_size=max_size,
access_token_factory=access_token_factory,
Expand Down
24 changes: 21 additions & 3 deletions src/pysignalr/transport/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from websockets.protocol import State

import pysignalr.exceptions as exceptions
from pysignalr import NegotiationTimeout
from pysignalr import NegotiationFailure
from pysignalr.messages import CompletionMessage, Message, PingMessage
from pysignalr.protocol.abstract import Protocol
from pysignalr.transport.abstract import ConnectionState, Transport
Expand All @@ -24,6 +24,10 @@
DEFAULT_PING_INTERVAL = 10
DEFAULT_CONNECTION_TIMEOUT = 10

DEFAULT_RETRY_SLEEP = 1
DEFAULT_RETRY_MULTIPLIER = 1.1
DEFAULT_RETRY_COUNT = 10

_logger = logging.getLogger('pysignalr.transport')


Expand Down Expand Up @@ -53,6 +57,9 @@ def __init__(
skip_negotiation: bool = False,
ping_interval: int = DEFAULT_PING_INTERVAL,
connection_timeout: int = DEFAULT_CONNECTION_TIMEOUT,
retry_sleep: float = DEFAULT_RETRY_SLEEP,
retry_multiplier: float = DEFAULT_RETRY_MULTIPLIER,
retry_count: int = DEFAULT_RETRY_COUNT,
max_size: int | None = DEFAULT_MAX_SIZE,
access_token_factory: Callable[[], str] | None = None,
ssl: ssl.SSLContext | None = None,
Expand Down Expand Up @@ -81,6 +88,9 @@ def __init__(
self._connection_timeout = connection_timeout
self._max_size = max_size
self._access_token_factory = access_token_factory
self._retry_sleep = retry_sleep
self._retry_multiplier = retry_multiplier
self._retry_count = retry_count
self._ssl = ssl

self._state = ConnectionState.disconnected
Expand Down Expand Up @@ -121,9 +131,17 @@ async def run(self) -> None:
Runs the WebSocket transport, managing the connection lifecycle.
"""
while True:
with suppress(NegotiationTimeout):
try:
await self._loop()
await self._set_state(ConnectionState.disconnected)
except NegotiationFailure as e:
await self._set_state(ConnectionState.disconnected)
self._retry_count -= 1
if self._retry_count <= 0:
raise e
self._retry_sleep *= self._retry_multiplier
await asyncio.sleep(self._retry_sleep)
else:
await self._set_state(ConnectionState.disconnected)

async def send(self, message: Message) -> None:
"""
Expand Down

0 comments on commit 2637486

Please sign in to comment.