Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

loss divergence when CP>1 and MBS>1 #12210

Open
hawkoli1987 opened this issue Feb 17, 2025 · 0 comments
Open

loss divergence when CP>1 and MBS>1 #12210

hawkoli1987 opened this issue Feb 17, 2025 · 0 comments
Labels
bug Something isn't working

Comments

@hawkoli1987
Copy link

Describe the bug

When we use Context Parallel > 1 and MBS>1 in LLM Pretraining, the loss curve starts to diverge.

Image

Steps/Code to reproduce bug

Please list minimal steps or code snippet for us to be able to reproduce the bug.

  1. in H100 compute cluster login node, run "sweep.sh", which will call the 'slurm_script' iteratively to submit training jobs in slurm. Each job uses 2 nodes and 1 GPU on each node, with the same TP1,PP1,CP2 configuration, but the MBS varies from 1 to 2 to 4 in each job.
#!/bin/bash

TP="1"
PP="1" # 2
CP="2" # 2
MBS="1 2 4"
NODES="2"

node_id="[0-4]"

for n_nodes in $NODES; do
    for tp in $TP; do
        for pp in $PP ; do
            for cp in $CP; do
                for mbs in $MBS; do
                    
                    export tp
                    export pp
                    export cp
                    export mbs
                    export n_nodes
                    export n_gpus=1

                    echo "Current NODES: $n_nodes"
                    echo "Current TP: $tp"
                    echo "Current PP: $pp"
                    echo "Current CP: $cp"
                    echo "Current MBS: $mbs"
                    
                    for i in {1}; do
                        sbatch -J N${n_nodes}-G${n_gpus}-mbs${mbs}-gbs512-tp${tp}-pp${pp}-cp${cp}-llama3.2-1B-resume \
                            --nodes=${n_nodes} \
                            --gres=gpu:${n_gpus} \
                            --nodelist=a3mega-a3meganodeset-${node_id} \
                            nemo-launch.slurm
                    done
                done
            done 
        done
    done
done
  1. As defined in the nemo-launch.slurm, each slurm job runs inside an enroot container, with shared drive directory mounted into the container.
#!/bin/bash

#SBATCH --ntasks-per-node=1
#SBATCH --time=1440:00:00
#SBATCH --output=log/%j-%x/output.log
#SBATCH --error=log/%j-%x/error.log

# Usage: sbatch nemo-launch.sh

set -eoux pipefail

# AWS config
export TZ='Asia/Singapore'
export SHARED_FS_DIR='/shared/all/'
export log_dir="$(pwd)/log/${SLURM_JOB_ID}-${SLURM_JOB_NAME}"
export NEMO_HOME=${SHARED_FS_DIR}/.cache/nemo


[ -d $log_dir ] || mkdir -p $log_dir

cd ${SLURM_SUBMIT_DIR}
# NCCL Variables
# export NCCL_IB_HCA="=mlx5_0,mlx5_3,mlx5_4,mlx5_5,mlx5_6,mlx5_7,mlx5_8,mlx5_9"

# nemo
export CURRENT_TIME=$(date '+%Y.%m.%d_%H.%M.%S')
export container_name="nvidia_${SLURM_JOBID}"
export sqsh_dir="${SHARED_FS_DIR}/sqsh/megagpu-nemo-24.12.sqsh"
export MAX_JOBS=$(nproc)
export nemo_dir="${SHARED_FS_DIR}/source_files/NeMo/"
export TOKENIZERS_PARALLELISM=false
export NCCL_DEBUG=INFO
export NCCL_DEBUG_FILE="${log_dir}/nccl/$(hostname).log"
# export GLOO_SOCKET_IFNAME="bond0.400,bond0.502"

WANBD_TOKEN_PATH=$HOME/.wandb

# check if token present
if [ -f $WANBD_TOKEN_PATH ]; then
    export WANDB_API_KEY=$(cat $WANBD_TOKEN_PATH)
else
    echo "Wandb token not found"
    exit 1
fi
echo $PWD
bash_script="$(pwd)/torchrun_yuli.sh"
python_script="$(pwd)/launch_yuli3.py"

