diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index b089ec420d47..7dedaf4f2ec6 100755 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -17,8 +17,8 @@ from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list from deepspeed.utils import groups from deepspeed.module_inject.layers import is_autotp_training_mode - - +import os +import ast def move(tensor, device, copy=True): if tensor.is_meta: return torch.empty_like(tensor, device=device) @@ -191,6 +191,7 @@ def load(module, state_dict, prefix, mp_group=None): class AutoTP(): + moe_experts_reduce_once = False def __init__(self, module, @@ -214,6 +215,8 @@ def __init__(self, self.keep_module_on_host = keep_module_on_host def in_module_list(module, module_list): + if 'DeepseekV2' in str(type(module)): + return False for item in module_list: if type(item).__name__ == type(module).__name__: return True @@ -261,7 +264,7 @@ def update_policy_list(policy_list, new_module, new_gems): for i, policy in enumerate(policy_list): # if module already exists in policy, combine gems and remove duplicates if policy[0] == type(new_module): - new_gems = set(new_gems + policy[1]) + new_gems = list(set(new_gems + policy[1])) policy_list[i] = tuple([type(new_module), new_gems]) return policy_list policy_list.append(tuple([type(new_module), new_gems])) @@ -287,6 +290,12 @@ def tp_parser(model): module_list = [] layer_list = [] gem_list = [] + #'DS_MOE_TP_SINGLE_ALLREDUCE' is a environment variable that indicates + # whether the MoE experts adopt the reduce-once optimization. + if not AutoTP.moe_experts_reduce_once: + ds_moe_experts_reduce_once = os.environ.get('DS_MOE_TP_SINGLE_ALLREDUCE') + if ds_moe_experts_reduce_once: + AutoTP.moe_experts_reduce_once = ast.literal_eval(ds_moe_experts_reduce_once) module_list = AutoTP.get_module_list(model) assert AutoTP.supported(model), "AutoTP not supported for model. Please use kernel injection since container policy for model exists." \ @@ -309,7 +318,16 @@ def tp_parser(model): gem_list = gem_list + [layer] elif 'o_proj' in layer: gem_list = gem_list + [layer] - elif 'down_proj' in layer: + elif 'down_proj' in layer and not (('DeepseekV2' in str(type(module))) or \ + ('qwen2_moe' in str(type(module))) or \ + not AutoTP.moe_experts_reduce_once): + gem_list = gem_list + [layer] + elif 'shared_experts.down_proj' in layer and (('DeepseekV2' in str(type(module))) or \ + ('qwen2_moe' in str(type(module)))) \ + and AutoTP.moe_experts_reduce_once: + gem_list = gem_list + [layer] + elif 'mlp.down_proj' in layer and ('DeepseekV2' in str(type(module)) \ + and AutoTP.moe_experts_reduce_once): gem_list = gem_list + [layer] elif 'attention.dense' in layer and 'GPTNeoX' in str(model): gem_list = gem_list + [layer] @@ -377,7 +395,8 @@ def _replace(self, child, name, conv_linear_layer): arctic_w2_all_reduce_linear = True # For MoE MLP model, e.g., deepseek and jamba down_proj = False - if 'down_proj' in name: + #Deepseek processes different down_proj in different ways. + if 'down_proj' in name and 'DeepseekV2' not in str(type(self.module)): down_proj = True if name in self.all_reduce_linears or arctic_w2_all_reduce_linear or down_proj: @@ -390,14 +409,48 @@ def _replace(self, child, name, conv_linear_layer): return LinearAllreduce(child, self.mp_group, name=name) else: - setattr(child, "replaced", True) + # if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size] + # else [weight_shape[0] // mp_size, weight_shape[1]] if self.conv_linear_layer: - conv_LinearLayer(child, self.mp_group) - elif require_tp_fused_qkvw(name, self.mp_size): - #Check and handle fused qkv for TP - return fused_LinearLayer(child, self.mp_group, fused_module=self.module) + child.weight.data = child.weight.data.transpose(-1, -2).contiguous() - return LinearLayer(child, self.mp_group, name=name) + if require_tp_fused_qkvw(name, self.mp_size): + #Check and handle fused qkv for TP + #The copy is a regular copy, The shape of dst and src is the same + data_dc = move( + prepare_tp_fused_qkvw(self.module, child.weight.data, self.mp_size, mp_replace.gpu_index), + device_name, return_new_copy) + + bias_data_dc = None if child.bias is None else move( + prepare_tp_fused_qkvw(self.module, child.bias.data, self.mp_size, mp_replace.gpu_index), + device_name, return_new_copy) + else: + if ('shared_experts.down_proj' not in name and 'mlp.down_proj' not in name and 'down_proj' in name \ + and ('DeepseekV2' in str(type(self.module)) or 'qwen2_moe' in str(type(self.module))) \ + and AutoTP.moe_experts_reduce_once ): + data = child.weight.data.split(get_shard_size_list(weight_shape[1], self.mp_size), dim=1) + data_dc = move(data[mp_replace.gpu_index], get_accelerator().current_device_name()).detach() + del data + bias_data_dc = None if child.bias is None else \ + torch.nn.parameter.Parameter(move(child.bias, get_accelerator().current_device_name())) + else: + data = child.weight.data.split(get_shard_size_list(weight_shape[0], self.mp_size, name), + dim=1 if self.conv_linear_layer else 0) + data_dc = move(data[mp_replace.gpu_index], device_name, return_new_copy).detach() + del data + + if child.bias is not None: + bias_data = child.bias.data.split(get_shard_size_list( + weight_shape[1] if self.conv_linear_layer else weight_shape[0], self.mp_size, name), + dim=0) + bias_data = move(bias_data[mp_replace.gpu_index], device_name, return_new_copy) + bias_data_dc = torch.nn.parameter.Parameter(bias_data, requires_grad=False) + del bias_data + else: + bias_data_dc = None + + setattr(child, "replaced", True) + return LinearLayer(weight=torch.nn.parameter.Parameter(data_dc, requires_grad=False), bias=bias_data_dc) def _slice_embedding(self, child, name, conv_linear_layer): if getattr(child, "replaced", False) == True: