Skip to content

Commit

Permalink
Merge branch 'trintamaki/linear-tp' into 'main'
Browse files Browse the repository at this point in the history
LinearWithFrozenWeight backward fix when TP > 1

See merge request ADLR/megatron-lm!1279
  • Loading branch information
jaredcasper committed Apr 29, 2024
2 parents 0264800 + 5fffdfc commit 20674cc
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 48 deletions.
9 changes: 2 additions & 7 deletions megatron/core/model_parallel_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,7 @@ class ModelParallelConfig:
"""

async_tensor_model_parallel_allreduce: bool = False
"""If true, enables asynchronous execution of tensor-model-parallel all-reduce with weight
gradient compuation of a column-linear layer.
"""
"""NOTE: Deprecated. This flag is ignored."""

use_te_rng_tracker: bool = False
"""If true, uses RNG state tracker in TransformerEngine if exists.
Expand Down Expand Up @@ -227,7 +225,7 @@ class ModelParallelConfig:
"""

defer_embedding_wgrad_compute: bool = False
"""If true, defers the embedding WGRAD GEMMs while pipeline flush is
"""If true, defers the embedding WGRAD GEMMs while pipeline flush is
taking place enabling us to hide pipeline flush latency. Defaults to False.
"""

Expand Down Expand Up @@ -270,9 +268,6 @@ def __post_init__(self):
if self.sequence_parallel:
if self.tensor_model_parallel_size <= 1:
raise ValueError("Can not use sequence paralllelism without tensor parallelism")
if self.async_tensor_model_parallel_allreduce:
# sequence_parallelism already does this async
self.async_tensor_model_parallel_allreduce = False

