You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When we use Context Parallel > 1 and MBS>1 in LLM Pretraining, the loss curve starts to diverge.
Steps/Code to reproduce bug
Please list minimal steps or code snippet for us to be able to reproduce the bug.
in H100 compute cluster login node, run "sweep.sh", which will call the 'slurm_script' iteratively to submit training jobs in slurm. Each job uses 2 nodes and 1 GPU on each node, with the same TP1,PP1,CP2 configuration, but the MBS varies from 1 to 2 to 4 in each job.
the detailed training configurations are defined in launch_yuli3.py, which uses the same dataset, starts the continuous pretraining from the base model, each for 50 steps:
importosimportargparseimportnemo_runasrunfromnemo.collectionsimportllmfromnemoimportlightningasnlfromnemo.collections.llm.gpt.model.llamaimport*fromnemo.collections.nlp.modules.common.tokenizer_utilsimportget_nmt_tokenizerfromnemo.collections.llm.gpt.dataimportPreTrainingDataModulefrommegatron.core.optimizerimportOptimizerConfigfromnemo.collections.common.metrics.perf_metricsimportFLOPsMeasurementCallbackfromnemo.lightning.pytorch.callbacksimportModelCheckpointfromnemo.utils.callbacksimportNeMoModelCheckpointfrompytorch_lightning.loggersimportWandbLogger, TensorBoardLoggerfromlightning.pytorchimportseed_everythingimporttorch'''from nemo.collections import llmllm.import_ckpt(model=llm.LlamaModel(llm.Llama31Config8B()), source='hf://meta-llama/Meta-Llama-3.1-8B')'''if__name__=="__main__":
seed_everything(1234)
parser=argparse.ArgumentParser(description="NeMo2.0 Pretraining")
parser.add_argument(
"--num-nodes", type=int, help="Number of nodes to use for training", default=1,
)
parser.add_argument(
"--gpus-per-node", type=int, help="Number of GPUs per node to use for training", default=8,
)
args=parser.parse_args()
# SHARED_FS_DIR=os.path.abspath(os.environ["SHARED_FS_DIR"])
DATA_PATH="data/mega_l3/en-hq/books/text_document"# hf://meta-llama/Llama-3.1-8B meta-llama/Llama-3.3-70B-Instructtokenizer=get_nmt_tokenizer("huggingface", "meta-llama/Llama-3.2-1B")
global_batch_size=512micro_batch_size=int(os.getenv("mbs", "2"))
cp_size=int(os.getenv("cp", "1"))
tp_size=int(os.getenv("tp", "2"))
pp_size=int(os.getenv("pp", "1"))
job_id=os.getenv("SLURM_JOB_ID")
data=PreTrainingDataModule(
paths=os.path.join(
SHARED_FS_DIR,
DATA_PATH
),
global_batch_size=global_batch_size,
micro_batch_size=micro_batch_size,
num_workers=8,
pin_memory=True,
seq_length=2048,
tokenizer=tokenizer
)
strategy=nl.MegatronStrategy(
tensor_model_parallel_size=tp_size,
pipeline_model_parallel_size=pp_size,
virtual_pipeline_model_parallel_size=None,
pipeline_dtype=torch.bfloat16,
context_parallel_size=cp_size,
sequence_parallel=tp_size>1, # only enable if tp_size > 1replace_progress_bar=True,
progress_interval=1,
enable_nemo_ckpt_io=True,
)
# not too familiar with exp manager and using the new format for NeMocallbacks= [
FLOPsMeasurementCallback({
}
),
ModelCheckpoint(
# always_save_nemo = True,every_n_train_steps=100,
save_last=True,
save_top_k=-1,
monitor=None,
)
]
trainer=nl.Trainer(
num_nodes=args.num_nodes,
devices=args.gpus_per_node,
accelerator="gpu",
strategy=strategy,
plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
max_epochs=None,
max_steps=50,
max_time="05:00:00:00",
log_every_n_steps=1,
# accumulate_grad_batches=1, # do not modify, grad acc is automatic for training megatron models# gradient_clip_val=1.0,limit_val_batches=2,
limit_test_batches=2,
# callbacks=callbacks,
)
opt=nl.MegatronOptimizerModule(
config=OptimizerConfig(
optimizer="adam",
lr=1e-5,
use_distributed_optimizer=True,
bf16=True,
),
)
model=llm.GPTModel(
config=Llama32Config1B(),
tokenizer=tokenizer,
)
resume=nl.AutoResume(
restore_config=nl.RestoreConfig(path="nemo://meta-llama/Llama-3.2-1B"),
# restore_config=nl.RestoreConfig(path="/shared/aisingapore/source_files/NeMo/model/hf_format"),resume_if_exists=False,
)
nemo_logger=nl.NeMoLogger(
name="llama3_cpt",
wandb=WandbLogger(
project="finding-nemo",
name=os.getenv("SLURM_JOB_NAME"),
entity="aisg-arf",
),
tensorboard=TensorBoardLogger(
save_dir="tensorboard_logs",
name="llama3_baseline",
version=os.getenv("SLURM_JOB_ID", "0"),
),
)
llm.train(
model=model,
data=data,
trainer=trainer,
log=nemo_logger,
tokenizer=tokenizer,
optim=opt,
resume=resume,
)
Expected behavior
All 3 different training configurations, which differs only by the MBS, are expected to generate the same loss curve.
Environment overview (please complete the following information)
Environment location: GCP cloud with multi-node H100 cluster, Slurm scheduler, Enroot container for the training job.
Method of NeMo install: using the sqsh image "megagpu-nemo-24.12.sqsh"
If method of install is [Docker], provide docker pull & docker run commands used
<pending, will confirm this by 20 Feb>
Environment details
If NVIDIA docker image is used you don't need to specify these.
Additional context
The above parallelism configuration (TP=1, PP=1, CP=2) was applied with 2 nodes and 1 GPU on each node. However, we have tested with 1 node, 2GPU configuration, and similar divergence is also observed
We have tested a range of configurations with different degree of parallelism and cluster setup (as shown in the job names in the graph), none of them had the same observation
The text was updated successfully, but these errors were encountered:
Describe the bug
When we use Context Parallel > 1 and MBS>1 in LLM Pretraining, the loss curve starts to diverge.
Steps/Code to reproduce bug
Please list minimal steps or code snippet for us to be able to reproduce the bug.
Expected behavior
All 3 different training configurations, which differs only by the MBS, are expected to generate the same loss curve.
Environment overview (please complete the following information)
docker pull
&docker run
commands used<pending, will confirm this by 20 Feb>
Environment details
If NVIDIA docker image is used you don't need to specify these.
Additional context
The text was updated successfully, but these errors were encountered: