Skip to content

Commit

Permalink
enabled auto torchdtype detecting
Browse files Browse the repository at this point in the history
Signed-off-by: ftgreat <[email protected]>
  • Loading branch information
ftgreat committed Oct 9, 2023
1 parent df9c0c2 commit 3c0cb59
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion flagai/auto_model/auto_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,13 @@ def __init__(self,
if task_name == "aquila2":
from flagai.model.aquila2.modeling_aquila import AquilaForCausalLM
download_path = os.path.join(model_dir, model_name)


if not torch_dtype:
if model_name.lower() == "aquilachat2-34b":
torch_dtype = torch.bfloat16
else:
torch_dtype = torch.float16

if not os.path.exists(download_path):
# Try to download from ModelHub
try:
Expand Down

0 comments on commit 3c0cb59

Please sign in to comment.