Skip to content

Commit

Permalink
Merge branch 'lora-grad-output-buffer-bugfix' into 'core_r0.6.0'
Browse files Browse the repository at this point in the history
Make sure APIs are consistent between linear layer forward impls

See merge request ADLR/megatron-lm!1307
  • Loading branch information
jaredcasper committed Apr 4, 2024
2 parents 3dd97ed + 4e7240e commit d4fa4dc
Showing 1 changed file with 18 additions and 8 deletions.
26 changes: 18 additions & 8 deletions megatron/core/tensor_parallel/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,11 @@ def sharded_state_dict(

class LinearWithFrozenWeight(torch.autograd.Function):
"""Linear operator that does not calculate gradient for weight.
This op and LinearWithGradAccumulationAndAsyncCommunication performs
mathematically-identical forward and DGRAD.
This op and LinearWithGradAccumulationAndAsyncCommunication performs
mathematically-identical forward and DGRAD.
Conceptually this op is the same as torch.nn.functional.linear with
weight.requires_grad==False, but in experiments they are not identical
weight.requires_grad==False, but in experiments they are not identical
mathematically. """

@staticmethod
Expand Down Expand Up @@ -281,13 +281,14 @@ def linear_with_frozen_weight(
gradient_accumulation_fusion: bool,
async_grad_allreduce: bool,
sequence_parallel: bool,
grad_output_buffer: Optional[List[torch.Tensor]] = None,
) -> torch.Tensor:
"""Linear layer execution with weight.requires_grad == False.
This function handles linear layers with weight frozen (untrainable).
This function handles linear layers with weight frozen (untrainable).
In the forward, it only saves weight and does not save input activations.
In the backward, it does not perform weight gradient calculation, or
weight gradient allreduce.
In the backward, it does not perform weight gradient calculation, or
weight gradient allreduce.
Args:
Expand All @@ -297,18 +298,27 @@ def linear_with_frozen_weight(
bias (torch.Tensor optional): bias like torch.nn.functional.linear
gradient_accumulation_fusion (bool required): dummy argument, used to
gradient_accumulation_fusion (bool required): dummy argument, used to
keep the API unified between all forward implementation functions.
async_grad_allreduce (bool required): dummy argument, used to
async_grad_allreduce (bool required): dummy argument, used to
keep the API unified between all forward implementation functions.
sequence_parallel (bool required): Indicates that sequence
parallelism is used and thus in the forward pass the input is
all gathered, and the backward pass the input gradients are
reduce scattered.
grad_output_buffer (List[torch.Tensor] optional): dummy argument, used to
keep the API unified between all forward implementation functions.
"""

assert grad_output_buffer is None, (
"grad_output_buffer kwarg is only supported with "
"linear_with_grad_accumulation_and_async_allreduce"
)

if sequence_parallel:
input = gather_from_sequence_parallel_region(input, tensor_parallel_output_grad=True)
else:
Expand Down

0 comments on commit d4fa4dc

Please sign in to comment.