Skip to content

Commit

Permalink
fix asyncio bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
51bitquant committed May 13, 2023
1 parent 3c36163 commit fa74520
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 39 deletions.
2 changes: 1 addition & 1 deletion howtrader/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.3.5"
__version__ = "3.3.6"
34 changes: 28 additions & 6 deletions howtrader/api/rest/rest_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,23 @@
from types import TracebackType, coroutine
from threading import Thread
from asyncio import (
get_event_loop,
get_running_loop,
new_event_loop,
set_event_loop,
run_coroutine_threadsafe,
AbstractEventLoop,
Future
Future,
set_event_loop_policy
)

from json import loads

from aiohttp import ClientSession, ClientResponse

# 在Windows系统上必须使用Selector事件循环,否则可能导致程序崩溃
if sys.platform == 'win32': # if platform.system() == 'Windows':
from asyncio import WindowsSelectorEventLoopPolicy
set_event_loop_policy(WindowsSelectorEventLoopPolicy())

CALLBACK_TYPE = Callable[[Union[dict, list], "Request"], None]
ON_FAILED_TYPE = Callable[[int, "Request"], None]
Expand Down Expand Up @@ -113,10 +120,11 @@ class RestClient(object):
def __init__(self):
""""""
self.url_base: str = ""
self.proxy: str = ""
self.proxy: str = None

self.session: ClientSession = ClientSession(trust_env=True)
self.session: ClientSession = None
self.loop: AbstractEventLoop = None
self._active = False

