Skip to content

Commit

Permalink
Merge pull request haotian-liu#1057 from haotian-liu/fix-multiple-gpu
Browse files Browse the repository at this point in the history
Fix multiple GPU inference.
  • Loading branch information
haotian-liu authored Feb 2, 2024
2 parents 498e18d + 4c92669 commit 1564e67
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 6 deletions.
4 changes: 4 additions & 0 deletions llava/eval/run_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ def eval_model(args):

if "llama-2" in model_name.lower():
conv_mode = "llava_llama_2"
elif "mistral" in model_name.lower():
conv_mode = "mistral_instruct"
elif "v1.6-34b" in model_name.lower():
conv_mode = "chatml_direct"
elif "v1" in model_name.lower():
conv_mode = "llava_v1"
elif "mpt" in model_name.lower():
Expand Down
5 changes: 3 additions & 2 deletions llava/model/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,9 @@ def load_from_hf(repo_id, filename, subfolder=None):

vision_tower = model.get_vision_tower()
if not vision_tower.is_loaded:
vision_tower.load_model()
vision_tower.to(device=device, dtype=torch.float16)
vision_tower.load_model(device_map=device_map)
if device_map != 'auto':
vision_tower.to(device=device_map, dtype=torch.float16)
image_processor = vision_tower.image_processor

if hasattr(model.config, "max_sequence_length"):
Expand Down
2 changes: 1 addition & 1 deletion llava/model/llava_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def prepare_inputs_labels_for_multimodal(
image_feature = unpad_image(image_feature, image_sizes[image_idx])
image_feature = torch.cat((
image_feature,
self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1)
self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device)
), dim=-1)
image_feature = image_feature.flatten(1, 2).transpose(0, 1)
else:
Expand Down
8 changes: 6 additions & 2 deletions llava/model/multimodal_encoder/clip_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@ def __init__(self, vision_tower, args, delay_load=False):
else:
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)

def load_model(self):
def load_model(self, device_map=None):
if self.is_loaded:
print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
return

self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
self.vision_tower.requires_grad_(False)

self.is_loaded = True
Expand Down
6 changes: 5 additions & 1 deletion llava/serve/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,12 @@ def main(args):
model_name = get_model_name_from_path(args.model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)

if 'llama-2' in model_name.lower():
if "llama-2" in model_name.lower():
conv_mode = "llava_llama_2"
elif "mistral" in model_name.lower():
conv_mode = "mistral_instruct"
elif "v1.6-34b" in model_name.lower():
conv_mode = "chatml_direct"
elif "v1" in model_name.lower():
conv_mode = "llava_v1"
elif "mpt" in model_name.lower():
Expand Down

0 comments on commit 1564e67

Please sign in to comment.