diff --git a/llava/eval/run_llava.py b/llava/eval/run_llava.py index 7865e6f3b..b522518a4 100644 --- a/llava/eval/run_llava.py +++ b/llava/eval/run_llava.py @@ -72,6 +72,10 @@ def eval_model(args): if "llama-2" in model_name.lower(): conv_mode = "llava_llama_2" + elif "mistral" in model_name.lower(): + conv_mode = "mistral_instruct" + elif "v1.6-34b" in model_name.lower(): + conv_mode = "chatml_direct" elif "v1" in model_name.lower(): conv_mode = "llava_v1" elif "mpt" in model_name.lower(): diff --git a/llava/model/builder.py b/llava/model/builder.py index 33edbd100..394da6a5d 100644 --- a/llava/model/builder.py +++ b/llava/model/builder.py @@ -147,8 +147,9 @@ def load_from_hf(repo_id, filename, subfolder=None): vision_tower = model.get_vision_tower() if not vision_tower.is_loaded: - vision_tower.load_model() - vision_tower.to(device=device, dtype=torch.float16) + vision_tower.load_model(device_map=device_map) + if device_map != 'auto': + vision_tower.to(device=device_map, dtype=torch.float16) image_processor = vision_tower.image_processor if hasattr(model.config, "max_sequence_length"): diff --git a/llava/model/llava_arch.py b/llava/model/llava_arch.py index 5a5e62dfc..f07e5b5ea 100644 --- a/llava/model/llava_arch.py +++ b/llava/model/llava_arch.py @@ -180,7 +180,7 @@ def prepare_inputs_labels_for_multimodal( image_feature = unpad_image(image_feature, image_sizes[image_idx]) image_feature = torch.cat(( image_feature, - self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1) + self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device) ), dim=-1) image_feature = image_feature.flatten(1, 2).transpose(0, 1) else: diff --git a/llava/model/multimodal_encoder/clip_encoder.py b/llava/model/multimodal_encoder/clip_encoder.py index 211781135..97dbea3de 100644 --- a/llava/model/multimodal_encoder/clip_encoder.py +++ b/llava/model/multimodal_encoder/clip_encoder.py @@ -21,9 +21,13 @@ def __init__(self, vision_tower, args, delay_load=False): else: self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) - def load_model(self): + def load_model(self, device_map=None): + if self.is_loaded: + print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name)) + return + self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) - self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name) + self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) self.vision_tower.requires_grad_(False) self.is_loaded = True diff --git a/llava/serve/cli.py b/llava/serve/cli.py index 589edd5c7..9ba4e4a73 100644 --- a/llava/serve/cli.py +++ b/llava/serve/cli.py @@ -31,8 +31,12 @@ def main(args): model_name = get_model_name_from_path(args.model_path) tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) - if 'llama-2' in model_name.lower(): + if "llama-2" in model_name.lower(): conv_mode = "llava_llama_2" + elif "mistral" in model_name.lower(): + conv_mode = "mistral_instruct" + elif "v1.6-34b" in model_name.lower(): + conv_mode = "chatml_direct" elif "v1" in model_name.lower(): conv_mode = "llava_v1" elif "mpt" in model_name.lower():