Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
TheExplainthis committed Mar 20, 2023
1 parent 4f04c88 commit 30fecb8
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 135 deletions.
155 changes: 68 additions & 87 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from src.memory import Memory
from src.logger import logger
from src.storage import Storage
from src.utils import get_role_and_content

load_dotenv('.env')

Expand Down Expand Up @@ -45,71 +46,57 @@ def callback():
@handler.add(MessageEvent, message=TextMessage)
def handle_text_message(event):
user_id = event.source.user_id
text = event.message.text
text = event.message.text.strip()
logger.info(f'{user_id}: {text}')
if text.startswith('/註冊'):
if len(text.split(' ')) == 2:
try:
_, api_key = text.split(' ')
model = OpenAIModel(api_key=api_key)
sucessful = model.check_token_valid()
if not sucessful:
msg = TextSendMessage(text='Token 無效,請重新註冊,注意格式有空格,格式為 /註冊 sk-xxxxx')
else:
model_management[user_id] = model
api_keys[user_id] = api_key
storage.save(api_keys)
msg = TextSendMessage(text='Token 有效,註冊成功')
except Exception:
msg = TextSendMessage(text='Token 無效,請重新註冊,格式為 /註冊 sk-xxxxx')
else:
msg = TextSendMessage(text='Token 無效,請重新註冊,注意格式有空格,格式為 /註冊 sk-xxxxx')
elif text.startswith('/指令說明'):
msg = TextSendMessage(text="指令:\n/註冊 + API Token\n👉 API Token 請先到 https://platform.openai.com/ 註冊登入後取得\n\n/系統訊息 + Prompt\n👉 Prompt 可以命令機器人扮演某個角色,例如:請你扮演擅長做總結的人\n\n/清除\n👉 當前每一次都會紀錄最後兩筆歷史紀錄,這個指令能夠清除歷史訊息\n\n/圖像 + Prompt\n👉 會調用 DALL∙E 2 Model,以文字生成圖像\n\n語音輸入\n👉 會調用 Whisper 模型,先將語音轉換成文字,再調用 ChatGPT 以文字回覆\n\n其他文字輸入\n👉 調用 ChatGPT 以文字回覆")
elif text.startswith('/系統訊息'):
system_message = text[5:]
memory.change_system_message(user_id, system_message)
msg = TextSendMessage(text='輸入成功')

elif text.startswith('/清除'):
memory.remove(user_id)
msg = TextSendMessage(text='歷史訊息清除成功')

else:
if not model_management.get(user_id):
msg = TextSendMessage(text='請先註冊你的 API Token,格式為 /註冊 [API TOKEN]')

try:
if text.startswith('/註冊'):
api_key = text[3:].strip()
model = OpenAIModel(api_key=api_key)
is_successful, _, _ = model.check_token_valid()
if not is_successful:
raise ValueError('Invalid API token')
model_management[user_id] = model
api_keys[user_id] = api_key
storage.save(api_keys)
msg = TextSendMessage(text='Token 有效,註冊成功')

elif text.startswith('/指令說明'):
msg = TextSendMessage(text="指令:\n/註冊 + API Token\n👉 API Token 請先到 https://platform.openai.com/ 註冊登入後取得\n\n/系統訊息 + Prompt\n👉 Prompt 可以命令機器人扮演某個角色,例如:請你扮演擅長做總結的人\n\n/清除\n👉 當前每一次都會紀錄最後兩筆歷史紀錄,這個指令能夠清除歷史訊息\n\n/圖像 + Prompt\n👉 會調用 DALL∙E 2 Model,以文字生成圖像\n\n語音輸入\n👉 會調用 Whisper 模型,先將語音轉換成文字,再調用 ChatGPT 以文字回覆\n\n其他文字輸入\n👉 調用 ChatGPT 以文字回覆")

elif text.startswith('/系統訊息'):
memory.change_system_message(user_id, text[5:].strip())
msg = TextSendMessage(text='輸入成功')

