Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for te v2.0 #12273

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
Use context instead of directly setting FP8GlobalStateManager
Signed-off-by: Guyue Huang <[email protected]>
  • Loading branch information
guyueh1 committed Feb 20, 2025
commit aa01aca65ecde97916e82269300d5e92a4097500
51 changes: 29 additions & 22 deletions nemo/collections/llm/gpt/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Dict, Literal, Optional, Union
from functools import partial

import lightning.pytorch as L
import torch
Expand All @@ -32,6 +33,7 @@
from nemo.lightning.pytorch.optim import MegatronOptimizerModule, OptimizerModule
from nemo.utils import logging
from nemo.utils.import_utils import safe_import
from nemo.utils.te_utils import te_version

_, HAVE_TE = safe_import("transformer_engine")

Expand Down Expand Up @@ -219,30 +221,35 @@ def configure_model(self, tokenizer, pre_process=None, post_process=None) -> "MC
vocab_size = get_vocab_size(self, tokenizer.vocab_size, self.make_vocab_size_divisible_by)

# Set FP8 recipe to DelayedScaling to initialize model with float8 precision.
# If not set, the default recipe is MXFP8BlockScaling which will initialize
# model with mxfp8 precision.
if self.fp8 is not None:
assert HAVE_TE, "Transformer Engine is required for FP8 training."
te_fp8, _ = safe_import("transformer_engine.pytorch.fp8")
te_recipe, _ = safe_import("transformer_engine.common.recipe")
te_fp8.FP8GlobalStateManager.FP8_RECIPE = te_recipe.DelayedScaling()

model = MCoreGPTModel(
self,
transformer_layer_spec=transformer_layer_spec,
vocab_size=vocab_size,
max_sequence_length=self.seq_length,
fp16_lm_cross_entropy=self.fp16_lm_cross_entropy,
parallel_output=self.parallel_output,
share_embeddings_and_output_weights=self.share_embeddings_and_output_weights,
position_embedding_type=self.position_embedding_type,
rotary_percent=self.rotary_percent,
rotary_base=self.rotary_base,
seq_len_interpolation_factor=self.seq_len_interpolation_factor,
pre_process=pre_process or parallel_state.is_pipeline_first_stage(),
post_process=post_process or parallel_state.is_pipeline_last_stage(),
scatter_embedding_sequence_parallel=self.scatter_embedding_sequence_parallel,
)
te_pytorch, _ = safe_import("transformer_engine.pytorch")
fp8_model_init = te_pytorch.fp8_model_init
if te_version() >= (2, 0):
# In TE 2.0, the default recipe is MXFP8BlockScaling, need to change it to DelayedScaling
te_recipe, _ = safe_import("transformer_engine.common.recipe")
recipe = te_recipe.DelayedScaling()
build_model_context = partial(fp8_model_init, recipe=recipe)
else:
build_model_context = fp8_model_init

with build_model_context():
model = MCoreGPTModel(
self,
transformer_layer_spec=transformer_layer_spec,
vocab_size=vocab_size,
max_sequence_length=self.seq_length,
fp16_lm_cross_entropy=self.fp16_lm_cross_entropy,
parallel_output=self.parallel_output,
share_embeddings_and_output_weights=self.share_embeddings_and_output_weights,
position_embedding_type=self.position_embedding_type,
rotary_percent=self.rotary_percent,
rotary_base=self.rotary_base,
seq_len_interpolation_factor=self.seq_len_interpolation_factor,
pre_process=pre_process or parallel_state.is_pipeline_first_stage(),
post_process=post_process or parallel_state.is_pipeline_last_stage(),
scatter_embedding_sequence_parallel=self.scatter_embedding_sequence_parallel,
)

# If using full TE layer, need to set TP, CP group since the module call
# is not routed through megatron core, which normally handles passing the
Expand Down