diff --git a/web_demo.py b/web_demo.py index a9a0727..668dcf4 100644 --- a/web_demo.py +++ b/web_demo.py @@ -26,7 +26,11 @@ device = args.device assert device in ['cuda', 'mps'] if args.dtype == 'bf16': - dtype = torch.bfloat16 + if device == 'mps': + print('Warning: MPS does not support bf16, will use fp16 instead') + dtype = torch.float16 + else: + dtype = torch.bfloat16 else: dtype = torch.float16