Skip to content

Commit

Permalink
Avoid NCCL collective coalescing in distopt when not needed (#1847)
Browse files Browse the repository at this point in the history
Signed-off-by: Tim Moon <[email protected]>
  • Loading branch information
timmoon10 authored Oct 17, 2024
1 parent e13011f commit 6102d2c
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions apex/contrib/optimizers/distributed_fused_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -1966,7 +1966,7 @@ def _start_bucket_grad_sync(self, buckets: List[GradientBucket]) -> None:
comm_stream.wait_stream(main_stream)

# Reduce-scatter over distributed process group
if self.distributed_size > 1:
if buckets and self.distributed_size > 1:
with torch.cuda.stream(comm_stream):
group = self.distributed_process_group
with _coalescing_manager(group, self.device, async_ops=True) as cm:
Expand All @@ -1986,7 +1986,7 @@ def _start_bucket_grad_sync(self, buckets: List[GradientBucket]) -> None:
cm.wait()

# All-reduce over redundant process group
if self.redundant_size > 1:
if buckets and self.redundant_size > 1:
with torch.cuda.stream(comm_stream):
group = self.redundant_process_group
with _coalescing_manager(group, self.device, async_ops=True) as cm:
Expand Down Expand Up @@ -2117,7 +2117,7 @@ def _start_bucket_param_sync(self, buckets: List[ParameterBucket]) -> None:
comm_stream.wait_stream(main_stream)

# All-gather over distributed process group
if self.distributed_size > 1:
if buckets and self.distributed_size > 1:
with torch.cuda.stream(comm_stream):
group = self.distributed_process_group
with _coalescing_manager(group, self.device, async_ops=True) as cm:
Expand Down

0 comments on commit 6102d2c

Please sign in to comment.