Skip to content

Commit

Permalink
Merge branch 'support_dp_overlap_for_moe' into 'main'
Browse files Browse the repository at this point in the history
[MoE] Support --overlap-grad-reduce with GroupedMLP

See merge request ADLR/megatron-lm!1194
  • Loading branch information
jaredcasper committed Apr 3, 2024
2 parents d107813 + fa5336a commit d585182
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 3 deletions.
12 changes: 10 additions & 2 deletions megatron/core/transformer/moe/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,17 @@ def forward(self, permuted_local_hidden_states, tokens_per_expert):

fc2_output = gg.ops.gmm(intermediate_parallel, w2, tokens_per_expert, trans_b=False)
else:
# None token is allocated for local experts.
# No token is allocated for local experts.
assert torch.count_nonzero(tokens_per_expert) == 0
fc2_output = permuted_local_hidden_states

# Make sure parameters still have gradients when no tokens are routed to this set of experts.
w1 = self.weight1.view(self.config.hidden_size, -1)
w2 = self.weight2.view(-1, self.config.hidden_size)
h = torch.matmul(permuted_local_hidden_states, w1)
h = self.activation_func(h)
h = torch.matmul(h, w2)

fc2_output = h

return fc2_output, None

Expand Down
1 change: 1 addition & 0 deletions tests/functional_tests/jet_recipes/MR-gpt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ products:
# - {tp_size: [2], pp_size: [1,2], extra_args: ['"--context-parallel-size 2 --sequence-parallel --hidden-dropout 0.0 --attention-dropout 0.0"']} # TODO: need updated container with TE > 1.0.0
- {tp_size: [2], pp_size: [1], extra_args: ['"--sequence-parallel --num-experts 8 --expert-model-parallel-size 2 --moe-router-load-balancing-type sinkhorn --moe-router-topk 1"'], args_meta: ["te_8experts2parallel"]}
- {tp_size: [2], pp_size: [1], extra_args: ['"--sequence-parallel --num-experts 8 --expert-model-parallel-size 2 --use-distributed-optimizer --moe-router-load-balancing-type sinkhorn --moe-router-topk 1"'], args_meta: ["te_8experts2parallel_dist_optimizer"]}
- {tp_size: [2], pp_size: [1], extra_args: ['"--moe-grouped-gemm --disable-bias-linear --sequence-parallel --num-experts 8 --expert-model-parallel-size 2 --use-distributed-optimizer --moe-router-load-balancing-type sinkhorn --moe-router-topk 1 --overlap-grad-reduce"'], moe_grouped_gemm: [1], args_meta: ["te_8experts2parallel_overlap_grad_reduce_groupedGEMM"]}
- {tp_size: [2], pp_size: [1], extra_args: ['"--moe-grouped-gemm --disable-bias-linear --sequence-parallel --num-experts 8 --expert-model-parallel-size 2 --moe-router-load-balancing-type sinkhorn --moe-router-topk 1"'], moe_grouped_gemm: [1], args_meta: ["te_8experts2parallel_groupedGEMM"]}
- {tp_size: [2], pp_size: [1], extra_args: ['"--disable-bias-linear --sequence-parallel --num-experts 8 --expert-model-parallel-size 2 --moe-router-load-balancing-type aux_loss --moe-router-topk 2 --moe-aux-loss-coeff 1e-2"'], moe_grouped_gemm: [1], args_meta: ["te_8experts2parallel_top2router"]}
- {tp_size: [1], pp_size: [1], extra_args: ["--use-distributed-optimizer"], args_meta: ["dist_optimizer"]}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [10.80961, 10.86088, 10.86703, 10.80386, 10.71988, 10.64698, 10.21161, 10.32003, 10.22052, 9.92363]}, "num-zeros": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [31228.0, 37860.0, 38327.0, 36135.0, 33138.0, 34687.0, 30217.0, 34984.0, 35952.0, 37036.0]}, "iteration_timing_avg": 0.18751352941176463}
20 changes: 19 additions & 1 deletion tests/unit_tests/transformer/moe/test_grouped_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def setup_method(self, method, use_cpu_initialization=False, swiglu=True):
print("============")
Utils.initialize_model_parallel(1,1)
num_layers = 1 # 2
self.hidden_size = 2 # 12
self.hidden_size = 16 # must be an multiple of 16, otherwise trigger CUTLASS misaligned issue
self.num_experts = 2
self.gated_linear_unit = swiglu
self.activation_func = F.silu if swiglu else F.gelu
Expand Down Expand Up @@ -161,6 +161,24 @@ def test_gpu_forward_with_no_tokens_allocated(self):
print("Expected error message from groupedGEMM:", e)
assert str(e) == "Input batch_sizes should not be all zeros!"

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not DEVICE_CAPABILITY or DEVICE_CAPABILITY[0] < 8, reason='GroupedGEMM kernels are not supported on this device.'
)
def test_gradient_with_no_tokens_allocated(self):
"""Test that when no token is passed in, the parameters of the grouped MLP will also have gradients."""
self.grouped_mlp.cuda()
num_allocated_tokens = 0
tokens_per_expert = torch.zeros(self.num_experts)
hidden_states = torch.rand((num_allocated_tokens, self.hidden_size), dtype=torch.bfloat16)
hidden_states = hidden_states.cuda()
output_gmm, _ = self.grouped_mlp.experts(
hidden_states,
tokens_per_expert=tokens_per_expert,
)
output_gmm.mean().backward()
assert self.grouped_mlp.experts.weight1.grad is not None


if __name__ == "__main__":
for use_cpu_unitilization in [True, False]:
Expand Down

0 comments on commit d585182

Please sign in to comment.