Skip to content

Commit

Permalink
Add distributed checkpointing support for GroupedGEMM
Browse files Browse the repository at this point in the history
  • Loading branch information
hxbai authored and ericharper committed Jun 28, 2024
1 parent 69d7d5b commit 9aa7ce6
Show file tree
Hide file tree
Showing 3 changed files with 332 additions and 9 deletions.
172 changes: 165 additions & 7 deletions megatron/core/transformer/moe/experts.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

from typing import Tuple
from copy import deepcopy
from functools import partial
from typing import Optional

import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter

from megatron.core import parallel_state
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
from megatron.core.dist_checkpointing import ShardedTensor
from megatron.core.dist_checkpointing.mapping import ReplicaId, ShardedTensorFactory
from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding
from megatron.core.jit import jit_fuser
from megatron.core.tensor_parallel.layers import (
Expand All @@ -20,11 +23,12 @@
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.moe import grouped_gemm_util as gg
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.utils import make_sharded_object_for_checkpoint


class GroupedMLP(MegatronModule):
"""An efficient implementation of the Experts layer using CUTLASS GroupedGEMM.
This class is designed to execute multiple experts in parallel, thereby maximizing computational efficiency.
"""

Expand Down Expand Up @@ -52,6 +56,7 @@ def glu(x):
self.activation_func = self.config.activation_func

# How many feature each rank holds for fc1 and fc2, respectively.
self.moe_extended_tp = config.moe_extended_tp
if config.moe_extended_tp:
tp_size = parallel_state.get_tensor_and_expert_parallel_world_size()
else:
Expand Down Expand Up @@ -139,6 +144,18 @@ def glu(x):
setattr(self.weight1, 'allreduce', not self.expert_parallel)
setattr(self.weight2, 'allreduce', not self.expert_parallel)

def remove_extra_states_check(self, incompatible_keys):
"""
Remove _extra_state from unexpected keys.
These keys are for dist ckpt compatibility with SequentialMLP.
"""
keys = deepcopy(incompatible_keys.unexpected_keys)
for key in keys:
if '_extra_state' in key:
incompatible_keys.unexpected_keys.remove(key)

self.register_load_state_dict_post_hook(remove_extra_states_check)

def forward(self, permuted_local_hidden_states, tokens_per_expert):
if permuted_local_hidden_states.nelement() != 0:
# Reshape the weights for the grouped GEMMs.
Expand Down Expand Up @@ -168,14 +185,155 @@ def forward(self, permuted_local_hidden_states, tokens_per_expert):
return fc2_output, None

def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
raise NotImplementedError(
'Currently distributed checkpointing is not supported for GroupedMLP'
"""Maps local expert to global experts."""
if self.moe_extended_tp:
raise NotImplementedError(
'Currently distributed checkpointing is not supported for moe_extended_tp'
)

sharded_state_dict = {}
num_global_experts = (
parallel_state.get_expert_model_parallel_world_size() * self.num_local_experts
)
local_expert_indices_offset = (
parallel_state.get_expert_model_parallel_rank() * self.num_local_experts
)
tp_size = parallel_state.get_tensor_model_parallel_world_size()
tp_rank = parallel_state.get_tensor_model_parallel_rank()

prepend_axis_num = len(sharded_offsets)
replica_id = (0, 0, parallel_state.get_data_modulo_expert_parallel_rank())

@torch.no_grad()
def sh_ten_build_fn(
key: str,
t: torch.Tensor,
replica_id: ReplicaId,
flattened_range: Optional[slice],
tp_axis: int,
with_glu: bool,
):
if tp_axis == 0:
real_shape = (self.num_local_experts, self.config.hidden_size, -1)
elif tp_axis == 1:
real_shape = (self.num_local_experts, -1, self.config.hidden_size)
assert with_glu == False
else:
raise ValueError("tp_axis should be 0 or 1.")
if flattened_range is None:
t = t.view(real_shape).transpose(-1, -2)
if with_glu:
local_tensors = torch.chunk(t, 2, -2)
sub_states = [
ShardedTensor.from_rank_offsets(
key,
local_tensors[0].contiguous(),
*sharded_offsets,
(
prepend_axis_num,
parallel_state.get_expert_model_parallel_rank(),
parallel_state.get_expert_model_parallel_world_size(),
),
(prepend_axis_num + 1, tp_rank, tp_size * 2),
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
),
ShardedTensor.from_rank_offsets(
key,
local_tensors[1].contiguous(),
*sharded_offsets,
(
prepend_axis_num,
parallel_state.get_expert_model_parallel_rank(),
parallel_state.get_expert_model_parallel_world_size(),
),
(prepend_axis_num + 1, tp_size + tp_rank, tp_size * 2),
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
),
]
else:
sub_states = ShardedTensor.from_rank_offsets(
key,
t.contiguous(),
*sharded_offsets,
(
prepend_axis_num,
parallel_state.get_expert_model_parallel_rank(),
parallel_state.get_expert_model_parallel_world_size(),
),
(prepend_axis_num + 1 + tp_axis, tp_rank, tp_size),
replica_id=replica_id,
prepend_axis_num=prepend_axis_num,
)
else:
raise NotImplementedError(
'Currently GroupedMLP does not support distributed checkpointing '
'with the distributed optimizer.'
)
return sub_states

@torch.no_grad()
def sh_ten_merge_fn(sub_state_dict, tp_axis: int, with_glu: bool):
if tp_axis == 0:
weight_shape = (self.config.hidden_size, -1)
elif tp_axis == 1:
weight_shape = (-1, self.config.hidden_size)
assert with_glu == False
else:
raise ValueError("tp_axis should be 0 or 1.")
if with_glu:
sub_state_dict = torch.cat(sub_state_dict, -2)
return sub_state_dict.transpose(-1, -2).reshape(weight_shape)

state_dict = self.state_dict(prefix='', keep_vars=True)
# To align with SequentialMLP, the weight tensors are transposed,
# and the tp_axis is also for the transposed tensors
for name, tensor in state_dict.items():
if name == 'weight1':
tp_axis = 0
with_glu = self.config.gated_linear_unit
wkey = f'{prefix}experts.linear_fc1.weight'
else:
tp_axis = 1
with_glu = False
wkey = f'{prefix}experts.linear_fc2.weight'
sharded_state_dict[f'{prefix}{name}'] = ShardedTensorFactory(
wkey,
tensor,
partial(sh_ten_build_fn, tp_axis=tp_axis, with_glu=with_glu),
partial(sh_ten_merge_fn, tp_axis=tp_axis, with_glu=with_glu),
replica_id,
)

replica_id = (
0,
parallel_state.get_tensor_model_parallel_rank(),
parallel_state.get_data_modulo_expert_parallel_rank(),
)
# Add fake _extra_state to be compatible with SequentialMLP
for expert_local_idx in range(self.num_local_experts):
expert_global_idx = local_expert_indices_offset + expert_local_idx
expert_sharded_offsets = (
*sharded_offsets,
(len(sharded_offsets), expert_global_idx, num_global_experts),
)
for mod in ['linear_fc1', 'linear_fc2']:
sharded_state_dict[f'{prefix}expert{expert_global_idx}.{mod}._extra_state'] = (
make_sharded_object_for_checkpoint(
None,
f'{prefix}experts.{mod}._extra_state',
expert_sharded_offsets,
replica_id,
)
)

return sharded_state_dict


class SequentialMLP(MegatronModule):
"""An implementation of the Experts layer using a sequence of MLP layers.
This class executes each expert sequentially.
"""

Expand Down Expand Up @@ -214,7 +372,7 @@ def forward(self, permuted_local_hidden_states, tokens_per_expert):
return output_local, output_bias_local

def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
""" Maps local expert to global experts. """
"""Maps local expert to global experts."""
if self.moe_extended_tp:
raise NotImplementedError(
'Currently distributed checkpointing is not supported for moe_extended_tp'
Expand Down
4 changes: 2 additions & 2 deletions tests/functional_tests/jet_recipes/MR-gpt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ products:
- {tp_size: [2], pp_size: [1], ep_size: [2], ckpt_resume: [0, 1], extra_args: ['"--sequence-parallel --num-experts 8 --moe-router-load-balancing-type sinkhorn --moe-router-topk 1 --ckpt-fully-parallel-save --ckpt-fully-parallel-load"'], args_meta: ["te_8experts2parallel"]}
- {tp_size: [2], pp_size: [1], ep_size: [2], ckpt_resume: [0, 1], extra_args: ['"--sequence-parallel --num-experts 8 --use-distributed-optimizer --moe-router-load-balancing-type sinkhorn --moe-router-topk 1 --ckpt-fully-parallel-save --ckpt-fully-parallel-load"'], args_meta: ["te_8experts2parallel_dist_optimizer"]}
## TODO: MoE GroupedMLP dist-ckpt not supported, so must use 'torch' ckpt format
- {tp_size: [2], pp_size: [1], ep_size: [2], ckpt_resume: [0, 1], ckpt_format: [torch], extra_args: ['"--moe-grouped-gemm --disable-bias-linear --sequence-parallel --num-experts 8 --use-distributed-optimizer --moe-router-load-balancing-type sinkhorn --moe-router-topk 1 --overlap-grad-reduce --overlap-param-gather"'], moe_grouped_gemm: [1], args_meta: ["te_8experts2parallel_overlap_grad_reduce_param_gather_groupedGEMM"]}
- {tp_size: [2], pp_size: [1], ep_size: [2], ckpt_resume: [0, 1], ckpt_format: [torch], extra_args: ['"--moe-grouped-gemm --disable-bias-linear --sequence-parallel --num-experts 8 --moe-router-load-balancing-type sinkhorn --moe-router-topk 1"'], moe_grouped_gemm: [1], args_meta: ["te_8experts2parallel_groupedGEMM"]}
- {tp_size: [2], pp_size: [1], ep_size: [2], ckpt_resume: [0, 1], extra_args: ['"--moe-grouped-gemm --disable-bias-linear --sequence-parallel --num-experts 8 --use-distributed-optimizer --moe-router-load-balancing-type sinkhorn --moe-router-topk 1 --overlap-grad-reduce --overlap-param-gather"'], moe_grouped_gemm: [1], args_meta: ["te_8experts2parallel_overlap_grad_reduce_param_gather_groupedGEMM"]}
- {tp_size: [2], pp_size: [1], ep_size: [2], ckpt_resume: [0, 1], extra_args: ['"--moe-grouped-gemm --disable-bias-linear --sequence-parallel --num-experts 8 --moe-router-load-balancing-type sinkhorn --moe-router-topk 1"'], moe_grouped_gemm: [1], args_meta: ["te_8experts2parallel_groupedGEMM"]}
- {tp_size: [2], pp_size: [1], ep_size: [2], ckpt_resume: [0, 1], extra_args: ['"--disable-bias-linear --sequence-parallel --num-experts 8 --moe-router-load-balancing-type aux_loss --moe-router-topk 2 --moe-aux-loss-coeff 1e-2"'], moe_grouped_gemm: [1], args_meta: ["te_8experts2parallel_top2router"]}
- {tp_size: [1], pp_size: [1], ckpt_resume: [0, 1], extra_args: ['"--use-distributed-optimizer --async-save"'], args_meta: ["dist_optimizer"]}
- {tp_size: [1], pp_size: [1], ckpt_resume: [0, 1], extra_args: ['"--use-distributed-optimizer --no-mmap-bin-files"'], args_meta: ["dist_optimizer_no_mmap_bin_files"]}
Expand Down
Loading

0 comments on commit 9aa7ce6

Please sign in to comment.