Skip to content

Commit

Permalink
Support fp32 gradaccum for bf16 model (microsoft#2566)
Browse files Browse the repository at this point in the history
* allow bf16 model with fp32 gradient accumulation datatype

* allow fp32 gradient accumulation and bfloat16 model in amp mode

* alternative fix for grad accumulation type mismatch.  In the case of zero optimizer we should have grad accum type == model data type

Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
delock and tjruwase authored Dec 6, 2022
1 parent 2d8f3f5 commit 0693883
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,7 @@ def get_data_types(self):
model_dtype = torch.bfloat16

if self._config.grad_accum_dtype == None:
if model_dtype == torch.bfloat16:
if model_dtype == torch.bfloat16 and not self.zero_optimization():
grad_accum_dtype = torch.float32
else:
grad_accum_dtype = model_dtype
Expand Down

0 comments on commit 0693883

Please sign in to comment.