Skip to content

Commit

Permalink
Experimental Yaml configs
Browse files Browse the repository at this point in the history
  • Loading branch information
wdykas authored and jaredcasper committed Mar 5, 2024
1 parent 528d7cf commit 47cb630
Show file tree
Hide file tree
Showing 7 changed files with 803 additions and 6 deletions.
303 changes: 303 additions & 0 deletions examples/gpt3/gpt_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,303 @@
# WARNING: Yaml configs is currently an experimental feature
language_model:
# model architecture
num_layers: 24
hidden_size: 1024
num_attention_heads: 16
num_query_groups: null

ffn_hidden_size: null
kv_channels: null
hidden_dropout: 0.0
attention_dropout: 0.0
fp32_residual_connection: False

apply_residual_connection_post_layernorm: False
layernorm_epsilon: 1.e-5
layernorm_zero_centered_gamma: True
add_bias_linear: False
bias_activation_fusion: False
add_qkv_bias: False
gated_linear_unit: False
activation_func: swiglu
num_moe_experts: null
rotary_interleaved: False
window_size: null

# initialization
init_method: null
init_method_std: 0.02
output_layer_init_method: null

# mixed-precision
apply_query_key_layer_scaling: False
attention_softmax_in_fp32: False

# fusion
bias_swiglu_fusion: True
masked_softmax_fusion: True
persist_layer_norm: False
memory_efficient_layer_norm: False
bias_dropout_fusion: True
apply_rope_fusion: True

# activation recomputation
recompute_granularity: null
recompute_method: null
recompute_num_layers: null
distribute_saved_activations: null

# fp8 related
fp8: null
fp8_margin: 0
fp8_interval: 1
fp8_amax_history_len: 1
fp8_amax_compute_algo: "most_recent"
fp8_wgrad: True

# miscellaneous
clone_scatter_output_in_embedding: True

normalization: "LayerNorm" # alt value supported by TE: "RMSNorm"

# MoE related
moe_router_load_balancing_type: "aux_loss"
moe_router_topk: 2
moe_grouped_gemm: False
moe_aux_loss_coeff: 0 # 1e-2 would be a good start value for load balance loss.
moe_z_loss_coeff: null # 1e-3 would be a good start value for z-loss
moe_input_jitter_eps: null
moe_token_dropping: False

model_parallel:
# Model parallelism
tensor_model_parallel_size: 1
context_parallel_size: 1
pipeline_model_parallel_size: 1
virtual_pipeline_model_parallel_size: null
sequence_parallel: True
expert_model_parallel_size: 1

# Initialization
perform_initialization: True
use_cpu_initialization: null

# Training
fp16: False
bf16: True
params_dtype: null # Set from above arguments for core
timers: null

# Optimizations
gradient_accumulation_fusion: True
async_tensor_model_parallel_allreduce: True
tp_comm_overlap: False

# Debug Options
tp_comm_split_ag: True
tp_comm_atomic_ag: True
tp_comm_split_rs: True
tp_comm_atomic_rs: True
tp_comm_bulk_wgrad: True
tp_comm_bulk_dgrad: True

# Parallelism
finalize_model_grads_func: null

# Pipeline Parallel
pipeline_dtype: null
grad_scale_func: null
enable_autocast: False
autocast_dtype: null
variable_seq_lengths: False
num_microbatches_with_partial_activation_checkpoints: null
overlap_p2p_comm: False
batch_p2p_comm: True
batch_p2p_sync: True
use_ring_exchange_p2p: False
deallocate_pipeline_outputs: False
no_sync_func: null
grad_sync_func: null
param_sync_func: null
pipeline_model_parallel_split_rank: null

# CPU Offloading
cpu_offloading: False
cpu_offloading_num_layers: 0
_cpu_offloading_context: null
cpu_offloading_weights: False
cpu_offloading_activations: True

# Timing
barrier_with_L1_time: True

# training:
use_mcore_models: True
spec: null
micro_batch_size: 2
global_batch_size: 128
rampup_batch_size: [32, 32, 65324160]
check_for_nan_in_loss_and_grad: True
num_layers_per_virtual_pipeline_stage: null

encoder_num_layers: null
decoder_num_layers: null
rotary_seq_len_interpolation_factor: null
add_position_embedding: False
make_vocab_size_divisible_by: 128
group_query_attention: False


exit_signal_handler: False
exit_duration_in_mins: null
exit_interval: null

untie_embeddings_and_output_weights: True
position_embedding_type: rope
rotary_percent: 0.5
openai_gelu: False
squared_relu: False
swiglu: True
onnx_safe: null
bert_binary_head: True
max_position_embeddings: 4096

transformer_impl: local
use_flash_attn: False
seed: 1234
data_parallel_random_init: False

# Optimizer
optimizer: adam
lr: 2.5e-4
lr_decay_style: cosine
lr_decay_iters: null
lr_decay_samples: 255126953
lr_warmup_fraction: null
lr_warmup_iters: 0
lr_warmup_samples: 81381
lr_warmup_init: 0.0
min_lr: 2.5e-5
weight_decay: 0.1
start_weight_decay: null
end_weight_decay: null
weight_decay_incr_style: constant
clip_grad: 1.0
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.e-08
sgd_momentum: 0.9
override_opt_param_scheduler: False
use_checkpoint_opt_param_scheduler: False

# checkpointing arguments
save: null
save_interval: 20000
no_save_optim: null
no_save_rng: null
load: null
no_load_optim: null
no_load_rng: null
finetune: False
use_checkpoint_args: False
exit_on_missing_checkpoint: False

