Skip to content

Commit

Permalink
init version of v2
Browse files Browse the repository at this point in the history
  • Loading branch information
TsuTikgiau committed Oct 12, 2023
1 parent d1367e5 commit 7a575af
Show file tree
Hide file tree
Showing 15 changed files with 981 additions and 15 deletions.
765 changes: 765 additions & 0 deletions demo_v2.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion eval_configs/minigpt4_eval.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
model:
arch: mini_gpt4
arch: minigpt4
model_type: pretrain_vicuna0
max_txt_len: 160
end_sym: "###"
Expand Down
2 changes: 1 addition & 1 deletion eval_configs/minigpt4_llama2_eval.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
model:
arch: mini_gpt4
arch: minigpt4
model_type: pretrain_llama2
max_txt_len: 160
end_sym: "</s>"
Expand Down
25 changes: 25 additions & 0 deletions eval_configs/minigptv2_eval.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
model:
arch: minigpt_v2
model_type: pretrain
max_txt_len: 160
end_sym: "</s>"
low_resource: True
prompt_template: '[INST] {} [/INST]'
ckpt: '/home/zhud/c2090/minigpt4_ckpt/448_conversation_correct_best_v7_ablation1_v5_v6/20231007035/checkpoint_35.pth'
llama_model: "/ibex/project/c2133/llama_v2/llama-2-7b-chat-pytorch_update"
lora_r: 64
lora_alpha: 16


datasets:
cc_sbu_align:
vis_processor:
train:
name: "blip2_image_eval"
image_size: 448
text_processor:
train:
name: "blip_caption"

run:
task: image_text_pretrain
2 changes: 1 addition & 1 deletion minigpt4/configs/models/minigpt4_llama2.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
model:
arch: mini_gpt4
arch: minigpt4

# vit encoder
image_size: 224
Expand Down
2 changes: 1 addition & 1 deletion minigpt4/configs/models/minigpt4_vicuna0.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
model:
arch: mini_gpt4
arch: minigpt4

# vit encoder
image_size: 224
Expand Down
31 changes: 31 additions & 0 deletions minigpt4/configs/models/minigpt_v2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
model:
arch: minigpt_v2

# vit encoder
image_size: 448
drop_path_rate: 0
use_grad_checkpoint: False
vit_precision: "fp16"
freeze_vit: True

# generation configs
prompt: ""

llama_model: "/path/to/llama2/weight"
lora_r: 64
lora_alpha: 16


preprocess:
vis_processor:
train:
name: "blip2_image_train"
image_size: 448
eval:
name: "blip2_image_eval"
image_size: 448
text_processor:
train:
name: "blip_caption"
eval:
name: "blip_caption"
6 changes: 5 additions & 1 deletion minigpt4/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,18 @@

from minigpt4.common.registry import registry
from minigpt4.models.base_model import BaseModel
from minigpt4.models.minigpt_4 import MiniGPT4
from minigpt4.models.minigpt_base import MiniGPTBase
from minigpt4.models.minigpt4 import MiniGPT4
from minigpt4.models.minigpt_v2 import MiniGPTv2
from minigpt4.processors.base_processor import BaseProcessor


__all__ = [
"load_model",
"BaseModel",
"MiniGPTBase",
"MiniGPT4",
"MiniGPTv2"
]


Expand Down
3 changes: 1 addition & 2 deletions minigpt4/models/minigpt_4.py → minigpt4/models/minigpt4.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from minigpt4.models.Qformer import BertConfig, BertLMHeadModel



@registry.register_model("mini_gpt4")
@registry.register_model("minigpt4")
class MiniGPT4(MiniGPTBase):
"""
MiniGPT-4 model
Expand Down
11 changes: 7 additions & 4 deletions minigpt4/models/minigpt_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ def __init__(
freeze_vit=True,
llama_model="",
max_txt_len=32,
max_context_len=3800,
prompt_template="",
end_sym='\n',
low_resource=False, # use 8 bit and put vit in cpu
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
lora_r=0, # lora_r means lora is not used
lora_r=0, # lora_r means lora is not used
lora_target_modules=["q_proj", "v_proj"],
lora_alpha=16,
lora_dropout=0.05,
Expand All @@ -50,8 +52,10 @@ def __init__(
)

self.max_txt_len = max_txt_len
self.max_context_len = max_context_len
self.end_sym = end_sym

self.prompt_template = prompt_template
self.prompt_list = []

def vit_to_cpu(self):
Expand Down Expand Up @@ -129,7 +133,6 @@ def prompt_wrap(self, img_embeds, atts_img, prompts, lengths=None):
wrapped_atts[i, :length] = 1
return wrapped_embs, wrapped_atts


def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts):
"""
Concatenate the batched input embedding and batched output embedding together.
Expand Down Expand Up @@ -219,7 +222,7 @@ def preparing_embedding(self, samples):
conv_q = [q.split(connect_sym)for q in conv_q]
conv_a = [a.split(connect_sym) for a in conv_a]

conv_q = [["[INST] " + item + "[/INST]" for item in items] for items in conv_q]
conv_q = [[self.prompt_template.format(item) for item in items] for items in conv_q]

cond_embeds, cond_atts = self.prompt_wrap(img_embeds, img_atts, [q[0] for q in conv_q])
regress_token_ids, regress_atts, part_targets = self.tokenize_conversation(conv_q, conv_a)
Expand All @@ -233,7 +236,7 @@ def preparing_embedding(self, samples):
instruction = None

