Skip to content

Commit

Permalink
Merge branch 'transformer_engine_rebase' into 'main'
Browse files Browse the repository at this point in the history
Transformer Engine Integration Rebase

See merge request ADLR/megatron-lm!487
  • Loading branch information
John Kamalu committed Dec 22, 2022
2 parents 52e6368 + 3499542 commit e1c334b
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 47 deletions.
40 changes: 40 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
parser = _add_vision_args(parser)
parser = _add_logging_args(parser)
parser = _add_inference_args(parser)
parser = _add_transformer_engine_args(parser)

# Custom arguments.
if extra_args_provider is not None:
Expand Down Expand Up @@ -304,6 +305,18 @@ def validate_args(args, defaults={}):
'v1.10 and above (Nvidia Pytorch container >= 21.07). Current ' \
'pytorch version is v%s.%s.' % (TORCH_MAJOR, TORCH_MINOR)

# Tranformer-Engine/FP8 related checking
if args.fp8_e4m3 or args.fp8_hybrid:
assert args.transformer_impl == 'transformer_engine', \
'transformer-engine required for fp8 training and inference'

assert not (args.fp8_e4m3 and args.fp8_hybrid), \
'cannot train with both fp8 e4m3 and hybrid formatting'

if args.fp16:
assert args.transformer_impl == 'local', \
'transformer-engine not yet approved for fp16 training and inference'

if args.recompute_granularity == 'selective':
assert args.recompute_method is None, \
'recompute method is not yet supported for ' \
Expand Down Expand Up @@ -355,6 +368,33 @@ def _check_arg_is_not_none(args, arg):
assert getattr(args, arg) is not None, '{} argument is None'.format(arg)


def _add_transformer_engine_args(parser):
group = parser.add_argument_group(title='Transformer-Engine')

group.add_argument('--fp8-e4m3', action='store_true',
help='E4M3 TransformerLayer', dest='fp8_e4m3')
group.add_argument('--fp8-hybrid', action='store_true',
help='Hybrid FP8 TransformerLayer', dest='fp8_hybrid')
group.add_argument('--no-fp8-wgrad', action='store_false',
help='Execute wgrad in higher precision even for FP8 runs', dest='fp8_wgrad')
group.add_argument('--fp8-margin', type=int, default=0,
help='Scaling margin for fp8', dest='fp8_margin')
group.add_argument('--fp8-interval', type=int, default=1,
help='Scaling update interval for fp8', dest='fp8_interval')
group.add_argument('--transformer-impl', default='local',
choices=['local', 'transformer_engine'],
help='Which Transformer implementation to use.',
dest='transformer_impl')
group.add_argument('--fp8-amax-history-len', type=int, default=1,
help='Number of steps for which amax history is recorded per tensor',
dest='fp8_amax_history_len')
group.add_argument('--fp8-amax-compute-algo', default='most_recent',
choices=['most_recent', 'max'],
help='Algorithm for computing amax from history',
dest='fp8_amax_compute_algo')

return parser

def _add_inference_args(parser):
group = parser.add_argument_group(title='inference')

Expand Down
12 changes: 9 additions & 3 deletions megatron/fused_kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@ def load(args):

# Check if cuda 11 is installed for compute capability 8.0
cc_flag = []
_, bare_metal_major, _ = _get_cuda_bare_metal_version(
_, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version(
cpp_extension.CUDA_HOME)
if int(bare_metal_major) >= 11:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_80,code=sm_80')
if int(bare_metal_minor) >= 7:
cc_flag.append('-gencode')
cc_flag.append('arch=compute_90,code=sm_90')

# Build path
srcpath = pathlib.Path(__file__).parent.absolute()
Expand Down Expand Up @@ -75,11 +78,14 @@ def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
# Mixed precision fused layer norm.
# =================================

extra_hopper_flags = ['-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__']

extra_cuda_flags = ['-maxrregcount=50']
sources=[srcpath / 'layer_norm_cuda.cpp',
srcpath / 'layer_norm_cuda_kernel.cu']
fused_mix_prec_layer_norm_cuda = _cpp_extention_load_helper(
"fused_mix_prec_layer_norm_cuda", sources, extra_cuda_flags)
"fused_mix_prec_layer_norm_cuda", sources, extra_cuda_flags + extra_hopper_flags)

# =================================
# Fused gradient accumulation to weight gradient computation of linear layer
Expand All @@ -89,7 +95,7 @@ def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
sources=[srcpath / 'fused_weight_gradient_dense.cpp',
srcpath / 'fused_weight_gradient_dense.cu']
fused_dense_cuda = _cpp_extention_load_helper(
"fused_dense_cuda", sources, [])
"fused_dense_cuda", sources, extra_hopper_flags)


def _get_cuda_bare_metal_version(cuda_dir):
Expand Down
190 changes: 147 additions & 43 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import torch.nn.functional as F

