From 5fffdfc737f14297bc3781dfc9e273199d1df52e Mon Sep 17 00:00:00 2001 From: Tuomas Rintamaki Date: Mon, 29 Apr 2024 16:37:48 -0700 Subject: [PATCH] LinearWithFrozenWeight backward fix when TP > 1 --- megatron/core/model_parallel_config.py | 9 +- megatron/core/tensor_parallel/layers.py | 92 +++++++++++++------ megatron/legacy/model/language_model.py | 17 ++-- megatron/training/arguments.py | 6 +- tests/unit_tests/tensor_parallel/__init__.py | 0 .../unit_tests/tensor_parallel/test_layers.py | 52 +++++++++++ 6 files changed, 128 insertions(+), 48 deletions(-) create mode 100644 tests/unit_tests/tensor_parallel/__init__.py create mode 100644 tests/unit_tests/tensor_parallel/test_layers.py diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index ac06c76b56..d4312b9fdf 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -126,9 +126,7 @@ class ModelParallelConfig: """ async_tensor_model_parallel_allreduce: bool = False - """If true, enables asynchronous execution of tensor-model-parallel all-reduce with weight - gradient compuation of a column-linear layer. - """ + """NOTE: Deprecated. This flag is ignored.""" use_te_rng_tracker: bool = False """If true, uses RNG state tracker in TransformerEngine if exists. @@ -227,7 +225,7 @@ class ModelParallelConfig: """ defer_embedding_wgrad_compute: bool = False - """If true, defers the embedding WGRAD GEMMs while pipeline flush is + """If true, defers the embedding WGRAD GEMMs while pipeline flush is taking place enabling us to hide pipeline flush latency. Defaults to False. """ @@ -270,9 +268,6 @@ def __post_init__(self): if self.sequence_parallel: if self.tensor_model_parallel_size <= 1: raise ValueError("Can not use sequence paralllelism without tensor parallelism") - if self.async_tensor_model_parallel_allreduce: - # sequence_parallelism already does this async - self.async_tensor_model_parallel_allreduce = False if self.pipeline_model_parallel_size > 1: if self.pipeline_dtype is None: diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py index 177efc30b5..727af87564 100644 --- a/megatron/core/tensor_parallel/layers.py +++ b/megatron/core/tensor_parallel/layers.py @@ -258,9 +258,10 @@ class LinearWithFrozenWeight(torch.autograd.Function): @staticmethod @custom_fwd def forward( - ctx, input, weight, bias, + ctx, input, weight, bias, allreduce_dgrad, ): ctx.save_for_backward(weight) + ctx.allreduce_dgrad = allreduce_dgrad output = torch.matmul(input, weight.t()) if bias is not None: output = output + bias @@ -271,7 +272,12 @@ def forward( def backward(ctx, grad_output): (weight,) = ctx.saved_tensors grad_input = grad_output.matmul(weight) - return grad_input, None, None + + if ctx.allreduce_dgrad: + # All-reduce. Note: here async and sync are effectively the same. + torch.distributed.all_reduce(grad_input, group=get_tensor_model_parallel_group()) + + return grad_input, None, None, None def linear_with_frozen_weight( @@ -282,6 +288,7 @@ def linear_with_frozen_weight( async_grad_allreduce: bool, sequence_parallel: bool, grad_output_buffer: Optional[List[torch.Tensor]] = None, + allreduce_dgrad: bool = None, ) -> torch.Tensor: """Linear layer execution with weight.requires_grad == False. @@ -312,6 +319,10 @@ def linear_with_frozen_weight( grad_output_buffer (List[torch.Tensor] optional): dummy argument, used to keep the API unified between all forward implementation functions. + allreduce_dgrad (bool): Do the allreduce of input gradients. + Here, async and sync allreduce are the same. If sequence_parallel is + True, this must be False, as no all reduce is performed. + """ assert grad_output_buffer is None, ( @@ -324,10 +335,17 @@ def linear_with_frozen_weight( else: input = input + if allreduce_dgrad is None: + warnings.warn( + "async_grad_allreduce is deprecated and will be removed in a future release. use allreduce_dgrad instead." + ) + allreduce_dgrad = async_grad_allreduce + args = [ input, weight, bias, + allreduce_dgrad, ] return LinearWithFrozenWeight.apply(*args) @@ -344,14 +362,14 @@ def forward( weight, bias, gradient_accumulation_fusion, - async_grad_allreduce, + allreduce_dgrad, sequence_parallel, grad_output_buffer, ): ctx.save_for_backward(input, weight) ctx.use_bias = bias is not None ctx.gradient_accumulation_fusion = gradient_accumulation_fusion - ctx.async_grad_allreduce = async_grad_allreduce + ctx.allreduce_dgrad = allreduce_dgrad ctx.sequence_parallel = sequence_parallel ctx.grad_output_buffer = grad_output_buffer @@ -413,7 +431,7 @@ def backward(ctx, grad_output): grad_output, total_input ) - if ctx.async_grad_allreduce: + if ctx.allreduce_dgrad: # Asynchronous all-reduce handle = torch.distributed.all_reduce( grad_input, group=get_tensor_model_parallel_group(), async_op=True @@ -422,7 +440,7 @@ def backward(ctx, grad_output): # all-reduce is scheduled before the weight gradient computation if ctx.sequence_parallel: - assert not ctx.async_grad_allreduce + assert not ctx.allreduce_dgrad dim_size = list(input.size()) sub_grad_input = torch.empty( dim_size, dtype=input.dtype, device=torch.cuda.current_device(), requires_grad=False @@ -479,7 +497,7 @@ def backward(ctx, grad_output): # provided during forward return sub_grad_input, grad_weight, grad_bias, None, None, None, None - if ctx.async_grad_allreduce: + if ctx.allreduce_dgrad: handle.wait() return grad_input, grad_weight, grad_bias, None, None, None, None @@ -493,6 +511,7 @@ def linear_with_grad_accumulation_and_async_allreduce( async_grad_allreduce: bool, sequence_parallel: bool, grad_output_buffer: Optional[List[torch.Tensor]] = None, + allreduce_dgrad: bool = None, ) -> torch.Tensor: """Linear layer execution with asynchronous communication and gradient accumulation fusion in backprop. @@ -520,7 +539,6 @@ def linear_with_grad_accumulation_and_async_allreduce( in the order they are called. Args: - input (torch.Tensor required): input like torch.nn.functional.linear weight (torch.Tensor required): weight like torch.nn.functional.linear @@ -536,26 +554,39 @@ def linear_with_grad_accumulation_and_async_allreduce( " Note that the extension requires CUDA>=11. Otherwise, you must turn off gradient accumulation fusion." + async_grad_allreduce (bool required): Do the allreduce of input gradients asyncronously with the computation of weight gradients. If sequence_parallel is True, this must be False, as no all reduce is performed. - sequence_parallel (bool required): Indicates that sequence - parallelism is used and thus in the forward pass the input is - all gathered, and the backward pass the input gradients are - reduce scattered. - grad_output_buffer (List[torch.Tensor] optional): Buffer used to save - output gradients when embedding table wgrad compute is deferred. - Defaults to None. + sequence_parallel (bool required): Indicates that sequence + parallelism is used and thus in the forward pass the input is + all gathered, and the backward pass the input gradients are + reduce scattered. + + grad_output_buffer (List[torch.Tensor] optional): Buffer used to save + output gradients when embedding table wgrad compute is deferred. + Defaults to None. + + allreduce_dgrad (bool): Do the allreduce of input gradients. + The allreduce is done asynchronously with the computation of weight + gradients. If sequence_parallel is True, this must be + False, as no all reduce is performed. """ + if allreduce_dgrad is None: + warnings.warn( + "async_grad_allreduce is deprecated and will be removed in a future release. use allreduce_dgrad instead." + ) + allreduce_dgrad = async_grad_allreduce + args = [ input, weight, bias, gradient_accumulation_fusion, - async_grad_allreduce, + allreduce_dgrad, sequence_parallel, grad_output_buffer, ] @@ -570,7 +601,7 @@ def linear_with_grad_accumulation_and_async_allreduce( ) linear_with_grad_accumulation_and_async_allreduce.warned = True - if async_grad_allreduce: + if allreduce_dgrad: warnings.warn( "When using async grad allreduce it is recommended to set the " "environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for " @@ -710,10 +741,6 @@ def __init__( else: self.register_parameter('bias', None) - self.async_tensor_model_parallel_allreduce = ( - config.async_tensor_model_parallel_allreduce and world_size > 1 - ) - self.sequence_parallel = config.sequence_parallel if self.sequence_parallel and world_size <= 1: warnings.warn( @@ -722,6 +749,8 @@ def __init__( ) self.sequence_parallel = False + self.allreduce_dgrad = world_size > 1 and not self.sequence_parallel + if config.gradient_accumulation_fusion and not _grad_accum_fusion_available: raise RuntimeError( "ColumnParallelLinear was called with gradient_accumulation_fusion set " @@ -734,10 +763,9 @@ def __init__( ) self.gradient_accumulation_fusion = config.gradient_accumulation_fusion - if self.async_tensor_model_parallel_allreduce and self.sequence_parallel: + if self.allreduce_dgrad and self.sequence_parallel: raise RuntimeError( - "`async_tensor_model_parallel_allreduce` and `sequence_parallel` " - "cannot be enabled at the same time." + "`allreduce_dgrad` and `sequence_parallel` cannot be enabled at the same time." ) self._forward_impl = linear_with_grad_accumulation_and_async_allreduce @@ -791,7 +819,7 @@ def forward(self, input_: torch.Tensor, weight: Optional[torch.Tensor] = None): bias = self.bias if not self.skip_bias_add else None if ( - self.async_tensor_model_parallel_allreduce + self.allreduce_dgrad or self.sequence_parallel or self.explicit_expert_comm or self.disable_grad_reduce @@ -809,18 +837,19 @@ def forward(self, input_: torch.Tensor, weight: Optional[torch.Tensor] = None): else: self._forward_impl = linear_with_grad_accumulation_and_async_allreduce + allreduce_dgrad = False if self.explicit_expert_comm else self.allreduce_dgrad + output_parallel = self._forward_impl( input=input_parallel, weight=weight, bias=bias, gradient_accumulation_fusion=self.gradient_accumulation_fusion, - async_grad_allreduce=False - if self.explicit_expert_comm - else self.async_tensor_model_parallel_allreduce, + async_grad_allreduce=allreduce_dgrad, sequence_parallel=False if self.explicit_expert_comm else self.sequence_parallel, grad_output_buffer=self.grad_output_buffer if self.config.defer_embedding_wgrad_compute else None, + allreduce_dgrad=allreduce_dgrad, ) if self.gather_output: # All-gather across the partitions. @@ -1002,13 +1031,18 @@ def forward(self, input_): self._forward_impl = linear_with_frozen_weight else: self._forward_impl = linear_with_grad_accumulation_and_async_allreduce + + allreduce_dgrad = False + output_parallel = self._forward_impl( input=input_parallel, weight=self.weight, bias=None, gradient_accumulation_fusion=self.gradient_accumulation_fusion, - async_grad_allreduce=False, + async_grad_allreduce=allreduce_dgrad, sequence_parallel=False, + grad_output_buffer=None, + allreduce_dgrad=allreduce_dgrad, ) # All-reduce across all the partitions. diff --git a/megatron/legacy/model/language_model.py b/megatron/legacy/model/language_model.py index 4fb5ae0dd5..1beb5f9e87 100644 --- a/megatron/legacy/model/language_model.py +++ b/megatron/legacy/model/language_model.py @@ -22,15 +22,13 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, """LM logits using word embedding weights.""" args = get_args() # Parallel logits. - if args.async_tensor_model_parallel_allreduce or\ - args.sequence_parallel: + model_parallel = mpu.get_tensor_model_parallel_world_size() > 1 + if model_parallel or args.sequence_parallel: input_parallel = input_ - model_parallel = mpu.get_tensor_model_parallel_world_size() > 1 - async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \ - model_parallel and not args.sequence_parallel + allreduce_dgrad = model_parallel and not args.sequence_parallel else: input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region(input_) - async_grad_allreduce = False + allreduce_dgrad = False # Matrix multiply. logits_parallel = tensor_parallel.linear_with_grad_accumulation_and_async_allreduce( @@ -38,8 +36,11 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, weight=word_embeddings_weight, bias=bias, gradient_accumulation_fusion=args.gradient_accumulation_fusion, - async_grad_allreduce=async_grad_allreduce, - sequence_parallel=args.sequence_parallel) + async_grad_allreduce=allreduce_dgrad, + sequence_parallel=args.sequence_parallel, + grad_output_buffer=None, + allreduce_dgrad=allreduce_dgrad, + ) # Gather if needed. if parallel_output: diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index dbbae053bc..c6206496f7 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -982,7 +982,7 @@ def _add_training_args(parser): ' overlap of Tensor parallel communication and GEMM kernels.') group.add_argument('--tp-comm-overlap-cfg', type=str, default=None, help='Config file when tp_comm_overlap is enabled.') - group.add_argument('--disable-tp-comm-overlap-ag', action='store_false', + group.add_argument('--disable-tp-comm-overlap-ag', action='store_false', help=('Disables the All-Gather overlap with GEMM by ' 'pipelining the GEMM and All-Gather.'), dest='tp_comm_overlap_ag') @@ -1070,9 +1070,7 @@ def _add_training_args(parser): help='Single pass vs multiple pass data loader') group.add_argument('--no-async-tensor-model-parallel-allreduce', action='store_false', - help='Disable asynchronous execution of ' - 'tensor-model-parallel all-reduce with weight ' - 'gradient compuation of a column-linear layer.', + help='DEPRECATED. This flag is ignored.', dest='async_tensor_model_parallel_allreduce') group.add_argument('--no-persist-layer-norm', action='store_true', help='Disable using persistent fused layer norm kernel. ' diff --git a/tests/unit_tests/tensor_parallel/__init__.py b/tests/unit_tests/tensor_parallel/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit_tests/tensor_parallel/test_layers.py b/tests/unit_tests/tensor_parallel/test_layers.py new file mode 100644 index 0000000000..4ed6b16fa3 --- /dev/null +++ b/tests/unit_tests/tensor_parallel/test_layers.py @@ -0,0 +1,52 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import pytest +import torch + +from megatron.core.tensor_parallel.layers import linear_with_frozen_weight +from megatron.core.tensor_parallel.mappings import gather_from_tensor_model_parallel_region +from tests.unit_tests.test_utilities import Utils + + +@pytest.mark.parametrize("tensor_parallel,allreduce_dgrad", [(1, False), (8, True)]) +def test_LinearWithFrozenWeight(tensor_parallel, allreduce_dgrad): + Utils.initialize_model_parallel(tensor_parallel, 1) + + size_per_partition = int(8 / tensor_parallel) + + # Input is an 8x8 identity matrix. + input_data = torch.eye(8).cuda() + input_data.requires_grad = True + + # Weight is an 8x8 matrix of all ones. If tensor parallelism > 1, the weight is partitioned evenly across GPUs. + weight = torch.ones((size_per_partition, 8)).cuda() + + # Bias is a vector of length 8 of all zeros. If tensor parallelism > 1, the bias is partitioned evenly across GPUs + bias = torch.zeros((size_per_partition)).cuda() + + gradient_accumulation_fusion = False + async_grad_allreduce = allreduce_dgrad + sequence_parallel = False + grad_output_buffer = None + + output_parallel = linear_with_frozen_weight( + input_data, + weight, + bias, + gradient_accumulation_fusion, + async_grad_allreduce, + sequence_parallel, + grad_output_buffer, + allreduce_dgrad, + ) + output = gather_from_tensor_model_parallel_region( + output_parallel + ) # no-op if tensor_parallel == 1. + output.sum().backward() + + expected_output = torch.ones(8).cuda() + expected_grad = 8 * torch.ones(8).cuda() + + assert torch.allclose(output, expected_output) + assert torch.allclose(input_data.grad, expected_grad) + + Utils.destroy_model_parallel()