From 6102d2c300fd26f4c312d5136a7b85296286c7a3 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Wed, 16 Oct 2024 21:06:27 -0700 Subject: [PATCH] Avoid NCCL collective coalescing in distopt when not needed (#1847) Signed-off-by: Tim Moon --- apex/contrib/optimizers/distributed_fused_adam.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/apex/contrib/optimizers/distributed_fused_adam.py b/apex/contrib/optimizers/distributed_fused_adam.py index ae3590d8..314f54e0 100644 --- a/apex/contrib/optimizers/distributed_fused_adam.py +++ b/apex/contrib/optimizers/distributed_fused_adam.py @@ -1966,7 +1966,7 @@ def _start_bucket_grad_sync(self, buckets: List[GradientBucket]) -> None: comm_stream.wait_stream(main_stream) # Reduce-scatter over distributed process group - if self.distributed_size > 1: + if buckets and self.distributed_size > 1: with torch.cuda.stream(comm_stream): group = self.distributed_process_group with _coalescing_manager(group, self.device, async_ops=True) as cm: @@ -1986,7 +1986,7 @@ def _start_bucket_grad_sync(self, buckets: List[GradientBucket]) -> None: cm.wait() # All-reduce over redundant process group - if self.redundant_size > 1: + if buckets and self.redundant_size > 1: with torch.cuda.stream(comm_stream): group = self.redundant_process_group with _coalescing_manager(group, self.device, async_ops=True) as cm: @@ -2117,7 +2117,7 @@ def _start_bucket_param_sync(self, buckets: List[ParameterBucket]) -> None: comm_stream.wait_stream(main_stream) # All-gather over distributed process group - if self.distributed_size > 1: + if buckets and self.distributed_size > 1: with torch.cuda.stream(comm_stream): group = self.distributed_process_group with _coalescing_manager(group, self.device, async_ops=True) as cm: