Skip to content

Commit

Permalink
Merge pull request haotian-liu#851 from l-salewski/kwargs-for-model-l…
Browse files Browse the repository at this point in the history
…oading

🩹 make ``load_pretrained_model`` accept kwargs
  • Loading branch information
haotian-liu authored Nov 24, 2023
2 parents 414cebd + 8fa7f58 commit 2ca20de
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions llava/model/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN


def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"):
kwargs = {"device_map": device_map}
def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", **kwargs):
kwargs = {"device_map": device_map, **kwargs}

if device != "cuda":
kwargs['device_map'] = {"": device}
Expand Down

0 comments on commit 2ca20de

Please sign in to comment.