From 0ba8602094c8bb733157ede6acc5aeee4befb2d5 Mon Sep 17 00:00:00 2001 From: MxEmerson <2382413024@qq.com> Date: Thu, 7 Dec 2023 19:40:51 +0800 Subject: [PATCH 1/3] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84chat=20model?= =?UTF-8?q?=EF=BC=8C=E5=88=A0=E9=99=A4=E6=9C=AA=E4=BD=BF=E7=94=A8=E4=BB=A3?= =?UTF-8?q?=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugins/chat/__init__.py | 24 ++++-- src/plugins/chat/model.py | 138 ++++++++++++++++------------------- src/plugins/greeting/wiki.py | 33 +-------- 3 files changed, 80 insertions(+), 115 deletions(-) diff --git a/src/plugins/chat/__init__.py b/src/plugins/chat/__init__.py index f60e293f..2803a2fb 100644 --- a/src/plugins/chat/__init__.py +++ b/src/plugins/chat/__init__.py @@ -4,26 +4,37 @@ from nonebot.rule import Rule from nonebot.typing import T_State from nonebot import on_message, get_driver, logger -import random -from .model import chat, del_session -from src.common.config import BotConfig, GroupConfig +from src.common.config import BotConfig, GroupConfig, plugin_config + try: from src.common.utils.speech.text_to_speech import text_2_speech TTS_AVAIABLE = True except Exception as error: - print('TTS not available, error:', error) + logger.error('TTS not available, error: ', error) TTS_AVAIABLE = False +try: + from .model import Chat +except Exception as error: + logger.error('Chat model import error: ', error) + raise error + TTS_MIN_LENGTH = 10 +try: + chat = Chat(plugin_config.chat_strategy) +except Exception as error: + logger.error('Chat model init error: ', error) + raise error + @BotConfig.handle_sober_up def on_sober_up(bot_id, group_id, drunkenness) -> bool: session = f'{bot_id}_{group_id}' logger.info( f'bot [{bot_id}] sober up in group [{group_id}], clear session [{session}]') - del_session(session) + chat.del_session(session) def is_drunk(bot: Bot, event: Event, state: T_State) -> bool: @@ -60,8 +71,7 @@ async def _(bot: Bot, event: GroupMessageEvent, state: T_State): text = text[:50] if not text: return - - ans = await asyncify(chat)(session, text) + ans = await asyncify(chat.chat)(session, text) logger.info(f'session [{session}]: {text} -> {ans}') if TTS_AVAIABLE and len(ans) >= TTS_MIN_LENGTH: diff --git a/src/plugins/chat/model.py b/src/plugins/chat/model.py index 092c1bab..2c94432d 100644 --- a/src/plugins/chat/model.py +++ b/src/plugins/chat/model.py @@ -10,88 +10,74 @@ # 这个要配个 ninja 啥的环境,能大幅提高推理速度,有需要可以自己弄下(仅支持 cuda 显卡) os.environ["RWKV_CUDA_ON"] = '0' - -from rwkv.model import RWKV # pip install rwkv -from .pipeline import PIPELINE, PIPELINE_ARGS +from rwkv.model import RWKV from .prompt import INIT_PROMPT, CHAT_FORMAT -from src.common.config import plugin_config - -# 这个可以照着原仓库的说明改一改,能省点显存啥的 -STRATEGY = 'cuda fp16' if cuda else 'cpu fp32' -if plugin_config.chat_strategy: - STRATEGY = plugin_config.chat_strategy - -MODEL_DIR = Path('resource/chat/models') -MODEL_EXT = '.pth' -MODEL_PATH = None -for f in MODEL_DIR.glob('*'): - if f.suffix != MODEL_EXT: - continue - MODEL_PATH = f.with_suffix('') - break - -print('Chat model:', MODEL_PATH) - -if not MODEL_PATH: - print(f'!!!!!!Chat model not found, please put it in {MODEL_DIR}!!!!!!') - print(f'!!!!!!Chat 模型不存在,请放到 {MODEL_DIR} 文件夹下!!!!!!') - raise Exception('Chat model not found') - -TOKEN_PATH = MODEL_DIR / '20B_tokenizer.json' - -if not TOKEN_PATH.exists(): - print( - f'AI Chat updated, please put token file to {TOKEN_PATH}, download: https://github.com/BlinkDL/ChatRWKV/blob/main/20B_tokenizer.json') - print( - f'牛牛的 AI Chat 版本更新了,把 token 文件放到 {TOKEN_PATH} 里再启动, 下载地址:https://github.com/BlinkDL/ChatRWKV/blob/main/20B_tokenizer.json') - raise Exception('Chat token not found') - -model = RWKV(model=str(MODEL_PATH), strategy=STRATEGY) -pipeline = PIPELINE(model, str(TOKEN_PATH)) -args = PIPELINE_ARGS( - temperature=1.0, - top_p=0.7, - alpha_frequency=0.25, - alpha_presence=0.25, - token_ban=[0], # ban the generation of some tokens - token_stop=[], # stop generation whenever you see any token here - ends=('\n'), - ends_if_too_long=("。", "!", "?", "\n")) - - -INIT_STATE = deepcopy(pipeline.generate( - INIT_PROMPT, token_count=200, args=args)[1]) -all_state = defaultdict(lambda: deepcopy(INIT_STATE)) -all_occurrence = {} - -chat_locker = Lock() - - -def chat(session: str, text: str, token_count: int = 50) -> str: - with chat_locker: - state = all_state[session] - ctx = CHAT_FORMAT.format(text) - occurrence = all_occurrence.get(session, {}) - - out, state, occurrence = pipeline.generate( - ctx, token_count=token_count, args=args, state=state, occurrence=occurrence) - - all_state[session] = deepcopy(state) - all_occurrence[session] = occurrence - return out.strip() - +from .pipeline import PIPELINE, PIPELINE_ARGS -def del_session(session: str): - with chat_locker: - if session in all_state: - del all_state[session] - if session in all_occurrence: - del all_occurrence[session] +DEFAULT_STRATEGY = 'cuda fp16' if cuda else 'cpu fp32' +DEFAULT_MODEL_DIR = Path('resource/chat/models') + + +class Chat: + def __init__(self, strategy=DEFAULT_STRATEGY, model_dir=DEFAULT_MODEL_DIR) -> None: + self.STRATEGY = strategy if strategy else DEFAULT_STRATEGY + self.MODEL_DIR = model_dir + self.MODEL_EXT = '.pth' + self.MODEL_PATH = None + self.TOKEN_PATH = self.MODEL_DIR / '20B_tokenizer.json' + for f in self.MODEL_DIR.glob('*'): + if f.suffix != self.MODEL_EXT: + continue + self.MODEL_PATH = f.with_suffix('') + break + if not self.MODEL_PATH: + raise Exception(f'Chat model not found in {self.MODEL_DIR}') + if not self.TOKEN_PATH.exists(): + raise Exception(f'Chat token not found in {self.TOKEN_PATH}') + model = RWKV(model=str(self.MODEL_PATH), strategy=self.STRATEGY) + self.pipeline = PIPELINE(model, str(self.TOKEN_PATH)) + self.args = PIPELINE_ARGS( + temperature=1.0, + top_p=0.7, + alpha_frequency=0.25, + alpha_presence=0.25, + token_ban=[0], # ban the generation of some tokens + token_stop=[], # stop generation whenever you see any token here + ends=('\n'), + ends_if_too_long=("。", "!", "?", "\n")) + + INIT_STATE = deepcopy(self.pipeline.generate( + INIT_PROMPT, token_count=200, args=self.args)[1]) + self.all_state = defaultdict(lambda: deepcopy(INIT_STATE)) + self.all_occurrence = {} + + self.chat_locker = Lock() + + def chat(self, session: str, text: str, token_count: int = 50) -> str: + with self.chat_locker: + state = self.all_state[session] + ctx = CHAT_FORMAT.format(text) + occurrence = self.all_occurrence.get(session, {}) + + out, state, occurrence = self.pipeline.generate( + ctx, token_count=token_count, args=self.args, state=state, occurrence=occurrence) + + self.all_state[session] = deepcopy(state) + self.all_occurrence[session] = occurrence + return out.strip() + + def del_session(self, session: str): + with self.chat_locker: + if session in self.all_state: + del self.all_state[session] + if session in self.all_occurrence: + del self.all_occurrence[session] if __name__ == "__main__": + chat = Chat('cpu fp32') while True: session = "main" text = input('text:') - result = chat(session, text) + result = chat.chat(session, text) print(result) diff --git a/src/plugins/greeting/wiki.py b/src/plugins/greeting/wiki.py index 6e6ecd6a..df4a5df9 100644 --- a/src/plugins/greeting/wiki.py +++ b/src/plugins/greeting/wiki.py @@ -1,6 +1,5 @@ import os import random -from src.common.utils.download_tools import DownloadTools # 这里的值是CN不代表是中文语音,wiki的定义有点怪,所有语言都叫CN_xx # 实际的url类似 'https://static.prts.wiki/voice_jp/char_485_pallas/CN_01.wav' @@ -30,29 +29,7 @@ voices_source = 'resource/voices' -class WikiVoice(DownloadTools): - def download_voice_from_wiki(self, operator, url, filename): - folder = f'{voices_source}/{operator}' - f = f'{folder}/{filename}' - if os.path.exists(f): - return - - print('Downloading', url, "as", filename, "to", folder) - content = self.request_file(url) - if content: - os.makedirs(folder, exist_ok=True) - with open(f, mode='wb+') as voice: - voice.write(content) - else: - print("Download failed!") - - def download_voices(self, folder, oper_id): - base_url = f'https://static.prts.wiki/voice/{oper_id}/' - for key, web_file in voice_dict.items(): - url = f'{base_url}{web_file}.wav' - filename = f'{key}.wav' - self.download_voice_from_wiki(folder, url, filename) - +class WikiVoice(): def get_voice_filename(self, operator, key): if key not in voice_dict: return None @@ -65,11 +42,3 @@ def get_voice_filename(self, operator, key): def get_random_voice(self, operator, ranges): key = random.choice([r for r in ranges if r in voice_dict]) return self.get_voice_filename(operator, key) - - -if __name__ == '__main__': - operator = 'Pallas' - wiki = WikiVoice() - wiki.download_voices('Pallas', 'char_485_pallas') - - print(wiki.get_random_voice(operator)) From 6c343158a5e487679a10214b8b36ce921cf2e489 Mon Sep 17 00:00:00 2001 From: MxEmerson <2382413024@qq.com> Date: Sat, 7 Sep 2024 21:26:01 +0800 Subject: [PATCH 2/3] =?UTF-8?q?feat:=20=E7=89=9B=E7=89=9B=E5=A4=BA?= =?UTF-8?q?=E8=88=8Dplus?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 顺便改了类型返回兼容 BREAKING CHANGE: 移除了默认启动的nonebot_go_cqhttp插件 Closes #113 --- pyproject.toml | 2 +- src/common/utils/array2cqcode/__init__.py | 4 +- src/plugins/take_name/__init__.py | 52 ++++++++++++++++++++++- 3 files changed, 53 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index dd3bec12..3af63507 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ nonebot-plugin-gocqhttp = "^0.6.4" nb-cli = "^1.2.0" [tool.nonebot] -plugins = ["nonebot_plugin_apscheduler", "nonebot_plugin_gocqhttp"] +plugins = ["nonebot_plugin_apscheduler"] plugin_dirs = ["src/plugins"] [build-system] diff --git a/src/common/utils/array2cqcode/__init__.py b/src/common/utils/array2cqcode/__init__.py index 390cf1e4..89e013a2 100644 --- a/src/common/utils/array2cqcode/__init__.py +++ b/src/common/utils/array2cqcode/__init__.py @@ -1,10 +1,10 @@ import json from .message_segment import BaseMessageSegment -from typing import Any +from typing import Union, Any -def try_convert_to_cqcode(data: Any) -> str | Any: +def try_convert_to_cqcode(data: Any) -> Union[str, Any]: try: msg = json.loads(data) if not isinstance(msg, list): diff --git a/src/plugins/take_name/__init__.py b/src/plugins/take_name/__init__.py index 2c7f52b5..4075a506 100644 --- a/src/plugins/take_name/__init__.py +++ b/src/plugins/take_name/__init__.py @@ -1,7 +1,10 @@ import random -from nonebot import require, logger, get_bot +from nonebot import require, logger, get_bot, on_notice +from nonebot.rule import Rule +from nonebot.typing import T_State +from nonebot.adapters import Bot, Event from nonebot.exception import ActionFailed -from nonebot.adapters.onebot.v11 import Message +from nonebot.adapters.onebot.v11 import Message, NoticeEvent from src.plugins.repeater.model import Chat from src.common.config import BotConfig @@ -77,3 +80,48 @@ async def change_name(): except ActionFailed: # 可能牛牛退群了 continue + + +async def is_change_name_notice(bot: Bot, event: NoticeEvent, state: T_State) -> bool: + config = BotConfig(event.self_id, event.group_id) + if event.notice_type == 'group_card' and event.user_id == config.taken_name(): + return True + return False + + +watch_name = on_notice(rule=Rule(is_change_name_notice), priority=4) + + +@watch_name.handle() +async def watch_name_handle(bot: Bot, event: NoticeEvent, state: T_State): + group_id = event.group_id + user_id = event.user_id + + try: + info = await bot.call_api('get_group_member_info', **{ + 'group_id': group_id, + 'user_id': user_id, + 'no_cache': True + }) + except ActionFailed: + return + + card = info['card'] if info['card'] else info['nickname'] + + logger.info( + 'bot [{}] watch name change by [{}] in group [{}]'.format( + bot.self_id, user_id, group_id)) + + config = BotConfig(bot.self_id, group_id) + + try: + await bot.call_api('set_group_card', **{ + 'group_id': group_id, + 'user_id': user_id, + 'card': card + }) + + config.update_taken_name(user_id) + + except ActionFailed: + return From 45c90f4cf5707a611847361c9a1447aed01bf676 Mon Sep 17 00:00:00 2001 From: MxEmerson <2382413024@qq.com> Date: Sat, 7 Sep 2024 21:34:00 +0800 Subject: [PATCH 3/3] =?UTF-8?q?ci:=20=E4=BC=98=E5=8C=96ci?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a5e59b4f..cb73f3a1 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -26,7 +26,7 @@ jobs: id: metadata with: images: | - misteo/pallas-bot + ${{ github.repository_owner }}/pallas-bot tags: | type=raw,value=latest