Skip to content

Commit

Permalink
Merge pull request Vision-CAIR#5 from 152334H/int8
Browse files Browse the repository at this point in the history
consumer gpu inference
  • Loading branch information
TsuTikgiau authored Apr 17, 2023
2 parents 378508b + 700f05d commit 3e03c83
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 11 deletions.
4 changes: 2 additions & 2 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):

num_beams = gr.Slider(
minimum=1,
maximum=16,
value=5,
maximum=10,
value=1,
step=1,
interactive=True,
label="beam search numbers)",
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ dependencies:
- filelock==3.9.0
- fonttools==4.38.0
- frozenlist==1.3.3
- huggingface-hub==0.12.1
- huggingface-hub==0.13.4
- importlib-resources==5.12.0
- kiwisolver==1.4.4
- matplotlib==3.7.0
Expand Down
5 changes: 3 additions & 2 deletions minigpt4/conversation/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,16 @@ def ask(self, text, conv):
else:
conv.append_message(conv.roles[0], text)

def answer(self, conv, img_list, max_new_tokens=200, num_beams=5, min_length=1, top_p=0.9,
repetition_penalty=1.0, length_penalty=1, temperature=1):
def answer(self, conv, img_list, max_new_tokens=200, num_beams=1, min_length=1, top_p=0.9,
repetition_penalty=1.0, length_penalty=1, temperature=1.0):
conv.append_message(conv.roles[1], None)
embs = self.get_context_emb(conv, img_list)
outputs = self.model.llama_model.generate(
inputs_embeds=embs,
max_new_tokens=max_new_tokens,
stopping_criteria=self.stopping_criteria,
num_beams=num_beams,
do_sample=True,
min_length=min_length,
top_p=top_p,
repetition_penalty=repetition_penalty,
Expand Down
18 changes: 12 additions & 6 deletions minigpt4/models/mini_gpt4.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ def __init__(
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token

self.llama_model = LlamaForCausalLM.from_pretrained(
llama_model, torch_dtype=torch.float16
llama_model, torch_dtype=torch.float16,
load_in_8bit=True, device_map="auto"
)
for name, param in self.llama_model.named_parameters():
param.requires_grad = False
Expand All @@ -107,12 +108,17 @@ def __init__(
self.prompt_list = []

def encode_img(self, image):
with self.maybe_autocast():
image_embeds = self.ln_vision(self.visual_encoder(image))
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
image.device
)
device = image.device
self.ln_vision.to("cpu")
self.ln_vision.float()
self.visual_encoder.to("cpu")
self.visual_encoder.float()
image = image.to("cpu")

image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)

with self.maybe_autocast():
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_output = self.Qformer.bert(
query_embeds=query_tokens,
Expand Down

0 comments on commit 3e03c83

Please sign in to comment.