Skip to content

Commit

Permalink
Add warning when inference with mps and bf16 on Mac
Browse files Browse the repository at this point in the history
  • Loading branch information
iceflame89 committed Jun 4, 2024
1 parent dfbc321 commit 1c5fadf
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion web_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@
device = args.device
assert device in ['cuda', 'mps']
if args.dtype == 'bf16':
dtype = torch.bfloat16
if device == 'mps':
print('Warning: MPS does not support bf16, will use fp16 instead')
dtype = torch.float16
else:
dtype = torch.bfloat16
else:
dtype = torch.float16

Expand Down

0 comments on commit 1c5fadf

Please sign in to comment.