Skip to content

Commit

Permalink
Merge branch 'main' into mblaz/add-ckpt-resume-functional-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mikolajblaz committed Apr 19, 2024
2 parents 0597c90 + ccfeda4 commit 7278310
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 12 deletions.
26 changes: 16 additions & 10 deletions megatron/core/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,21 +754,27 @@ def load_state_dict(self, state_dict):
self.param_groups += optimizer.param_groups

def disable_pre_hook(self):
if not self.config.use_distributed_optimizer or not self.config.overlap_param_gather:
raise ValueError(
"disable_pre_hook should only be called with 'use_distributed_optimizer' "
"and 'overlap_param_gather' are both enabled."
)
for optimizer in self.chained_optimizers:
if (
not optimizer.config.use_distributed_optimizer
or not optimizer.config.overlap_param_gather
):
raise ValueError(
"disable_pre_hook should only be called with 'use_distributed_optimizer' "
"and 'overlap_param_gather' both enabled."
)
optimizer.disable_pre_hook()

def enable_pre_hook(self):
if not self.config.use_distributed_optimizer or not self.config.overlap_param_gather:
raise ValueError(
"enable_pre_hook should only be called with 'use_distributed_optimizer' "
"and 'overlap_param_gather' are both enabled."
)
for optimizer in self.chained_optimizers:
if (
not optimizer.config.use_distributed_optimizer
or not optimizer.config.overlap_param_gather
):
raise ValueError(
"enable_pre_hook should only be called with 'use_distributed_optimizer' "
"and 'overlap_param_gather' both enabled."
)
optimizer.enable_pre_hook()

def step(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/functional_tests/jet_recipes/MR-gpt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ products:
- {tp_size: [2], pp_size: [1], ckpt_resume: [0, 1], extra_args: ['"--sequence-parallel --num-experts 8 --expert-model-parallel-size 2 --moe-router-load-balancing-type sinkhorn --moe-router-topk 1"'], args_meta: ["te_8experts2parallel"]}
- {tp_size: [2], pp_size: [1], ckpt_resume: [0, 1], extra_args: ['"--sequence-parallel --num-experts 8 --expert-model-parallel-size 2 --use-distributed-optimizer --moe-router-load-balancing-type sinkhorn --moe-router-topk 1"'], args_meta: ["te_8experts2parallel_dist_optimizer"]}
## TODO: MoE GroupedMLP dist-ckpt not supported, so must use 'torch' ckpt format
- {tp_size: [2], pp_size: [1], ckpt_resume: [0, 1], ckpt_format: [torch], extra_args: ['"--moe-grouped-gemm --disable-bias-linear --sequence-parallel --num-experts 8 --expert-model-parallel-size 2 --use-distributed-optimizer --moe-router-load-balancing-type sinkhorn --moe-router-topk 1 --overlap-grad-reduce"'], moe_grouped_gemm: [1], args_meta: ["te_8experts2parallel_overlap_grad_reduce_groupedGEMM"]}
- {tp_size: [2], pp_size: [1], ckpt_resume: [0, 1], ckpt_format: [torch], extra_args: ['"--moe-grouped-gemm --disable-bias-linear --sequence-parallel --num-experts 8 --expert-model-parallel-size 2 --use-distributed-optimizer --moe-router-load-balancing-type sinkhorn --moe-router-topk 1 --overlap-grad-reduce --overlap-param-gather"'], moe_grouped_gemm: [1], args_meta: ["te_8experts2parallel_overlap_grad_reduce_param_gather_groupedGEMM"]}
- {tp_size: [2], pp_size: [1], ckpt_resume: [0, 1], ckpt_format: [torch], extra_args: ['"--moe-grouped-gemm --disable-bias-linear --sequence-parallel --num-experts 8 --expert-model-parallel-size 2 --moe-router-load-balancing-type sinkhorn --moe-router-topk 1"'], moe_grouped_gemm: [1], args_meta: ["te_8experts2parallel_groupedGEMM"]}
- {tp_size: [2], pp_size: [1], ckpt_resume: [0, 1], extra_args: ['"--disable-bias-linear --sequence-parallel --num-experts 8 --expert-model-parallel-size 2 --moe-router-load-balancing-type aux_loss --moe-router-topk 2 --moe-aux-loss-coeff 1e-2"'], moe_grouped_gemm: [1], args_meta: ["te_8experts2parallel_top2router"]}
- {tp_size: [1], pp_size: [1], ckpt_resume: [0, 1], extra_args: ["--use-distributed-optimizer"], args_meta: ["dist_optimizer"]}
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [10.80961, 10.86088, 10.86703, 10.80386, 10.71988, 10.64698, 10.21161, 10.32003, 10.22052, 9.92363]}, "num-zeros": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [31228.0, 37860.0, 38327.0, 36135.0, 33138.0, 34687.0, 30217.0, 34984.0, 35952.0, 37036.0]}, "iteration_timing_avg": 0.18751352941176463}
{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [10.80961, 10.86088, 10.86703, 10.80386, 10.71988, 10.64698, 10.21161, 10.32003, 10.22052, 9.92363]}, "num-zeros": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [31228.0, 37860.0, 38327.0, 36135.0, 33138.0, 34687.0, 30217.0, 34984.0, 35952.0, 37036.0]}, "iteration_timing_avg": 0.17911029411764712}

0 comments on commit 7278310

Please sign in to comment.