Skip to content

Commit

Permalink
add argument to switch 8bit
Browse files Browse the repository at this point in the history
  • Loading branch information
TsuTikgiau committed Apr 17, 2023
1 parent 3e03c83 commit b76d5c5
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 9 deletions.
1 change: 1 addition & 0 deletions eval_configs/minigpt4_eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ model:
freeze_qformer: True
max_txt_len: 160
end_sym: "###"
low_resource: True
prompt_path: "prompts/alignment.txt"
prompt_template: '###Human: {} ###Assistant: '
ckpt: '/path/to/pretrained/ckpt/'
Expand Down
35 changes: 26 additions & 9 deletions minigpt4/models/mini_gpt4.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@ def __init__(
prompt_path="",
prompt_template="",
max_txt_len=32,
low_resource=False, # use 8 bit and put vit in cpu
end_sym='\n',
):
super().__init__()

self.tokenizer = self.init_tokenizer()
self.low_resource = low_resource

print('Loading VIT')
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
Expand Down Expand Up @@ -83,10 +85,19 @@ def __init__(
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False)
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token

self.llama_model = LlamaForCausalLM.from_pretrained(
llama_model, torch_dtype=torch.float16,
load_in_8bit=True, device_map="auto"
)
if self.low_resource:
self.llama_model = LlamaForCausalLM.from_pretrained(
llama_model,
torch_dtype=torch.float16,
load_in_8bit=True,
device_map="auto"
)
else:
self.llama_model = LlamaForCausalLM.from_pretrained(
llama_model,
torch_dtype=torch.float16,
)

for name, param in self.llama_model.named_parameters():
param.requires_grad = False
print('Loading LLAMA Done')
Expand All @@ -107,18 +118,22 @@ def __init__(
else:
self.prompt_list = []

def encode_img(self, image):
device = image.device
def vit_to_cpu(self):
self.ln_vision.to("cpu")
self.ln_vision.float()
self.visual_encoder.to("cpu")
self.visual_encoder.float()
image = image.to("cpu")

image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
def encode_img(self, image):
device = image.device
if self.low_resource:
self.vit_to_cpu()
image = image.to("cpu")

with self.maybe_autocast():
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)

query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
query_output = self.Qformer.bert(
query_embeds=query_tokens,
Expand Down Expand Up @@ -216,6 +231,7 @@ def from_config(cls, cfg):
vit_precision = cfg.get("vit_precision", "fp16")
freeze_vit = cfg.get("freeze_vit", True)
freeze_qformer = cfg.get("freeze_qformer", True)
low_resource = cfg.get("low_resource", False)

prompt_path = cfg.get("prompt_path", "")
prompt_template = cfg.get("prompt_template", "")
Expand All @@ -236,6 +252,7 @@ def from_config(cls, cfg):
prompt_path=prompt_path,
prompt_template=prompt_template,
max_txt_len=max_txt_len,
low_resource=low_resource,
end_sym=end_sym
)

Expand Down

0 comments on commit b76d5c5

Please sign in to comment.