Skip to content

Commit

Permalink
adding ability to stop the websocket service
Browse files Browse the repository at this point in the history
  • Loading branch information
camelpac committed Mar 16, 2021
1 parent 3734b28 commit d214cbf
Showing 1 changed file with 75 additions and 40 deletions.
115 changes: 75 additions & 40 deletions alpaca_trade_api/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
import msgpack
import os
import re

import concurrent
import websockets
import queue

from .common import get_base_url, get_data_stream_url, get_credentials, URL
from .entity import Entity
Expand Down Expand Up @@ -43,6 +44,7 @@ def __init__(self,
self._ws = None
self._running = False
self._raw_data = raw_data
self._stop_stream_queue = queue.Queue()

async def _connect(self):
self._ws = await websockets.connect(
Expand Down Expand Up @@ -183,10 +185,15 @@ async def _start_ws(self):

async def _consume(self):
while True:
r = await self._ws.recv()
msgs = msgpack.unpackb(r)
for msg in msgs:
await self._dispatch(msg)
if not self._stop_stream_queue.empty():
self._stop_stream_queue.get()
await self.close()
break
else:
r = await self._ws.recv()
msgs = msgpack.unpackb(r)
for msg in msgs:
await self._dispatch(msg)

async def _run_forever(self):
# do not start the websocket connection until we subscribe to something
Expand All @@ -195,46 +202,53 @@ async def _run_forever(self):
await asyncio.sleep(0.1)
log.info('started data stream')
retries = 0
self._running = False
while True:
self._running = False
try:
await self._start_ws()
self._running = True
retries = 0
await self._consume()
if not self._running:
await self._start_ws()
self._running = True
retries = 0
await self._consume()
except websockets.WebSocketException as wse:
retries += 1
if retries > int(os.environ.get('APCA_RETRY_MAX', 3)):
await self.close()
self._running = False
raise ConnectionError("max retries exceeded")
if retries > 1:
await asyncio.sleep(
int(os.environ.get('APCA_RETRY_WAIT', 3)))
logging.warn('websocket error, restarting connection: ' +
str(wse))
finally:
await self.close()

def run(self):
try:
asyncio.get_event_loop().run_until_complete(self._run_forever())
except KeyboardInterrupt:
log.info('exited')
if not self._running:
break
await asyncio.sleep(0.01)

async def close(self):
if self._ws:
await self._ws.close()
self._ws = None
self._running = False

async def stop_ws(self):
self._stop_stream_queue.put_nowait({"should_stop": True})


class TradingStream:
def __init__(self, key_id: str, secret_key: str, base_url: URL):
def __init__(self,
key_id: str,
secret_key: str,
base_url: URL):
self._key_id = key_id
self._secret_key = secret_key
base_url = re.sub(r'^http', 'ws', base_url)
self._endpoint = base_url + '/stream/'
self._trade_updates_handler = None
self._ws = None
self._running = False
self._stop_stream_queue = queue.Queue()

async def _connect(self):
self._ws = await websockets.connect(self._endpoint)
Expand Down Expand Up @@ -284,49 +298,53 @@ async def _start_ws(self):

async def _consume(self):
while True:
r = await self._ws.recv()
msg = json.loads(r)
await self._dispatch(msg)
if not self._stop_stream_queue.empty():
self._stop_stream_queue.get()
await self.close()
break
else:
r = await self._ws.recv()
msg = json.loads(r)
await self._dispatch(msg)

async def _run_forever(self):
# do not start the websocket connection until we subscribe to something
while not self._trade_updates_handler:
await asyncio.sleep(0.1)
log.info('started trading stream')
retries = 0
self._running = False
while True:
self._running = False
try:
await self._start_ws()
self._running = True
retries = 0
await self._consume()
except asyncio.CancelledError:
log.info('cancelled, closing trading stream connection')
return
if not self._running:
await self._start_ws()
self._running = True
retries = 0
await self._consume()
except websockets.WebSocketException as wse:
retries += 1
if retries > int(os.environ.get('APCA_RETRY_MAX', 3)):
await self.close()
self._running = False
raise ConnectionError("max retries exceeded")
if retries > 1:
await asyncio.sleep(
int(os.environ.get('APCA_RETRY_WAIT', 3)))
logging.warn('websocket error, restarting connection: ' +
str(wse))
finally:
await self.close()

def run(self):
try:
asyncio.get_event_loop().run_until_complete(self._run_forever())
except KeyboardInterrupt:
log.info('exited')
if not self._running:
break
await asyncio.sleep(0.01)

async def close(self):
if self._ws:
await self._ws.close()
self._ws = None
self._running = False

async def stop_ws(self):
self._stop_stream_queue.put_nowait({"should_stop": True})

class Stream:
def __init__(self,
Expand All @@ -339,12 +357,16 @@ def __init__(self,
self._key_id, self._secret_key, _ = get_credentials(key_id, secret_key)
self._base_url = base_url or get_base_url()
self._data_steam_url = data_stream_url or get_data_stream_url()
print(self._data_steam_url)

self._trading_ws = TradingStream(self._key_id, self._secret_key,
self._trading_ws = TradingStream(self._key_id,
self._secret_key,
self._base_url)
self._data_ws = DataStream(self._key_id, self._secret_key,
self._data_steam_url, raw_data, data_feed)
self._data_ws = DataStream(self._key_id,
self._secret_key,
self._data_steam_url,
raw_data,
data_feed)


def subscribe_trade_updates(self, handler):
self._trading_ws.subscribe_trade_updates(handler)
Expand Down Expand Up @@ -403,3 +425,16 @@ def run(self):
except KeyboardInterrupt:
print('keyboard interrupt, bye')
pass

async def stop_ws(self):
"""
Signal the ws connections to stop listenning to api stream.
"""
if self._trading_ws:
logging.info("Stopping the trading websocket connection")
await self._trading_ws.stop_ws()

if self._data_ws:
logging.info("Stopping the data websocket connection")
await self._data_ws.stop_ws()

0 comments on commit d214cbf

Please sign in to comment.