From dadc0d7e69a0ccd228fdd9cc356c2d3800d008eb Mon Sep 17 00:00:00 2001 From: Deyao Zhu Date: Tue, 18 Apr 2023 22:04:50 +0300 Subject: [PATCH] adding length control and change the default hyperparameter of the conversation to avoid OOM in 3090. --- demo.py | 7 ++++++- minigpt4/conversation/conversation.py | 13 +++++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/demo.py b/demo.py index b074ed27..a79b3a88 100644 --- a/demo.py +++ b/demo.py @@ -89,7 +89,12 @@ def gradio_ask(user_message, chatbot, chat_state): def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature): - llm_message = chat.answer(conv=chat_state, img_list=img_list, max_new_tokens=1000, num_beams=num_beams, temperature=temperature)[0] + llm_message = chat.answer(conv=chat_state, + img_list=img_list, + num_beams=num_beams, + temperature=temperature, + max_new_tokens=300, + max_length=2000)[0] chatbot[-1][1] = llm_message return chatbot, chat_state, img_list diff --git a/minigpt4/conversation/conversation.py b/minigpt4/conversation/conversation.py index 84aea9a3..7cd50bbf 100644 --- a/minigpt4/conversation/conversation.py +++ b/minigpt4/conversation/conversation.py @@ -134,10 +134,19 @@ def ask(self, text, conv): else: conv.append_message(conv.roles[0], text) - def answer(self, conv, img_list, max_new_tokens=200, num_beams=1, min_length=1, top_p=0.9, - repetition_penalty=1.0, length_penalty=1, temperature=1.0): + def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9, + repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000): conv.append_message(conv.roles[1], None) embs = self.get_context_emb(conv, img_list) + + current_max_len = embs.shape[1] + max_new_tokens + if current_max_len - max_length > 0: + print('Warning: The number of tokens in current conversation exceeds the max length. ' + 'The model will not see the contexts outside the range.') + begin_idx = max(0, current_max_len - max_length) + + embs = embs[:, begin_idx:] + outputs = self.model.llama_model.generate( inputs_embeds=embs, max_new_tokens=max_new_tokens,