From aaf1341af7b60c014df37240cc49caec86347397 Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Tue, 10 Jan 2023 12:14:47 +0000 Subject: [PATCH] better chat --- RWKV-v4neo/chat.py | 58 +++++++++++++++++++++++++++------------------- 1 file changed, 34 insertions(+), 24 deletions(-) diff --git a/RWKV-v4neo/chat.py b/RWKV-v4neo/chat.py index f3a833c3..b2aad520 100644 --- a/RWKV-v4neo/chat.py +++ b/RWKV-v4neo/chat.py @@ -55,16 +55,19 @@ bot = "Bot" interface = ":" +# The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite. +# The following is a conversation between a highly knowledgeable and intelligent AI called {bot}, and a human called {user}. In the following interactions, {user} and {bot} converse in natural language, and {bot} do its best to answer {user}'s questions. {bot} is respectful, polite and inclusive. {bot} knows a lot, and always tells the truth. + init_prompt = f''' -The following is a verbose and detailed conversation between a highly knowledgeable and intelligent AI assistant called {bot}, and a human user called {user}. {bot} always tells the truth and facts. {bot} is polite and humorous. The conversation begins. +The following is a verbose and detailed conversation between an AI assistant called {bot}, and a human user called {user}. {bot} is intelligent, knowledgeable, wise and polite. {user}{interface} french revolution what year -{bot}{interface} The French revolution started in 1789, and lasted 10 years until 1799. +{bot}{interface} The French Revolution started in 1789, and lasted 10 years until 1799. {user}{interface} 3+5=? -{bot}{interface} 3 + 5 = 8, so the answer is 8. +{bot}{interface} The answer is 8. {user}{interface} guess i marry who ? @@ -76,7 +79,7 @@ {user}{interface} wat is lhc -{bot}{interface} LHC is a large and very expensive piece of science equipment. It’s a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012. +{bot}{interface} LHC is a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012. ''' @@ -154,7 +157,7 @@ def on_message(message): return x_temp = 1.0 - x_top_p = 0.8 + x_top_p = 0.85 if ("-temp=" in msg): x_temp = float(msg.split("-temp=")[1].split(" ")[0]) msg = msg.replace("-temp="+f'{x_temp:g}', "") @@ -170,25 +173,25 @@ def on_message(message): if x_top_p <= 0: x_top_p = 0 - if msg == '+reset_rwkv' or msg == '+rwkv_reset': + if msg == '+reset': out = load_all_stat('', 'chat_init') save_all_stat(srv, 'chat', out) reply_msg("Chat reset.") return - elif msg[:10] == '+rwkv_gen ' or msg[:9] == '+rwkv_qa ' or msg == '+rwkv_more' or msg == '+rwkv_retry' or msg == '+rwkv_again': + elif msg[:5] == '+gen ' or msg[:4] == '+qa ' or msg == '+more' or msg == '+retry': - if msg[:10] == '+rwkv_gen ': - new = '\n' + msg[10:].strip() + if msg[:5] == '+gen ': + new = '\n' + msg[5:].strip() # print(f'### prompt ###\n[{new}]') current_state = None out = run_rnn(tokenizer.tokenizer.encode(new)) save_all_stat(srv, 'gen_0', out) - elif msg[:9] == '+rwkv_qa ': + elif msg[:4] == '+qa ': out = load_all_stat('', 'chat_init') - real_msg = msg[9:].strip() + real_msg = msg[4:].strip() new = f"{user}{interface} {real_msg}\n\n{bot}{interface}" # print(f'### qa ###\n[{new}]') @@ -201,14 +204,14 @@ def on_message(message): # out = run_rnn(tokenizer.tokenizer.encode(new)) # save_all_stat(srv, 'gen_0', out) - elif msg == '+rwkv_more': + elif msg == '+more': try: out = load_all_stat(srv, 'gen_1') save_all_stat(srv, 'gen_0', out) except: return - elif msg == '+rwkv_retry' or msg == '+rwkv_again': + elif msg == '+retry': try: out = load_all_stat(srv, 'gen_0') except: @@ -224,9 +227,9 @@ def on_message(message): top_p_usual=x_top_p, top_p_newline=x_top_p, ) - if msg[:9] == '+rwkv_qa ': - out = run_rnn([token], newline_adj=-2) - else: + if msg[:4] == '+qa ': + out = run_rnn([token], newline_adj=-1) + else: out = run_rnn([token]) send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]).strip() # print(f'### send ###\n[{send_msg}]') @@ -234,7 +237,7 @@ def on_message(message): save_all_stat(srv, 'gen_1', out) else: - if msg == '+rwkv_alt': + if msg == '+alt': try: out = load_all_stat(srv, 'chat_pre') except: @@ -269,22 +272,29 @@ def on_message(message): out = run_rnn([token], newline_adj=newline_adj) if tokenizer.tokenizer.decode(model_tokens[-10:]).endswith(f'\n\n'): break + # tail = tokenizer.tokenizer.decode(model_tokens[-10:]).strip() + # if tail.endswith(f'User:') or tail.endswith(f'Bot:'): + # break send_msg = tokenizer.tokenizer.decode(model_tokens[begin:]).strip() + # if send_msg.endswith(f'User:'): + # send_msg = send_msg[:-5].strip() + # if send_msg.endswith(f'Bot:'): + # send_msg = send_msg[:-4].strip() # print(f'### send ###\n[{send_msg}]') reply_msg(send_msg) save_all_stat(srv, 'chat', out) print('''Commands: -+rwkv_alt --> alternate chat reply -+rwkv_reset --> reset chat ++alt --> alternate chat reply ++reset --> reset chat -+rwkv_gen YOUR PROMPT --> free generation with your prompt -+rwkv_qa YOUR QUESTION --> free generation - ask any question and get answer (just ask the question) -+rwkv_more --> continue last free generation [does not work for chat] -+rwkv_retry --> retry last free generation ++gen YOUR PROMPT --> free generation with your prompt ++qa YOUR QUESTION --> free generation - ask any question and get answer (just ask the question) ++more --> continue last free generation [does not work for chat] ++retry --> retry last free generation -Now talk with the bot and enjoy. Remember to +rwkv_reset periodically to clean up the bot's memory. Use RWKV-4 14B for best results. +Now talk with the bot and enjoy. Remember to +reset periodically to clean up the bot's memory. Use RWKV-4 14B for best results. This is not instruct-tuned for conversation yet, so don't expect good quality. Better use +rwkv_gen for free generation. ''')