Skip to content

Commit

Permalink
add argument to set the demo device
Browse files Browse the repository at this point in the history
  • Loading branch information
TsuTikgiau committed Apr 19, 2023
1 parent 58e5c71 commit 322ed18
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
9 changes: 6 additions & 3 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
def parse_args():
parser = argparse.ArgumentParser(description="Demo")
parser.add_argument("--cfg-path", required=True, help="path to configuration file.")
parser.add_argument("--gpu_id", type=int, default=0, help="specify the gpu to load the model.")
parser.add_argument(
"--options",
nargs="+",
Expand Down Expand Up @@ -50,15 +51,17 @@ def setup_seeds(config):
# ========================================

print('Initializing Chat')
cfg = Config(parse_args())
args = parse_args()
cfg = Config(args)

model_config = cfg.model_cfg
model_config.device_8bit = args.gpu_id
model_cls = registry.get_model_class(model_config.arch)
model = model_cls.from_config(model_config).to('cuda:0')
model = model_cls.from_config(model_config).to('cuda:{}'.format(args.gpu_id))

vis_processor_cfg = cfg.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
chat = Chat(model, vis_processor)
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
print('Initialization Finished')

# ========================================
Expand Down
9 changes: 6 additions & 3 deletions minigpt4/models/mini_gpt4.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ def __init__(
prompt_path="",
prompt_template="",
max_txt_len=32,
low_resource=False, # use 8 bit and put vit in cpu
end_sym='\n',
low_resource=False, # use 8 bit and put vit in cpu
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
):
super().__init__()

Expand Down Expand Up @@ -90,7 +91,7 @@ def __init__(
llama_model,
torch_dtype=torch.float16,
load_in_8bit=True,
device_map="auto"
device_map={'': device_8bit}
)
else:
self.llama_model = LlamaForCausalLM.from_pretrained(
Expand Down Expand Up @@ -232,6 +233,7 @@ def from_config(cls, cfg):
freeze_vit = cfg.get("freeze_vit", True)
freeze_qformer = cfg.get("freeze_qformer", True)
low_resource = cfg.get("low_resource", False)
device_8bit = cfg.get("device_8bit", 0)

prompt_path = cfg.get("prompt_path", "")
prompt_template = cfg.get("prompt_template", "")
Expand All @@ -252,8 +254,9 @@ def from_config(cls, cfg):
prompt_path=prompt_path,
prompt_template=prompt_template,
max_txt_len=max_txt_len,
end_sym=end_sym,
low_resource=low_resource,
end_sym=end_sym
device_8bit=device_8bit,
)

ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4
Expand Down

0 comments on commit 322ed18

Please sign in to comment.