# loss arguments
loss_scale: null
initial_loss_scale: 4294967296
min_loss_scale: 1.0
loss_scale_window: 1000
hysteresis: 2
accumulate_allreduce_grads_in_fp32: False
fp16_lm_cross_entropy: False

# distributed arguments
distributed_backend: nccl
distributed_timeout_minutes: 10
overlap_grad_reduce: False
delay_grad_reduce: True
overlap_param_gather: False
delay_param_gather: False
scatter_gather_tensors_in_pipeline: True
local_rank: null
lazy_mpu_init: null
empty_unused_memory_level: 0
standalone_embedding_stage: False
use_distributed_optimizer: False
nccl_communicator_config_path: null

train_iters: null
eval_iters: 32
eval_interval: 2000
skip_train: False

adlr_autoresume: False
adlr_autoresume_interval: 1000

# garbage collection
manual_gc: False
manual_gc_interval: 0
manual_gc_eval: True

tp_comm_overlap_cfg: null

#data
data_path: null
split: '99,1,0'
train_data_path: null
valid_data_path: null
test_data_path: null
data_cache_path: null
mock_data: False
vocab_size: null
vocab_file: null
merge_file: null
vocab_extra_ids: 0
seq_length: 4096
encoder_seq_length: null
decoder_seq_length: null
retriever_seq_length: 256
sample_rate: 1.0
mask_prob: 0.15
short_seq_prob: 0.1
num_workers: 2
tokenizer_type: GPTSentencePieceTokenizer
tokenizer_model: null
reset_position_ids: False
reset_attention_mask: False
eod_mask_loss: False
train_samples: 268554688
dataloader_type: null

#profile:
profile: False
profile_ranks: [0]
profile_step_end: 12
profile_step_start: 10

#logging:
log_params_norm: True
log_num_zeros_in_grad: True
log_throughput: False
log_progress: False
timing_log_level: 0
timing_log_option: minmax
tensorboard_log_interval: 1
tensorboard_queue_size: 1000
log_timers_to_tensorboard: False
log_batch_size_to_tensorboard: False
log_learning_rate_to_tensorboard: True
log_learning_rate_to_tensorboard: True
log_validation_ppl_to_tensorboard: False
log_memory_to_tensorboard: False
log_world_size_to_tensorboard: False
log_loss_scale_to_tensorboard: True
wandb_project: ''
wandb_exp_name: ''
wandb_save_dir: ''
enable_one_logger: False
one_logger_project: e2e-tracking
one_logger_entity: hwinf_dcm
one_logger_run_name: null
log_interval: 100
tensorboard_dir: null
11 changes: 10 additions & 1 deletion megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,20 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
# Custom arguments.
if extra_args_provider is not None:
parser = extra_args_provider(parser)

# Parse.
if ignore_unknown_args:
args, _ = parser.parse_known_args()
else:
args = parser.parse_args()

# Experimental yaml
if args.yaml_cfg is not None:
from .yaml_arguments import load_yaml
assert args.yaml_cfg and args.use_mcore_models, "To use yaml, mcore must be enabled"
args = load_yaml(args.yaml_cfg)


# Args from environment
args.rank = int(os.getenv('RANK', '0'))
args.world_size = int(os.getenv("WORLD_SIZE", '1'))
Expand Down Expand Up @@ -1474,5 +1481,7 @@ def _add_experimental_args(parser):
'To use local spec specify local as the argument.'
'For more details, see the model class, '
'`transformer_block.py`, or `transformer_layer.py`')
group.add_argument('--yaml-cfg', type=str, default=None,
help = 'Config file to add additional arguments')

return parser
1 change: 0 additions & 1 deletion megatron/global_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,4 +247,3 @@ def _ensure_var_is_not_initialized(var, name):
assert var is None, '{} is already initialized.'.format(name)



7 changes: 6 additions & 1 deletion megatron/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from megatron import get_tensorboard_writer
from megatron.core import mpu, tensor_parallel
from megatron.arguments import parse_args, validate_args
from megatron.yaml_arguments import validate_yaml
from megatron.checkpointing import load_args_from_checkpoint
from megatron.global_vars import set_global_variables
from megatron.model.transformer import bias_dropout_add_fused_train
Expand Down Expand Up @@ -47,7 +48,11 @@ def initialize_megatron(
assert args.load is not None, "--use-checkpoints-args requires --load argument"
load_args_from_checkpoint(args)

validate_args(args, args_defaults)
if args.yaml_cfg is not None:
args = validate_yaml(args, args_defaults)
else:
validate_args(args, args_defaults)


# set global args, build tokenizer, and set adlr-autoresume,
# tensorboard-writer, and timers.
Expand Down
4 changes: 2 additions & 2 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ def train_step(forward_step_func, data_iterator,
torch.cuda.empty_cache()

# Vision gradients.
if args.vision_pretraining and args.vision_pretraining_type == "dino":
if getattr(args, 'vision_pretraining', False) and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0])
unwrapped_model.cancel_gradients_last_layer(args.curr_iteration)

Expand All @@ -558,7 +558,7 @@ def train_step(forward_step_func, data_iterator,
timers('optimizer').stop()

# Vision momentum.
if args.vision_pretraining and args.vision_pretraining_type == "dino":
if getattr(args, 'vision_pretraining', False) and args.vision_pretraining_type == "dino":
unwrapped_model = unwrap_model(model[0])
unwrapped_model.update_momentum(args.curr_iteration)

Expand Down
Loading

0 comments on commit 47cb630

Please sign in to comment.