# archive scripts
cp $python_script $log_dir/launch_yuli3.py
cp $bash_script $log_dir/torchrun_yuli.sh
cp $0 $log_dir
cp sweep.sh $log_dir/sweep.sh
chmod +rx $log_dir/*

# use the copied script to prevent issues when accidentally modifying the original script
bash_script="${log_dir}/torchrun_yuli.sh"
export python_script="${log_dir}/launch_yuli3.py"

num_gpus_pernode=${SLURM_GPUS_ON_NODE}
num_node=${SLURM_JOB_NUM_NODES}
num_gpus=$(( $num_gpus_pernode * $num_node ))
master_addr=$(scontrol show hostnames ${SLURM_JOB_NODELIST} | head -n 1)
master_port=$((10000 + $RANDOM % 9000))

export PYTHONPATH="${nemo_dir}"

container_mounts=(
    "${SHARED_FS_DIR}"
)

HOST_VARS=$(sed 's/ \{1,\}/,/g' <<<"${!NCCL*} ${!FI*}")
container_mounts_str=$(IFS=,; echo "${container_mounts[*]}")

srun_args=" \
    --export=ALL \
    --container-mount-home
    --container-image=${sqsh_dir} \
    --container-mounts=${container_mounts_str} \
    --container-env=${HOST_VARS}PYTHONPATH \
    --container-writable \
    --container-workdir=$(pwd) \
    --wait=60 \
    --kill-on-bad-exit=1
    "

launch_cmd="\
    bash ${bash_script} \
        ${num_gpus_pernode} \
        ${num_gpus} \
        ${num_node} \
        ${master_addr} \
        ${master_port} \
        \$SLURM_PROCID \
        "
echo ${launch_cmd[@]}
# for debugging
> "${log_dir}/main.log"

srun ${srun_args} --jobid ${SLURM_JOB_ID} bash -c "${launch_cmd}" \
    2>&1 | tee -a "${log_dir}/main.log"
  1. the following shell script torchrun_yuli.sh will trigger start the multi-node distributed training
#!/bin/bash
set -euox pipefail

pip install wandb==0.19.1 > /dev/null 2>&1

export PYTORCH_CUDA_ALLOC_CONF="max_split_size_mb:128"
export CUDA_DEVICE_MAX_CONNECTIONS=1
export NVTE_FLASH_ATTN=0
export NVTE_FUSED_ATTN=1
export NVTE_UNFUSED_ATTN=0
# export CUDNN_LOGERR_DBG=1
# export CUDNN_LOGDEST_DBG=stderr
# debug transformer engine, also to check if using flash attention for gemma
export NVTE_DEBUG_LEVEL=0
export NVTE_DEBUG=0
export HF_HOME="${SHARED_FS_DIR}/.cache/huggingface"
export HF_HUB_OFFLINE=0
export HF_TOKEN=$(cat $HOME/.hf_token)
echo $HF_TOKEN
export TORCH_HOME=/shared/all/.cache/torch

mkdir -p "${log_dir}/nccl"

env | grep NEMO

echo "NCCL ENVS:"
env | grep NCCL
env | grep ${log_dir}

num_gpus_pernode=${1:-$(nvidia-smi -L | wc -l)}
num_gpus=${2:-$num_gpus_pernode}
num_node=${3:-1}
master_addr=${4:-'127.0.0.1'}
master_port=${5:-$((10000 + $RANDOM % 9000))}
node_rank=${6:-0}

distributed_args=(
    --nnodes ${num_node}
    --node_rank ${node_rank}
    --nproc_per_node ${num_gpus_pernode}
    --rdzv_endpoint "${master_addr}:${master_port}"
    --rdzv_backend c10d
)

script_args=(
    "${python_script}"
    --num-nodes ${num_node}
    --gpus-per-node ${num_gpus_pernode}
)

# TODO: Move all into sbatch script

log_dir=${log_dir:-"./log/debug-cpt"}
mkdir -p ${log_dir}
echo ${distributed_args[@]}
torchrun ${distributed_args[@]} ${script_args[@]}
    2>&1 | tee ${log_dir}/${node_rank}.log
  1. the detailed training configurations are defined in launch_yuli3.py, which uses the same dataset, starts the continuous pretraining from the base model, each for 50 steps:
import os
import argparse
import nemo_run as run
from nemo.collections import llm
from nemo import lightning as nl

from nemo.collections.llm.gpt.model.llama import *
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
from nemo.collections.llm.gpt.data import PreTrainingDataModule
from megatron.core.optimizer import OptimizerConfig
from nemo.collections.common.metrics.perf_metrics import FLOPsMeasurementCallback
from nemo.lightning.pytorch.callbacks import ModelCheckpoint
from nemo.utils.callbacks import NeMoModelCheckpoint

from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
from lightning.pytorch import seed_everything
import torch

'''
from nemo.collections import llm
llm.import_ckpt(model=llm.LlamaModel(llm.Llama31Config8B()), source='hf://meta-llama/Meta-Llama-3.1-8B')
'''

if __name__ == "__main__":

    seed_everything(1234)

    parser = argparse.ArgumentParser(description="NeMo2.0 Pretraining")
    parser.add_argument(
        "--num-nodes", type=int, help="Number of nodes to use for training", default=1,
    )
    parser.add_argument(
        "--gpus-per-node", type=int, help="Number of GPUs per node to use for training", default=8,
    )
    args = parser.parse_args()

    # 
    SHARED_FS_DIR = os.path.abspath(os.environ["SHARED_FS_DIR"])
    DATA_PATH = "data/mega_l3/en-hq/books/text_document"
    # hf://meta-llama/Llama-3.1-8B meta-llama/Llama-3.3-70B-Instruct
    tokenizer = get_nmt_tokenizer("huggingface", "meta-llama/Llama-3.2-1B")
    global_batch_size = 512
    micro_batch_size = int(os.getenv("mbs", "2"))
    cp_size = int(os.getenv("cp", "1"))
    tp_size = int(os.getenv("tp", "2"))
    pp_size = int(os.getenv("pp", "1"))
    job_id = os.getenv("SLURM_JOB_ID")

    data = PreTrainingDataModule(
        paths=os.path.join(
            SHARED_FS_DIR,
            DATA_PATH
        ),
        global_batch_size=global_batch_size,
        micro_batch_size=micro_batch_size,
        num_workers=8,
        pin_memory=True,
        seq_length=2048,
        tokenizer=tokenizer
    )
    strategy = nl.MegatronStrategy(
        tensor_model_parallel_size=tp_size,
        pipeline_model_parallel_size=pp_size,
        virtual_pipeline_model_parallel_size=None,
        pipeline_dtype=torch.bfloat16,
        context_parallel_size=cp_size,
        sequence_parallel=tp_size > 1, # only enable if tp_size > 1
        replace_progress_bar=True,
        progress_interval=1,
        enable_nemo_ckpt_io=True,
    )
    # not too familiar with exp manager and using the new format for NeMo
    callbacks = [
        FLOPsMeasurementCallback({
        }
        ),
        ModelCheckpoint(
            # always_save_nemo = True,
            every_n_train_steps = 100,
            save_last = True,
            save_top_k = -1,
            monitor = None,
        )
    ]
    
    trainer = nl.Trainer(
        num_nodes= args.num_nodes,
        devices=args.gpus_per_node,
        accelerator="gpu",
        strategy=strategy,
        plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"),
        max_epochs=None,
        max_steps=50,
        max_time="05:00:00:00",
        log_every_n_steps=1,
        # accumulate_grad_batches=1, # do not modify, grad acc is automatic for training megatron models
        # gradient_clip_val=1.0,
        limit_val_batches = 2,
        limit_test_batches = 2,
        # callbacks=callbacks,
        )

    opt = nl.MegatronOptimizerModule(
        config=OptimizerConfig(
            optimizer="adam",
            lr=1e-5,
            use_distributed_optimizer=True,
            bf16 = True,
        ),
    )

    model = llm.GPTModel(
        config=Llama32Config1B(),
        tokenizer=tokenizer,
    )

    resume = nl.AutoResume(
        restore_config=nl.RestoreConfig(path="nemo://meta-llama/Llama-3.2-1B"),
        # restore_config=nl.RestoreConfig(path="/shared/aisingapore/source_files/NeMo/model/hf_format"),
        resume_if_exists=False,
    )

    nemo_logger = nl.NeMoLogger(
        name="llama3_cpt",
        wandb=WandbLogger(
            project="finding-nemo",
            name=os.getenv("SLURM_JOB_NAME"),
            entity="aisg-arf",
        ),
        tensorboard=TensorBoardLogger(
            save_dir="tensorboard_logs",
            name="llama3_baseline",
            version=os.getenv("SLURM_JOB_ID", "0"),
        ),
    )

    llm.train(
        model=model,
        data=data,
        trainer=trainer,
        log=nemo_logger,
        tokenizer=tokenizer,
        optim=opt,
        resume=resume,
        )

Expected behavior

All 3 different training configurations, which differs only by the MBS, are expected to generate the same loss curve.

Environment overview (please complete the following information)

  • Environment location: GCP cloud with multi-node H100 cluster, Slurm scheduler, Enroot container for the training job.
  • Method of NeMo install: using the sqsh image "megagpu-nemo-24.12.sqsh"
  • If method of install is [Docker], provide docker pull & docker run commands used
    <pending, will confirm this by 20 Feb>

Environment details

If NVIDIA docker image is used you don't need to specify these.

Additional context

  • The above parallelism configuration (TP=1, PP=1, CP=2) was applied with 2 nodes and 1 GPU on each node. However, we have tested with 1 node, 2GPU configuration, and similar divergence is also observed
Image
  • We have tested a range of configurations with different degree of parallelism and cluster setup (as shown in the job names in the graph), none of them had the same observation
Image
@hawkoli1987 hawkoli1987 added the bug Something isn't working label Feb 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant