Skip to content

Commit

Permalink
Use torch.cuda.synchronize() right after calling batch_isend_irecv() …
Browse files Browse the repository at this point in the history
…communication API
  • Loading branch information
deepakn94 authored and jaredcasper committed Feb 6, 2021
1 parent be473a5 commit 1b3dfa2
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,8 @@ def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward)
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# Temporary workaround for batch_isend_irecv() race condition.
torch.cuda.synchronize()

return tensor_recv_prev, tensor_recv_next

Expand Down

0 comments on commit 1b3dfa2

Please sign in to comment.