if self.pipeline_model_parallel_size > 1:
if self.pipeline_dtype is None:
Expand Down
92 changes: 63 additions & 29 deletions megatron/core/tensor_parallel/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,10 @@ class LinearWithFrozenWeight(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(
ctx, input, weight, bias,
ctx, input, weight, bias, allreduce_dgrad,
):
ctx.save_for_backward(weight)
ctx.allreduce_dgrad = allreduce_dgrad
output = torch.matmul(input, weight.t())
if bias is not None:
output = output + bias
Expand All @@ -271,7 +272,12 @@ def forward(
def backward(ctx, grad_output):
(weight,) = ctx.saved_tensors
grad_input = grad_output.matmul(weight)
return grad_input, None, None

if ctx.allreduce_dgrad:
# All-reduce. Note: here async and sync are effectively the same.
torch.distributed.all_reduce(grad_input, group=get_tensor_model_parallel_group())

return grad_input, None, None, None


def linear_with_frozen_weight(
Expand All @@ -282,6 +288,7 @@ def linear_with_frozen_weight(
async_grad_allreduce: bool,
sequence_parallel: bool,
grad_output_buffer: Optional[List[torch.Tensor]] = None,
allreduce_dgrad: bool = None,
) -> torch.Tensor:
"""Linear layer execution with weight.requires_grad == False.
Expand Down Expand Up @@ -312,6 +319,10 @@ def linear_with_frozen_weight(
grad_output_buffer (List[torch.Tensor] optional): dummy argument, used to
keep the API unified between all forward implementation functions.
allreduce_dgrad (bool): Do the allreduce of input gradients.
Here, async and sync allreduce are the same. If sequence_parallel is
True, this must be False, as no all reduce is performed.
"""

assert grad_output_buffer is None, (
Expand All @@ -324,10 +335,17 @@ def linear_with_frozen_weight(
else:
input = input

if allreduce_dgrad is None:
warnings.warn(
"async_grad_allreduce is deprecated and will be removed in a future release. use allreduce_dgrad instead."
)
allreduce_dgrad = async_grad_allreduce

args = [
input,
weight,
bias,
allreduce_dgrad,
]

return LinearWithFrozenWeight.apply(*args)
Expand All @@ -344,14 +362,14 @@ def forward(
weight,
bias,
gradient_accumulation_fusion,
async_grad_allreduce,
allreduce_dgrad,
sequence_parallel,
grad_output_buffer,
):
ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
ctx.async_grad_allreduce = async_grad_allreduce
ctx.allreduce_dgrad = allreduce_dgrad
ctx.sequence_parallel = sequence_parallel
ctx.grad_output_buffer = grad_output_buffer

Expand Down Expand Up @@ -413,7 +431,7 @@ def backward(ctx, grad_output):
grad_output, total_input
)

if ctx.async_grad_allreduce:
if ctx.allreduce_dgrad:
# Asynchronous all-reduce
handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True
Expand All @@ -422,7 +440,7 @@ def backward(ctx, grad_output):
# all-reduce is scheduled before the weight gradient computation

if ctx.sequence_parallel:
assert not ctx.async_grad_allreduce
assert not ctx.allreduce_dgrad
dim_size = list(input.size())
sub_grad_input = torch.empty(
dim_size, dtype=input.dtype, device=torch.cuda.current_device(), requires_grad=False
Expand Down Expand Up @@ -479,7 +497,7 @@ def backward(ctx, grad_output):
# provided during forward
return sub_grad_input, grad_weight, grad_bias, None, None, None, None

if ctx.async_grad_allreduce:
if ctx.allreduce_dgrad:
handle.wait()

return grad_input, grad_weight, grad_bias, None, None, None, None
Expand All @@ -493,6 +511,7 @@ def linear_with_grad_accumulation_and_async_allreduce(
async_grad_allreduce: bool,
sequence_parallel: bool,
grad_output_buffer: Optional[List[torch.Tensor]] = None,
allreduce_dgrad: bool = None,
) -> torch.Tensor:
"""Linear layer execution with asynchronous communication and
gradient accumulation fusion in backprop.
Expand Down Expand Up @@ -520,7 +539,6 @@ def linear_with_grad_accumulation_and_async_allreduce(
in the order they are called.
Args:
input (torch.Tensor required): input like torch.nn.functional.linear
weight (torch.Tensor required): weight like torch.nn.functional.linear
Expand All @@ -536,26 +554,39 @@ def linear_with_grad_accumulation_and_async_allreduce(
" Note that the extension requires CUDA>=11. Otherwise, you
must turn off gradient accumulation fusion."
async_grad_allreduce (bool required): Do the allreduce of input
gradients asyncronously with the computation of weight
gradients. If sequence_parallel is True, this must be
False, as no all reduce is performed.
sequence_parallel (bool required): Indicates that sequence
parallelism is used and thus in the forward pass the input is
all gathered, and the backward pass the input gradients are
reduce scattered.
grad_output_buffer (List[torch.Tensor] optional): Buffer used to save
output gradients when embedding table wgrad compute is deferred.
Defaults to None.
sequence_parallel (bool required): Indicates that sequence
parallelism is used and thus in the forward pass the input is
all gathered, and the backward pass the input gradients are
reduce scattered.
grad_output_buffer (List[torch.Tensor] optional): Buffer used to save
output gradients when embedding table wgrad compute is deferred.
Defaults to None.
allreduce_dgrad (bool): Do the allreduce of input gradients.
The allreduce is done asynchronously with the computation of weight
gradients. If sequence_parallel is True, this must be
False, as no all reduce is performed.
"""
if allreduce_dgrad is None:
warnings.warn(
"async_grad_allreduce is deprecated and will be removed in a future release. use allreduce_dgrad instead."
)
allreduce_dgrad = async_grad_allreduce

args = [
input,
weight,
bias,
gradient_accumulation_fusion,
async_grad_allreduce,
allreduce_dgrad,
sequence_parallel,
grad_output_buffer,
]
Expand All @@ -570,7 +601,7 @@ def linear_with_grad_accumulation_and_async_allreduce(
)
linear_with_grad_accumulation_and_async_allreduce.warned = True

if async_grad_allreduce:
if allreduce_dgrad:
warnings.warn(
"When using async grad allreduce it is recommended to set the "
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
Expand Down Expand Up @@ -710,10 +741,6 @@ def __init__(
else:
self.register_parameter('bias', None)

self.async_tensor_model_parallel_allreduce = (
config.async_tensor_model_parallel_allreduce and world_size > 1
)

self.sequence_parallel = config.sequence_parallel
if self.sequence_parallel and world_size <= 1:
warnings.warn(
Expand All @@ -722,6 +749,8 @@ def __init__(
)
self.sequence_parallel = False

self.allreduce_dgrad = world_size > 1 and not self.sequence_parallel

if config.gradient_accumulation_fusion and not _grad_accum_fusion_available:
raise RuntimeError(
"ColumnParallelLinear was called with gradient_accumulation_fusion set "
Expand All @@ -734,10 +763,9 @@ def __init__(
)
self.gradient_accumulation_fusion = config.gradient_accumulation_fusion

if self.async_tensor_model_parallel_allreduce and self.sequence_parallel:
if self.allreduce_dgrad and self.sequence_parallel:
raise RuntimeError(
"`async_tensor_model_parallel_allreduce` and `sequence_parallel` "
"cannot be enabled at the same time."
"`allreduce_dgrad` and `sequence_parallel` cannot be enabled at the same time."
)

self._forward_impl = linear_with_grad_accumulation_and_async_allreduce
Expand Down Expand Up @@ -791,7 +819,7 @@ def forward(self, input_: torch.Tensor, weight: Optional[torch.Tensor] = None):
bias = self.bias if not self.skip_bias_add else None

if (
self.async_tensor_model_parallel_allreduce
self.allreduce_dgrad
or self.sequence_parallel
or self.explicit_expert_comm
or self.disable_grad_reduce
Expand All @@ -809,18 +837,19 @@ def forward(self, input_: torch.Tensor, weight: Optional[torch.Tensor] = None):
else:
self._forward_impl = linear_with_grad_accumulation_and_async_allreduce

allreduce_dgrad = False if self.explicit_expert_comm else self.allreduce_dgrad

output_parallel = self._forward_impl(
input=input_parallel,
weight=weight,
bias=bias,
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
async_grad_allreduce=False
if self.explicit_expert_comm
else self.async_tensor_model_parallel_allreduce,
async_grad_allreduce=allreduce_dgrad,
sequence_parallel=False if self.explicit_expert_comm else self.sequence_parallel,
grad_output_buffer=self.grad_output_buffer
if self.config.defer_embedding_wgrad_compute
else None,
allreduce_dgrad=allreduce_dgrad,
)
if self.gather_output:
# All-gather across the partitions.
Expand Down Expand Up @@ -1002,13 +1031,18 @@ def forward(self, input_):
self._forward_impl = linear_with_frozen_weight
else:
self._forward_impl = linear_with_grad_accumulation_and_async_allreduce

allreduce_dgrad = False

output_parallel = self._forward_impl(
input=input_parallel,
weight=self.weight,
bias=None,
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
async_grad_allreduce=False,
async_grad_allreduce=allreduce_dgrad,
sequence_parallel=False,
grad_output_buffer=None,
allreduce_dgrad=allreduce_dgrad,
)

# All-reduce across all the partitions.
Expand Down
17 changes: 9 additions & 8 deletions megatron/legacy/model/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,25 @@ def parallel_lm_logits(input_, word_embeddings_weight, parallel_output,
"""LM logits using word embedding weights."""
args = get_args()
# Parallel logits.
if args.async_tensor_model_parallel_allreduce or\
args.sequence_parallel:
model_parallel = mpu.get_tensor_model_parallel_world_size() > 1
if model_parallel or args.sequence_parallel:
input_parallel = input_
model_parallel = mpu.get_tensor_model_parallel_world_size() > 1
async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \
model_parallel and not args.sequence_parallel
allreduce_dgrad = model_parallel and not args.sequence_parallel
else:
input_parallel = tensor_parallel.copy_to_tensor_model_parallel_region(input_)
async_grad_allreduce = False
allreduce_dgrad = False

# Matrix multiply.
logits_parallel = tensor_parallel.linear_with_grad_accumulation_and_async_allreduce(
input=input_parallel,
weight=word_embeddings_weight,
bias=bias,
gradient_accumulation_fusion=args.gradient_accumulation_fusion,
async_grad_allreduce=async_grad_allreduce,
sequence_parallel=args.sequence_parallel)
async_grad_allreduce=allreduce_dgrad,
sequence_parallel=args.sequence_parallel,
grad_output_buffer=None,
allreduce_dgrad=allreduce_dgrad,
)
# Gather if needed.

if parallel_output:
Expand Down
6 changes: 2 additions & 4 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,7 @@ def _add_training_args(parser):
' overlap of Tensor parallel communication and GEMM kernels.')
group.add_argument('--tp-comm-overlap-cfg', type=str, default=None,
help='Config file when tp_comm_overlap is enabled.')
group.add_argument('--disable-tp-comm-overlap-ag', action='store_false',
group.add_argument('--disable-tp-comm-overlap-ag', action='store_false',
help=('Disables the All-Gather overlap with GEMM by '
'pipelining the GEMM and All-Gather.'),
dest='tp_comm_overlap_ag')
Expand Down Expand Up @@ -1070,9 +1070,7 @@ def _add_training_args(parser):
help='Single pass vs multiple pass data loader')
group.add_argument('--no-async-tensor-model-parallel-allreduce',
action='store_false',
help='Disable asynchronous execution of '
'tensor-model-parallel all-reduce with weight '
'gradient compuation of a column-linear layer.',
help='DEPRECATED. This flag is ignored.',
dest='async_tensor_model_parallel_allreduce')
group.add_argument('--no-persist-layer-norm', action='store_true',
help='Disable using persistent fused layer norm kernel. '
Expand Down
Empty file.
Loading

0 comments on commit 20674cc

Please sign in to comment.