if self.chat_template:
instruction = ["[INST] " + instruct + "[/INST]" for instruct in instruction]
instruction = [self.prompt_template.format(instruct) for instruct in instruction]

if 'length' in samples:
# the input is a image train (like videos)
Expand Down
139 changes: 139 additions & 0 deletions minigpt4/models/minigpt_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import logging
import random

import torch
from torch.cuda.amp import autocast as autocast
import torch.nn as nn

from minigpt4.common.registry import registry
from minigpt4.models.base_model import disabled_train
from minigpt4.models.minigpt_base import MiniGPTBase
from minigpt4.models.Qformer import BertConfig, BertLMHeadModel


@registry.register_model("minigpt_v2")
class MiniGPTv2(MiniGPTBase):
"""
MiniGPT-v2 model
"""

PRETRAINED_MODEL_CONFIG_DICT = {
"pretrain": "configs/models/minigpt_v2.yaml",
}

def __init__(
self,
vit_model="eva_clip_g",
img_size=448,
drop_path_rate=0,
use_grad_checkpoint=False,
vit_precision="fp16",
freeze_vit=True,
llama_model="",
prompt_template='[INST] {} [/INST]',
max_txt_len=300,
end_sym='\n',
lora_r=64,
lora_target_modules=["q_proj", "v_proj"],
lora_alpha=16,
lora_dropout=0.05,
chat_template=False,
use_grad_checkpoint_llm=False,
max_context_len=3800,
low_resource=False, # use 8 bit and put vit in cpu
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
):
super().__init__(
vit_model=vit_model,
img_size=img_size,
drop_path_rate=drop_path_rate,
use_grad_checkpoint=use_grad_checkpoint,
vit_precision=vit_precision,
freeze_vit=freeze_vit,
llama_model=llama_model,
max_txt_len=max_txt_len,
max_context_len=max_context_len,
end_sym=end_sym,
prompt_template=prompt_template,
low_resource=low_resource,
device_8bit=device_8bit,
lora_r=lora_r,
lora_target_modules=lora_target_modules,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
)

img_f_dim = self.visual_encoder.num_features * 4
self.llama_proj = nn.Linear(
img_f_dim, self.llama_model.config.hidden_size
)
self.chat_template = chat_template

if use_grad_checkpoint_llm:
self.llama_model.gradient_checkpointing_enable()

def encode_img(self, image):
device = image.device

if len(image.shape) > 4:
image = image.reshape(-1, *image.shape[-3:])

with self.maybe_autocast():
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
image_embeds = image_embeds[:, 1:, :]
bs, pn, hs = image_embeds.shape
image_embeds = image_embeds.view(bs, int(pn / 4), int(hs * 4))

inputs_llama = self.llama_proj(image_embeds)
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
return inputs_llama, atts_llama

@classmethod
def from_config(cls, cfg):
vit_model = cfg.get("vit_model", "eva_clip_g")
img_size = cfg.get("image_size")
llama_model = cfg.get("llama_model")

drop_path_rate = cfg.get("drop_path_rate", 0)
use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
vit_precision = cfg.get("vit_precision", "fp16")
freeze_vit = cfg.get("freeze_vit", True)
low_resource = cfg.get("low_resource", False)

prompt_template = cfg.get("prompt_template", '[INST] {} [/INST]')
max_txt_len = cfg.get("max_txt_len", 300)
end_sym = cfg.get("end_sym", '\n')

lora_r = cfg.get("lora_r", 64)
lora_alpha = cfg.get("lora_alpha", 16)
chat_template = cfg.get("chat_template", False)

use_grad_checkpoint_llm = cfg.get("use_grad_checkpoint_llm", False)
max_context_len = cfg.get("max_context_len", 3800)

model = cls(
vit_model=vit_model,
img_size=img_size,
drop_path_rate=drop_path_rate,
use_grad_checkpoint=use_grad_checkpoint,
vit_precision=vit_precision,
freeze_vit=freeze_vit,
llama_model=llama_model,
prompt_template=prompt_template,
max_txt_len=max_txt_len,
low_resource=low_resource,
end_sym=end_sym,
lora_r=lora_r,
lora_alpha=lora_alpha,
chat_template=chat_template,
use_grad_checkpoint_llm=use_grad_checkpoint_llm,
max_context_len=max_context_len,
)

ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
if ckpt_path:
print("Load Minigpt-4-LLM Checkpoint: {}".format(ckpt_path))
ckpt = torch.load(ckpt_path, map_location="cpu")
msg = model.load_state_dict(ckpt['model'], strict=False)

return model
2 changes: 1 addition & 1 deletion train_configs/minigpt4_llama2_stage1_pretrain.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
model:
arch: mini_gpt4
arch: minigpt4
model_type: pretrain_llama2


Expand Down
2 changes: 1 addition & 1 deletion train_configs/minigpt4_llama2_stage2_finetune.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
model:
arch: mini_gpt4
arch: minigpt4
model_type: pretrain_llama2

max_txt_len: 160
Expand Down
2 changes: 1 addition & 1 deletion train_configs/minigpt4_stage1_pretrain.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
model:
arch: mini_gpt4
arch: minigpt4
model_type: pretrain_vicuna0


Expand Down
2 changes: 1 addition & 1 deletion train_configs/minigpt4_stage2_finetune.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
model:
arch: mini_gpt4
arch: minigpt4
model_type: pretrain_vicuna0

max_txt_len: 160
Expand Down

0 comments on commit 7a575af

Please sign in to comment.