from megatron import get_timers, get_args, core
from megatron import get_timers, get_args, core, get_num_microbatches
from .module import MegatronModule
from megatron.core import mpu, tensor_parallel
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
Expand All @@ -15,7 +15,6 @@
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu


""" We use the following notation throughout this file:
h: hidden size
n: number of attention heads
Expand Down Expand Up @@ -810,6 +809,7 @@ def __init__(self, init_method, output_layer_init_method,
self.post_process = post_process
self.input_tensor = None
self.drop_path_rate = drop_path_rate
self.transformer_impl = args.transformer_impl

# Store activation checkpoiting flag.
self.recompute_granularity = args.recompute_granularity
Expand All @@ -820,6 +820,31 @@ def __init__(self, init_method, output_layer_init_method,

self.sequence_parallel = args.sequence_parallel

# Transformer Engine Init.
if self.transformer_impl == 'transformer_engine':
global transformer_engine
import transformer_engine
self.use_fp8 = args.fp8_e4m3 or args.fp8_hybrid
self.fp8_recipe = None
self.fp8_group = mpu.get_data_parallel_group()
if self.use_fp8:
if args.fp8_e4m3:
fp8_format = transformer_engine.common.recipe.Format.E4M3
elif args.fp8_hybrid:
fp8_format = transformer_engine.common.recipe.Format.HYBRID
self.fp8_recipe = transformer_engine.common.recipe.DelayedScaling(
margin=args.fp8_margin,
interval=args.fp8_interval,
fp8_format=fp8_format,
amax_history_len=args.fp8_amax_history_len,
amax_compute_algo=args.fp8_amax_compute_algo,
override_linear_precision=(False, False, not args.fp8_wgrad),
)

self.num_microbatches_in_previous_step = -1
self.microbatch_count = 0
self.checkpoint_core_attention = args.recompute_granularity == 'selective'

# Number of layers.
self.num_layers = _get_num_layers(
args,
Expand All @@ -830,13 +855,43 @@ def __init__(self, init_method, output_layer_init_method,

# Transformer layers.
def build_layer(layer_number):
return ParallelTransformerLayer(
init_method,
output_layer_init_method,
layer_number,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type,
drop_path_rate=self.drop_path_rates[layer_number - 1])
if args.transformer_impl == 'local':
return ParallelTransformerLayer(
init_method,
output_layer_init_method,
layer_number,
layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type,
drop_path_rate=self.drop_path_rates[layer_number - 1])
else:
return transformer_engine.pytorch.TransformerLayer(
args.hidden_size,
args.ffn_hidden_size,
args.num_attention_heads,
layernorm_epsilon=args.layernorm_epsilon,
hidden_dropout=args.hidden_dropout,
attention_dropout=args.attention_dropout,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
layer_number=layer_number,
kv_channels=args.kv_channels,
self_attn_mask_type=self_attn_mask_type.name,
tp_group=mpu.get_tensor_model_parallel_group(),
get_rng_state_tracker=tensor_parallel.get_cuda_rng_tracker,
fuse_wgrad_accumulation=args.gradient_accumulation_fusion,
apply_query_key_layer_scaling=args.apply_query_key_layer_scaling,
attention_softmax_in_fp32=args.attention_softmax_in_fp32,
seq_length=args.seq_length,
micro_batch_size=args.micro_batch_size,
sequence_parallel=args.sequence_parallel,
params_dtype=args.params_dtype,
apply_residual_connection_post_layernorm=args.apply_residual_connection_post_layernorm,
output_layernorm=False,
layer_type="encoder",
drop_path_rate=self.drop_path_rates[layer_number - 1],
set_parallel_mode=True,
fuse_qkv_params=True)

if args.virtual_pipeline_model_parallel_size is not None:
assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
'num_layers_per_stage must be divisible by ' \
Expand Down Expand Up @@ -896,30 +951,40 @@ def _get_layer(self, layer_number):
return self.layers[layer_number]

def _checkpointed_forward(self, hidden_states, attention_mask,
encoder_output, enc_dec_attn_mask):
encoder_output, enc_dec_attn_mask, is_first_microbatch):
"""Forward method with activation checkpointing."""
def custom(start, end):
def custom_forward(*inputs):
x_ = inputs[0]
attention_mask = inputs[1]
encoder_output = inputs[2]
enc_dec_attn_mask = inputs[3]
def custom(start, end, is_transformer_engine=False):
def custom_forward(*args, **kwargs):
for index in range(start, end):
layer = self._get_layer(index)
x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask)
x_ = layer(*args, **kwargs)
return x_
return custom_forward
def custom_forward_transformer_engine(*args, **kwargs):
return custom_forward(*args, is_first_microbatch=is_first_microbatch, **kwargs)
if not is_transformer_engine:
return custom_forward
else:
return custom_forward_transformer_engine

if self.recompute_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l = 0
while l < self.num_layers:
hidden_states = tensor_parallel.checkpoint(
custom(l, l + self.recompute_num_layers),
self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
if self.transformer_impl == 'transformer_engine':
hidden_states = transformer_engine.pytorch.distributed.checkpoint(
custom(l, l + self.recompute_num_layers, is_transformer_engine=True),
self.distribute_saved_activations,
tensor_parallel.get_cuda_rng_tracker,
mpu.get_tensor_model_parallel_group(),
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
hidden_states = tensor_parallel.checkpoint(
custom(l, l + self.recompute_num_layers),
self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)

l += self.recompute_num_layers

elif self.recompute_method == 'block':
Expand All @@ -928,13 +993,25 @@ def custom_forward(*inputs):
# A method fully use the device memory removing redundant re-computation.
for l in range(self.num_layers):
if l < self.recompute_num_layers:
hidden_states = tensor_parallel.checkpoint(
custom(l, l + 1),
self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
if self.transformer_impl == 'transformer_engine':
hidden_states = transformer_engine.pytorch.distributed.checkpoint(
custom(l, l + 1, is_transformer_engine=True),
self.distribute_saved_activations,
tensor_parallel.get_cuda_rng_tracker,
mpu.get_tensor_model_parallel_group(),
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
hidden_states = tensor_parallel.checkpoint(
custom(l, l + 1),
self.distribute_saved_activations,
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
hidden_states = custom(l, l + 1)(
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
if self.transformer_impl == 'transformer_engine':
hidden_states = custom(l, l + 1, is_transformer_engine=True)(
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
hidden_states = custom(l, l + 1)(
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask)
else:
raise ValueError("Invalid activation recompute method.")

Expand Down Expand Up @@ -991,21 +1068,48 @@ def forward(self, hidden_states, attention_mask,
rng_context = nullcontext()

with rng_context:
# Forward pass.
if self.recompute_granularity == 'full':
hidden_states = self._checkpointed_forward(hidden_states,
attention_mask,
encoder_output,
enc_dec_attn_mask)
else:
for index in range(self.num_layers):
layer = self._get_layer(index)
hidden_states = layer(
hidden_states,
attention_mask,
encoder_output=encoder_output,
enc_dec_attn_mask=enc_dec_attn_mask,
inference_params=inference_params)
# The fp8_autocast context manager is a no-op when enabled=True
# The if...else serves to short circuit name resolution for fp8_autocast
with transformer_engine.pytorch.fp8_autocast(
enabled=self.use_fp8,
fp8_recipe=self.fp8_recipe,
fp8_group=self.fp8_group
) if self.use_fp8 else nullcontext():
# Determine if the current iteration is first microbatch
if self.num_microbatches_in_previous_step != get_num_microbatches():
self.microbatch_count = 0 # Reset count on new batch size rampup interval
self.num_microbatches_in_previous_step = get_num_microbatches()
is_first_microbatch = self.microbatch_count % get_num_microbatches() == 0

# Forward pass.
if self.recompute_granularity == 'full':
hidden_states = self._checkpointed_forward(hidden_states,
attention_mask,
encoder_output,
enc_dec_attn_mask,
is_first_microbatch)
else:
forward_kwargs = {
'encoder_output': encoder_output,
'enc_dec_attn_mask': enc_dec_attn_mask,
'inference_params': inference_params,
}

if self.transformer_impl == 'transformer_engine':
forward_kwargs['is_first_microbatch'] = is_first_microbatch
forward_kwargs['checkpoint_core_attention'] = self.checkpoint_core_attention

for index in range(self.num_layers):
layer = self._get_layer(index)

hidden_states = layer(
hidden_states,
attention_mask,
**forward_kwargs)

# Skip counter update for eval and activation checkpointing
if torch.is_grad_enabled() and self.training:
self.microbatch_count += 1

# Final layer norm.
if self.post_process and self.post_layer_norm:
Expand Down
7 changes: 7 additions & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from megatron.checkpointing import save_checkpoint
from megatron.model import Float16Module
from megatron.model import ModelType
from megatron.model import GPTModel
from megatron.optimizer import get_megatron_optimizer
from megatron.initialize import initialize_megatron
from megatron.initialize import write_args_to_tensorboard
Expand Down Expand Up @@ -251,6 +252,12 @@ def get_model(model_provider_func, model_type=ModelType.encoder_or_decoder, wrap
if not isinstance(model, list):
model = [model]

# Disallow training and inference with Transformer Engine
# for non-GPT models
args.allow_transformer_engine = all([type(m) == GPTModel for m in model])
assert args.allow_transformer_engine or args.transformer_impl == 'local', \
'Transformer Engine is only approved for GPT models'

# Set tensor model parallel attributes if not set.
# Only parameters that are already tensor model parallel have these
# attributes set for them. We should make sure the default attributes
Expand Down
4 changes: 3 additions & 1 deletion pretrain_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):

pretrain(train_valid_test_datasets_provider, model_provider,
ModelType.encoder_or_decoder,
forward_step, args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}
)

0 comments on commit e1c334b

Please sign in to comment.