def init(
self,
Expand All @@ -132,13 +140,21 @@ def init(

def start(self) -> None:
"""start event loop"""
if not self.loop:
self.loop = get_event_loop()
if self._active:
return None

self._active = True

try:
self.loop = get_running_loop()
except RuntimeError:
self.loop = new_event_loop()

start_event_loop(self.loop)

def stop(self) -> None:
"""stop event loop"""
self._active = False
if self.loop and self.loop.is_running():
self.loop.stop()

Expand Down Expand Up @@ -236,6 +252,12 @@ async def _get_response(self, request: Request) -> Response:
request = self.sign(request)
url = self._make_full_url(request.path)

if not self.session:
self.session = ClientSession(trust_env=True)

if self.session.closed:
self.session = ClientSession(trust_env=True)

cr: ClientResponse = await self.session.request(
request.method,
url,
Expand Down
52 changes: 29 additions & 23 deletions howtrader/api/websocket/websocket_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from types import coroutine
from threading import Thread
from asyncio import (
get_event_loop,
get_running_loop,
new_event_loop,
set_event_loop,
run_coroutine_threadsafe,
AbstractEventLoop
AbstractEventLoop,
TimeoutError
)

from aiohttp import ClientSession, ClientWebSocketResponse
Expand All @@ -31,11 +33,12 @@ def __init__(self):
self._active: bool = False
self._host: str = ""

self._session: ClientSession = ClientSession()
self._session: ClientSession = None
self.receive_timeout = 5 * 60 # 5 minutes for receiving timeout
self._ws: ClientWebSocketResponse = None
self._loop: AbstractEventLoop = None

self._proxy: str = ""
self._proxy: str = None
self._ping_interval: int = 60 # ping interval for 60 seconds
self._header: dict = {}

Expand Down Expand Up @@ -69,20 +72,16 @@ def start(self):
will call the on_connected callback when connected
subscribe the data when call the on_connected callback
"""
try:
if self._ws:
coro = self._ws.close()
run_coroutine_threadsafe(coro, self._loop)
except Exception as error:
pass

if self._active:
return None

self._active = True

if not self._loop:
self._loop = get_event_loop()
try:
self._loop = get_running_loop()
except RuntimeError:
self._loop = new_event_loop()

start_event_loop(self._loop)

run_coroutine_threadsafe(self._run(), self._loop)
Expand All @@ -97,6 +96,10 @@ def stop(self):
coro = self._ws.close()
run_coroutine_threadsafe(coro, self._loop)

if self._session: # need to close the session.
coro1 = self._session.close()
run_coroutine_threadsafe(coro1, self._loop)

if self._loop and self._loop.is_running():
self._loop.stop()

Expand Down Expand Up @@ -151,9 +154,6 @@ def on_error(
except Exception:
traceback.print_exc()

def on_exit_loop(self):
self.start()

def exception_detail(
self,
exception_type: type,
Expand All @@ -174,16 +174,24 @@ def exception_detail(

async def _run(self):
"""
run on the asyncio
"""
while self._active:
# try catch error/exception
try:
# connect ws server
if not self._session:
self._session = ClientSession()

if self._session.closed:
self._session = ClientSession()

self._ws = await self._session.ws_connect(
self._host,
proxy=self._proxy,
verify_ssl=False
verify_ssl=False,
heartbeat=self._ping_interval, # send ping interval
receive_timeout=self.receive_timeout,
)

# call the on_connected function
Expand All @@ -203,13 +211,11 @@ async def _run(self):
# call the on_disconnected
self.on_disconnected()
# on exception
except TimeoutError:
pass
except Exception:
et, ev, tb = sys.exc_info()
self.on_error(et, ev, tb)
break

self._active = False
self.on_exit_loop()

def _record_last_sent_text(self, text: str):
"""record the last send text for debugging"""
Expand All @@ -220,7 +226,7 @@ def _record_last_received_text(self, text: str):
self._last_received_text = text[:1000]


def start_event_loop(loop: AbstractEventLoop) -> AbstractEventLoop:
def start_event_loop(loop: AbstractEventLoop) -> None:
"""start event loop"""
# if the event loop is not running, then create the thread to run
if not loop.is_running():
Expand Down
4 changes: 1 addition & 3 deletions howtrader/gateway/binance/binance_inverse_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,9 +1096,6 @@ def on_packet(self, packet: dict) -> None:
elif packet["e"] == "ORDER_TRADE_UPDATE":
self.on_order(packet)

def on_exit_loop(self):
self.gateway.rest_api.start_user_stream()

def on_account(self, packet: dict) -> None:
"""account data update"""
for acc_data in packet["a"]["B"]:
Expand Down Expand Up @@ -1180,6 +1177,7 @@ def __init__(self, gateway: BinanceInverseGateway) -> None:

self.ticks: Dict[str, TickData] = {}
self.reqid: int = 0
self.receive_timeout = 60

def connect(
self,
Expand Down
4 changes: 1 addition & 3 deletions howtrader/gateway/binance/binance_spot_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,9 +972,6 @@ def on_packet(self, packet: dict) -> None:
elif packet["e"] == "executionReport":
self.on_order(packet)

def on_exit_loop(self):
self.gateway.rest_api.start_user_stream()

def on_account(self, packet: dict) -> None:
"""account data update"""
for d in packet["B"]:
Expand Down Expand Up @@ -1032,6 +1029,7 @@ def __init__(self, gateway: BinanceSpotGateway) -> None:

self.ticks: Dict[str, TickData] = {}
self.reqid: int = 0
self.receive_timeout = 60

def connect(self, proxy_host: str, proxy_port: int):
"""connect market data ws"""
Expand Down
4 changes: 1 addition & 3 deletions howtrader/gateway/binance/binance_usdt_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,9 +1112,6 @@ def on_packet(self, packet: dict) -> None:
elif packet["e"] == "ORDER_TRADE_UPDATE":
self.on_order(packet)

def on_exit_loop(self):
self.gateway.rest_api.start_user_stream()

def on_account(self, packet: dict) -> None:
"""account data update"""
for acc_data in packet["a"]["B"]:
Expand Down Expand Up @@ -1196,6 +1193,7 @@ def __init__(self, gateway: BinanceUsdtGateway) -> None:

self.ticks: Dict[str, TickData] = {}
self.reqid: int = 0
self.receive_timeout = 60 # 1minute for receiving data timeout.

def connect(
self,
Expand Down

0 comments on commit fa74520

Please sign in to comment.