Skip to content

Commit

Permalink
Merge branch 'jbaczek/extend_transformer_block_spec' into 'core_r0.7.…
Browse files Browse the repository at this point in the history
…0.beta'

Add layer norm to TransformerBlockSubmodules

See merge request ADLR/megatron-lm!1350

(cherry picked from commit 4326832)

8fad4687 Add layer norm to TransformerBlockSubmodules
0c042672 Update formatting
60dde170 fix formatting issue
ccb145a1 Define whether to use final layer norm in TransformerBlock from the spec...
4d41aa6c Restore arguments needed for toggling ln of in intermediate layers of PP
8e15168e Remove incorrect warnings
  • Loading branch information
ShriyaPalsamudram authored and jbaczek committed Jul 2, 2024
1 parent 0d7bdd8 commit 561f250
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
18 changes: 12 additions & 6 deletions megatron/core/transformer/transformer_block.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

import re
import warnings
from contextlib import nullcontext
from dataclasses import dataclass
from typing import List, Tuple, Union
from typing import List, Optional, Tuple, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -65,6 +66,7 @@ def get_num_layers_to_build(config: TransformerConfig) -> int:
@dataclass
class TransformerBlockSubmodules:
layer_specs: List[ModuleSpec] = None
layer_norm: Optional[Union[ModuleSpec, torch.nn.Module]] = None


def _get_block_submodules(
Expand All @@ -83,7 +85,7 @@ def _get_block_submodules(
return spec.submodules
elif issubclass(spec.module, BaseTransformerLayer):
num_layers = get_num_layers_to_build(config)
return TransformerBlockSubmodules(layer_specs=[spec] * num_layers)
return TransformerBlockSubmodules(layer_specs=[spec] * num_layers, layer_norm=TENorm,)
else:
raise Exception(f"specialize for {spec.module.__name__}.")
else:
Expand Down Expand Up @@ -176,13 +178,17 @@ def build_layer(layer_spec, layer_number):
# else:
# self.layers = torch.nn.ModuleList([build_layer(i + 1 + offset) for i in range(self.num_layers)])

if self.post_process and self.post_layer_norm:
# Final layer norm before output.
self.final_layernorm = TENorm(
# In pipeline parallelism, we want to add this LN only to the last stage of the pipeline
# self.post_process and self.post_layer_norm guide this behavior
if self.submodules.layer_norm and self.post_process and self.post_layer_norm:
self.final_layernorm = build_module(
self.submodules.layer_norm,
config=self.config,
hidden_size=self.config.hidden_size,
eps=self.config.layernorm_epsilon,
)
else:
self.final_layernorm = None # Either this or nn.Identity

def _get_layer(self, layer_number: int):
return self.layers[layer_number]
Expand Down Expand Up @@ -415,7 +421,7 @@ def forward(
hidden_states = self.group_prefetch_offload_commit_async(hidden_states)

# Final layer norm.
if self.post_process and self.post_layer_norm:
if self.final_layernorm is not None:
hidden_states = self.final_layernorm(hidden_states)

return hidden_states
Expand Down
3 changes: 2 additions & 1 deletion tests/unit_tests/transformer/test_spec_customization.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,8 @@ def test_transformer_block_custom(self):
layer_specs=[
ModuleSpec(module=TransformerLayer, submodules=layer_local_spec.submodules)
]
* transformer_config.num_layers
* transformer_config.num_layers,
layer_norm=TENorm,
)
# make sure the model init conditions are identical
model_parallel_cuda_manual_seed(123)
Expand Down

0 comments on commit 561f250

Please sign in to comment.