elif text.startswith('/清除'):
memory.remove(user_id)
msg = TextSendMessage(text='歷史訊息清除成功')

elif text.startswith('/圖像'):
prompt = text[3:].strip()
is_successful, response, error_message = model_management[user_id].image_generations(prompt)
if not is_successful:
raise Exception(error_message)
memory.append(user_id, 'user', prompt)
msg = ImageSendMessage(
original_content_url=response,
preview_image_url=response
)
memory.append(user_id, 'assistant', response)

else:
memory.append(user_id, {
'role': 'user',
'content': text
})
if text.startswith('/圖像'):
text = text[3:].strip()
role = 'assistant'
response, error_message = model_management[user_id].image_generations(text)
if error_message:
msg = TextSendMessage(text=error_message)
memory.remove(user_id)
else:
msg = ImageSendMessage(
original_content_url=response,
preview_image_url=response
)
memory.append(user_id, {
'role': role,
'content': response
})
else:
role, response, error_message = model_management[user_id].chat_completions(memory.get(user_id), os.getenv('OPENAI_MODEL_ENGINE'))
if error_message:
msg = TextSendMessage(text=error_message)
memory.remove(user_id)
else:
msg = TextSendMessage(text=response)
memory.append(user_id, {
'role': role,
'content': response
})
is_successful, response, error_message = model_management[user_id].chat_completions(memory.get(user_id), os.getenv('OPENAI_MODEL_ENGINE'))
if not is_successful:
raise Exception(error_message)
memory.append(user_id, 'user', text)
role, response = get_role_and_content(response)
msg = TextSendMessage(text=response)
memory.append(user_id, role, response)

except ValueError:
msg = TextSendMessage(text='Token 無效,請重新註冊,注意格式有空格,格式為 /註冊 sk-xxxxx')
except Exception as e:
msg = TextSendMessage(text=str(e))
line_bot_api.reply_message(event.reply_token, msg)


Expand All @@ -122,31 +109,25 @@ def handle_audio_message(event):
for chunk in audio_content.iter_content():
fd.write(chunk)

if not model_management.get(user_id):
try:
if not model_management.get(user_id):
raise ValueError('Invalid API token')
else:
transciption, error_message = model_management[user_id].audio_transcriptions(input_audio_path, 'whisper-1')
if error_message:
raise Exception(error_message)
is_successful, response, error_message = model_management[user_id].chat_completions(memory.get(user_id), 'gpt-3.5-turbo')
if not is_successful:
raise Exception(error_message)
memory.append(user_id, 'user', transciption)
role, response = get_role_and_content(response)
memory.append(user_id, role, response)
msg = TextSendMessage(text=response)
except ValueError:
msg = TextSendMessage(text='請先註冊你的 API Token,格式為 /註冊 [API TOKEN]')
else:
transciption, error_message = model_management[user_id].audio_transcriptions(input_audio_path, 'whisper-1')
if error_message:
os.remove(input_audio_path)
line_bot_api.reply_message(event.reply_token, TextSendMessage(text=error_message))
return
memory.append(user_id, {
'role': 'user',
'content': transciption
})

role, response, error_message = model_management[user_id].chat_completions(memory.get(user_id), 'gpt-3.5-turbo')
if error_message:
os.remove(input_audio_path)
line_bot_api.reply_message(event.reply_token, TextSendMessage(text=error_message))
memory.remove(user_id)
return
memory.append(user_id, {
'role': role,
'content': response
})
os.remove(input_audio_path)
msg = TextSendMessage(text=response)
except Exception as e:
msg = TextSendMessage(text=str(e))
os.remove(input_audio_path)
line_bot_api.reply_message(event.reply_token, msg)


Expand Down
7 changes: 5 additions & 2 deletions src/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,13 @@ def change_system_message(self, user_id, system_message):
self.system_messages[user_id] = system_message
self.remove(user_id)

