Skip to content

Commit

Permalink
fix: wechatmp's deadloop when reply is None from @JS00000 #789
Browse files Browse the repository at this point in the history
  • Loading branch information
lanvent committed Apr 9, 2023
1 parent f1e8344 commit fcfafb0
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 21 deletions.
5 changes: 5 additions & 0 deletions channel/chat_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@ def _send(self, reply: Reply, context: Context, retry_cnt = 0):
time.sleep(3+3*retry_cnt)
self._send(reply, context, retry_cnt+1)

def _success_callback(self, session_id, **kwargs):# 线程正常结束时的回调函数
logger.debug("Worker return success, session_id = {}".format(session_id))

def _fail_callback(self, session_id, exception, **kwargs): # 线程异常结束时的回调函数
logger.exception("Worker return exception: {}".format(exception))

Expand All @@ -242,6 +245,8 @@ def func(worker:Future):
worker_exception = worker.exception()
if worker_exception:
self._fail_callback(session_id, exception = worker_exception, **kwargs)
else:
self._success_callback(session_id, **kwargs)
except CancelledError as e:
logger.info("Worker cancelled, session_id = {}".format(session_id))
except Exception as e:
Expand Down
8 changes: 4 additions & 4 deletions channel/wechatmp/ServiceAccount.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def GET(self):

def POST(self):
# Make sure to return the instance that first created, @singleton will do that.
channel_instance = WechatMPChannel()
channel = WechatMPChannel()
try:
webData = web.data()
# logger.debug("[wechatmp] Receive request:\n" + webData.decode("utf-8"))
Expand All @@ -27,14 +27,14 @@ def POST(self):
message_id = wechatmp_msg.msg_id

logger.info("[wechatmp] {}:{} Receive post query {} {}: {}".format(web.ctx.env.get('REMOTE_ADDR'), web.ctx.env.get('REMOTE_PORT'), from_user, message_id, message))
context = channel_instance._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg)
context = channel._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg)
if context:
# set private openai_api_key
# if from_user is not changed in itchat, this can be placed at chat_channel
user_data = conf().get_user_data(from_user)
context['openai_api_key'] = user_data.get('openai_api_key') # None or user openai_api_key
channel_instance.produce(context)
# The reply will be sent by channel_instance.send() in another thread
channel.produce(context)
# The reply will be sent by channel.send() in another thread
return "success"

elif wechatmp_msg.msg_type == 'event':
Expand Down
29 changes: 18 additions & 11 deletions channel/wechatmp/SubscribeAccount.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def POST(self):
context = channel._compose_context(ContextType.TEXT, message, isgroup=False, msg=wechatmp_msg)
logger.debug("[wechatmp] context: {} {}".format(context, wechatmp_msg))
if message_id in channel.received_msgs: # received and finished
return
# no return because of bandwords or other reasons
return "success"
if supported and context:
# set private openai_api_key
# if from_user is not changed in itchat, this can be placed at chat_channel
Expand Down Expand Up @@ -71,11 +72,12 @@ def POST(self):
channel.query1[cache_key] = False
channel.query2[cache_key] = False
channel.query3[cache_key] = False
# Request again
# User request again, and the answer is not ready
elif cache_key in channel.running and channel.query1.get(cache_key) == True and channel.query2.get(cache_key) == True and channel.query3.get(cache_key) == True:
channel.query1[cache_key] = False #To improve waiting experience, this can be set to True.
channel.query2[cache_key] = False #To improve waiting experience, this can be set to True.
channel.query3[cache_key] = False
# User request again, and the answer is ready
elif cache_key in channel.cache_dict:
# Skip the waiting phase
channel.query1[cache_key] = True
Expand All @@ -89,7 +91,7 @@ def POST(self):
logger.debug("[wechatmp] query1 {}".format(cache_key))
channel.query1[cache_key] = True
cnt = 0
while cache_key not in channel.cache_dict and cnt < 45:
while cache_key in channel.running and cnt < 45:
cnt = cnt + 1
time.sleep(0.1)
if cnt == 45:
Expand All @@ -104,7 +106,7 @@ def POST(self):
logger.debug("[wechatmp] query2 {}".format(cache_key))
channel.query2[cache_key] = True
cnt = 0
while cache_key not in channel.cache_dict and cnt < 45:
while cache_key in channel.running and cnt < 45:
cnt = cnt + 1
time.sleep(0.1)
if cnt == 45:
Expand All @@ -119,7 +121,7 @@ def POST(self):
logger.debug("[wechatmp] query3 {}".format(cache_key))
channel.query3[cache_key] = True
cnt = 0
while cache_key not in channel.cache_dict and cnt < 40:
while cache_key in channel.running and cnt < 40:
cnt = cnt + 1
time.sleep(0.1)
if cnt == 40:
Expand All @@ -132,12 +134,17 @@ def POST(self):
else:
pass

if float(time.time()) - float(query_time) > 4.8:
reply_text = "【正在思考中,回复任意文字尝试获取回复】"
logger.info("[wechatmp] Timeout for {} {}, return".format(from_user, message_id))
replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
return replyPost


if cache_key not in channel.cache_dict and cache_key not in channel.running:
# no return because of bandwords or other reasons
return "success"

# if float(time.time()) - float(query_time) > 4.8:
# reply_text = "【正在思考中,回复任意文字尝试获取回复】"
# logger.info("[wechatmp] Timeout for {} {}, return".format(from_user, message_id))
# replyPost = reply.TextMsg(from_user, to_user, reply_text).send()
# return replyPost

if cache_key in channel.cache_dict:
content = channel.cache_dict[cache_key]
if len(content.encode('utf8'))<=MAX_UTF8_LEN:
Expand Down
16 changes: 10 additions & 6 deletions channel/wechatmp/wechatmp_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ def send(self, reply: Reply, context: Context):
if self.passive_reply:
receiver = context["receiver"]
self.cache_dict[receiver] = reply.content
self.running.remove(receiver)
logger.debug("[send] reply to {} saved to cache: {}".format(receiver, reply))
logger.info("[send] reply to {} saved to cache: {}".format(receiver, reply))
else:
receiver = context["receiver"]
reply_text = reply.content
Expand All @@ -116,10 +115,15 @@ def send(self, reply: Reply, context: Context):
return


def _fail_callback(self, session_id, exception, context, **kwargs):
logger.exception("[wechatmp] Fail to generation message to user, msgId={}, exception={}".format(context['msg'].msg_id, exception))
assert session_id not in self.cache_dict
self.running.remove(session_id)
def _success_callback(self, session_id, context, **kwargs): # 线程异常结束时的回调函数
logger.debug("[wechatmp] Success to generate reply, msgId={}".format(context['msg'].msg_id))
if self.passive_reply:
self.running.remove(session_id)


def _fail_callback(self, session_id, exception, context, **kwargs): # 线程异常结束时的回调函数
logger.exception("[wechatmp] Fail to generate reply to user, msgId={}, exception={}".format(context['msg'].msg_id, exception))
if self.passive_reply:
assert session_id not in self.cache_dict
self.running.remove(session_id)

0 comments on commit fcfafb0

Please sign in to comment.