Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
TheExplainthis committed Mar 20, 2023
1 parent 30fecb8 commit e93336f
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 10 deletions.
10 changes: 6 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,28 +74,29 @@ def handle_text_message(event):

elif text.startswith('/圖像'):
prompt = text[3:].strip()
memory.append(user_id, 'user', prompt)
is_successful, response, error_message = model_management[user_id].image_generations(prompt)
if not is_successful:
raise Exception(error_message)
memory.append(user_id, 'user', prompt)
msg = ImageSendMessage(
original_content_url=response,
preview_image_url=response
)
memory.append(user_id, 'assistant', response)

else:
memory.append(user_id, 'user', text)
is_successful, response, error_message = model_management[user_id].chat_completions(memory.get(user_id), os.getenv('OPENAI_MODEL_ENGINE'))
if not is_successful:
raise Exception(error_message)
memory.append(user_id, 'user', text)
role, response = get_role_and_content(response)
msg = TextSendMessage(text=response)
memory.append(user_id, role, response)

except ValueError:
msg = TextSendMessage(text='Token 無效,請重新註冊,注意格式有空格,格式為 /註冊 sk-xxxxx')
msg = TextSendMessage(text='Token 無效,請重新註冊,格式為 /註冊 sk-xxxxx')
except Exception as e:
memory.remove(user_id)
msg = TextSendMessage(text=str(e))
line_bot_api.reply_message(event.reply_token, msg)

Expand All @@ -114,18 +115,19 @@ def handle_audio_message(event):
raise ValueError('Invalid API token')
else:
transciption, error_message = model_management[user_id].audio_transcriptions(input_audio_path, 'whisper-1')
memory.append(user_id, 'user', transciption)
if error_message:
raise Exception(error_message)
is_successful, response, error_message = model_management[user_id].chat_completions(memory.get(user_id), 'gpt-3.5-turbo')
if not is_successful:
raise Exception(error_message)
memory.append(user_id, 'user', transciption)
role, response = get_role_and_content(response)
memory.append(user_id, role, response)
msg = TextSendMessage(text=response)
except ValueError:
msg = TextSendMessage(text='請先註冊你的 API Token,格式為 /註冊 [API TOKEN]')
except Exception as e:
memory.remove(user_id)
msg = TextSendMessage(text=str(e))
os.remove(input_audio_path)
line_bot_api.reply_message(event.reply_token, msg)
Expand Down
6 changes: 3 additions & 3 deletions src/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ class ModelInterface:
def check_token_valid(self) -> bool:
pass

def chat_completions(self, messages: List[Dict]) -> str:
def chat_completions(self, messages: List[Dict], model_engine: str) -> str:
pass

def audio_transcriptions(self, file) -> str:
def audio_transcriptions(self, file, model_engine: str) -> str:
pass

def image_generations(self, prompt: str) -> str:
Expand Down Expand Up @@ -62,4 +62,4 @@ def image_generations(self, prompt: str) -> str:
"n": 1,
"size": "512x512"
}
return self._request('/images/generations', body=json_body)
return self._request('POST', '/images/generations', body=json_body)
6 changes: 3 additions & 3 deletions src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
t2s_converter = opencc.OpenCC('t2s.json')


def get_role_and_content(response):
def get_role_and_content(response: str):
role = response['choices'][0]['message']['role']
content = response['choices'][0]['message']['content'].strip()
response = s2t_converter.convert(content)
return role, response
content = s2t_converter.convert(content)
return role, content

0 comments on commit e93336f

Please sign in to comment.