def append(self, user_id: str, message: Dict) -> None:
def append(self, user_id: str, role: str, content: str) -> None:
if self.storage[user_id] == []:
self._initialize(user_id)
self.storage[user_id].append(message)
self.storage[user_id].append({
'role': role,
'content': content
})
self._drop_message(user_id)

def get(self, user_id: str) -> str:
Expand Down
63 changes: 17 additions & 46 deletions src/models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
from typing import List, Dict
import requests
import opencc

s2t_converter = opencc.OpenCC('s2t.json')
t2s_converter = opencc.OpenCC('t2s.json')


class ModelInterface:
Expand All @@ -22,73 +18,48 @@ def image_generations(self, prompt: str) -> str:

class OpenAIModel(ModelInterface):
def __init__(self, api_key: str):
self.headers = {
'Authorization': f'Bearer {api_key}'
}
self.api_key = api_key
self.base_url = 'https://api.openai.com/v1'

def check_token_valid(self):
try:
r = requests.get('https://api.openai.com/v1/models', headers=self.headers)
r = r.json()
if r.get('error'):
return False, r.get('error', {}).get('message')
except Exception:
return False, 'OpenAI API 系統不穩定,請稍後再試'
return True, None

def _request(self, endpoint, body):
def _request(self, method, endpoint, body=None, files=None):
self.headers = {
'Authorization': f'Bearer {self.api_key}'
}
try:
self.headers['Content-Type'] = 'application/json'
r = requests.post(f'{self.base_url}{endpoint}', headers=self.headers, json=body)
if method == 'GET':
r = requests.get(f'{self.base_url}{endpoint}', headers=self.headers)
elif method == 'POST':
if body:
self.headers['Content-Type'] = 'application/json'
r = requests.post(f'{self.base_url}{endpoint}', headers=self.headers, json=body, files=files)
r = r.json()
if r.get('error'):
return False, None, r.get('error', {}).get('message')
except Exception:
return False, None, 'OpenAI API 系統不穩定,請稍後再試'
return True, r, None

def _request_with_file(self, endpoint, files):
try:
self.headers.pop('Content-Type', None)
r = requests.post(f'{self.base_url}{endpoint}', headers=self.headers, files=files)
r = r.json()
if r.get('error'):
return False, None, r.get('error', {}).get('message')
except Exception:
return False, None, 'OpenAI API 系統不穩定,請稍後再試'
return True, r, None
def check_token_valid(self):
return self._request('GET', '/models')

def chat_completions(self, messages, model_engine) -> str:
json_body = {
'model': model_engine,
'messages': messages
}
is_successful, r, error_message = self._request('/chat/completions', body=json_body)
if not is_successful:
return None, None, error_message
role = r['choices'][0]['message']['role']
content = r['choices'][0]['message']['content'].strip()
response = s2t_converter.convert(content)
return role, response, None
return self._request('POST', '/chat/completions', body=json_body)

def audio_transcriptions(self, file_path, model_engine) -> str:
files = {
'file': open(file_path, 'rb'),
'model': (None, 'whisper-1'),
'model': (None, model_engine),
}
is_successful, r, error_message = self._request_with_file('/audio/transcriptions', files)
if not is_successful:
return None, error_message
return r['text'], None
return self._request('POST', '/audio/transcriptions', files=files)

def image_generations(self, prompt: str) -> str:
json_body = {
"prompt": prompt,
"n": 1,
"size": "512x512"
}
is_successful, r, error_message = self._request('/images/generations', json_body)
if not is_successful:
return None, error_message
return r['data'][0]['url'], None
return self._request('/images/generations', body=json_body)
11 changes: 11 additions & 0 deletions src/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import opencc

s2t_converter = opencc.OpenCC('s2t.json')
t2s_converter = opencc.OpenCC('t2s.json')


def get_role_and_content(response):
role = response['choices'][0]['message']['role']
content = response['choices'][0]['message']['content'].strip()
response = s2t_converter.convert(content)
return role, response

0 comments on commit 30fecb8

Please sign in to comment.