diff --git a/eval_configs/minigpt4_eval.yaml b/eval_configs/minigpt4_eval.yaml index 5ac9fad4..f9e55a30 100644 --- a/eval_configs/minigpt4_eval.yaml +++ b/eval_configs/minigpt4_eval.yaml @@ -5,6 +5,7 @@ model: freeze_qformer: True max_txt_len: 160 end_sym: "###" + low_resource: True prompt_path: "prompts/alignment.txt" prompt_template: '###Human: {} ###Assistant: ' ckpt: '/path/to/pretrained/ckpt/' diff --git a/minigpt4/models/mini_gpt4.py b/minigpt4/models/mini_gpt4.py index 5c223fe4..db889168 100644 --- a/minigpt4/models/mini_gpt4.py +++ b/minigpt4/models/mini_gpt4.py @@ -36,11 +36,13 @@ def __init__( prompt_path="", prompt_template="", max_txt_len=32, + low_resource=False, # use 8 bit and put vit in cpu end_sym='\n', ): super().__init__() self.tokenizer = self.init_tokenizer() + self.low_resource = low_resource print('Loading VIT') self.visual_encoder, self.ln_vision = self.init_vision_encoder( @@ -83,10 +85,19 @@ def __init__( self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False) self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token - self.llama_model = LlamaForCausalLM.from_pretrained( - llama_model, torch_dtype=torch.float16, - load_in_8bit=True, device_map="auto" - ) + if self.low_resource: + self.llama_model = LlamaForCausalLM.from_pretrained( + llama_model, + torch_dtype=torch.float16, + load_in_8bit=True, + device_map="auto" + ) + else: + self.llama_model = LlamaForCausalLM.from_pretrained( + llama_model, + torch_dtype=torch.float16, + ) + for name, param in self.llama_model.named_parameters(): param.requires_grad = False print('Loading LLAMA Done') @@ -107,18 +118,22 @@ def __init__( else: self.prompt_list = [] - def encode_img(self, image): - device = image.device + def vit_to_cpu(self): self.ln_vision.to("cpu") self.ln_vision.float() self.visual_encoder.to("cpu") self.visual_encoder.float() - image = image.to("cpu") - image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) - image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) + def encode_img(self, image): + device = image.device + if self.low_resource: + self.vit_to_cpu() + image = image.to("cpu") with self.maybe_autocast(): + image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) query_output = self.Qformer.bert( query_embeds=query_tokens, @@ -216,6 +231,7 @@ def from_config(cls, cfg): vit_precision = cfg.get("vit_precision", "fp16") freeze_vit = cfg.get("freeze_vit", True) freeze_qformer = cfg.get("freeze_qformer", True) + low_resource = cfg.get("low_resource", False) prompt_path = cfg.get("prompt_path", "") prompt_template = cfg.get("prompt_template", "") @@ -236,6 +252,7 @@ def from_config(cls, cfg): prompt_path=prompt_path, prompt_template=prompt_template, max_txt_len=max_txt_len, + low_resource=low_resource, end_sym=end_sym )