Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabled high-performance Automatic Tensor Parallelism (auto TP) for the MoE models on multiple GPUs/HPUs #6964

Open
wants to merge 64 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
c9b12af
Reduced the experts allreduce number per layer to ONCE for the Qwen2-…
gyou2021 Jan 21, 2025
590ea36
Fixed format
gyou2021 Jan 21, 2025
889c275
Removed print
gyou2021 Jan 21, 2025
2ec6c34
Fix a bug about set.
gyou2021 Jan 21, 2025
504d696
Add the missing view operations from sequence parallel(async). (#6750)
inkcherry Jan 21, 2025
c266dc9
Update `torch.norm` to `torch.linalg.norm` and `torch.linalg.vector_n…
loadams Jan 21, 2025
ae12993
Using explicit GPU upcast for ZeRO-Offload (#6962)
xylian86 Jan 21, 2025
deb09a3
Update version.txt after 0.16.3 release (#6965)
loadams Jan 21, 2025
128d436
Precisely track nvme optimizer offload (#6963)
tjruwase Jan 23, 2025
864472b
Update build_win.bat script to exclue GDS op as it lacks Windows supp…
loadams Jan 24, 2025
1ac398c
Add CUDA 12.8 support and comment on CUDA 12.7 (#6975)
loadams Jan 28, 2025
eda53d8
Update torch versions to support 2.6 (#6977)
loadams Jan 29, 2025
112a7c6
generalize deepspeed linear and implement it for non cuda systems (#6…
oelayan7 Jan 29, 2025
7d2c5fe
Update recommended Windows whl building versions (#6983)
loadams Jan 30, 2025
f1d326c
Title: Fix setup_env_ranks to Properly Set Environment Variables Inst…
fabiosanger Jan 30, 2025
46545d7
Specify torchvision in nv-ds-chat workflow (prevents errors with torc…
loadams Jan 30, 2025
af1ba94
Remove assumption that padding only occurs on last rank (#6974)
xylian86 Jan 31, 2025
e235921
Use ds-specific module id to avoid conflicts (#6847)
tjruwase Jan 31, 2025
f5e9796
Update A6000 workflows to use newer docker container - 24.09 vs 24.03…
loadams Jan 31, 2025
07634b9
Allow NVIDIA Blackwell (#6991)
fabiendupont Feb 4, 2025
0e57fa0
Update GH org references (#6998)
tjruwase Feb 5, 2025
e86c0c3
Update CNAME
loadams Feb 5, 2025
0d7f0eb
Update CNAME
loadams Feb 5, 2025
cd8a988
[XPU] max1100 workflow update for docker and softwares (#7003)
Liangliang-Ma Feb 5, 2025
18c712f
autotp training(fix dco) (#7004)
inkcherry Feb 5, 2025
c5bf6f6
import triton files when triton is supported and installed (#6989)
oelayan7 Feb 6, 2025
590de5f
Update A6000 tests transformers version (#7016)
loadams Feb 8, 2025
693c39f
Fix ds-chat CI regression (#7015)
tjruwase Feb 10, 2025
322a05a
[Ulysses tutorial] typos (#7024)
stas00 Feb 11, 2025
8869d78
fix hostname -I for macOS #6497 (#6990)
fitzjalen Feb 12, 2025
e4d03af
Update workflows to cuda 12.4 (#7000)
loadams Feb 12, 2025
8c6251d
[ROCm] Enable fp_quantizer on ROCm (#7027)
rraminen Feb 13, 2025
e3e179c
add gds chinese blog (#7034)
GuanhuaWang Feb 13, 2025
fd2787b
Add chinese blog for deepspeed windows, and fix format (#7035)
hwchen2017 Feb 14, 2025
ba8ef57
AIO on ROCM (#7023)
jomayeri Feb 14, 2025
f4b0f58
Control trace cache warnings (#7039)
tjruwase Feb 18, 2025
3ca3e2f
Update CUDA compute capability to support Blackwell (#7047)
hwchen2017 Feb 18, 2025
5612778
Update setup.py handling of ROCm cupy (#7051)
loadams Feb 19, 2025
af8c190
nv-ds-chat breaks with latest transformers (#7052)
loadams Feb 19, 2025
225471a
Rename aio_thread_count to intra_op_parallelism (#7056)
tjruwase Feb 19, 2025
1df293a
add autoTP training zero2 tests (#7049)
inkcherry Feb 19, 2025
94abf68
Fix, bf16 optimizer remove dup loop (#7054)
wukong1992 Feb 20, 2025
4a4ff9b
Update version.txt after 0.16.4 release (#7063)
loadams Feb 20, 2025
e5eda47
fix an outdated doc wrt CUDA_VISIBLE_DEVICES (#7058)
stas00 Feb 20, 2025
675ec9a
Tecorigin sdaa accelerator (#6903)
siqi654321 Feb 20, 2025
81c1fee
Handle special case of libuv for Windows (#7064)
loadams Feb 20, 2025
17f544c
Update README with info on newest accelerator (#7065)
loadams Feb 21, 2025
20fd872
Bug Fix for offload_states API (#7050)
U-rara Feb 21, 2025
0b289a2
Fix TOCTOU issues, switch to fstat (#7067)
loadams Feb 24, 2025
4a86d02
config torch to avoid graph breaks caused by logger (#6999)
ShellyNR Feb 24, 2025
594b5bb
Fix meta load tensor imcompatible issue (#7073)
Yejing-Lai Feb 24, 2025
a843e39
Replace calls to `python setup.py sdist` with `python -m build --sdis…
loadams Feb 24, 2025
4cbc52c
Revert "Handle special case of libuv for Windows (#7064)" (#7076)
loadams Feb 25, 2025
586e436
Add DeepseekV3 AutoTP. (#7045)
Yejing-Lai Feb 26, 2025
5e379ad
Improve inference tutorial docs (#7083)
loadams Feb 26, 2025
13bf866
Added support for the environment variable DS_MOE_EXPERTS_REDUCE_ONCE…
gyou2021 Feb 27, 2025
d5115be
Changed env variable name to 'DS_MOE_TP_SINGLE_ALLREDUCE'
gyou2021 Feb 28, 2025
f0044cb
Pin transformers version on tests that use latest. (#7085)
loadams Feb 27, 2025
16ad5fd
Update README.md with ICS '23 MoE paper link (#7087)
siddharth9820 Feb 27, 2025
47d4420
Update parallelism for nv-torch-latest/nightly tests due to more GPUs…
loadams Feb 27, 2025
b3c64dd
Remove workflows for very old torch versions (#7090)
loadams Feb 28, 2025
9b1fe98
Fixed conflicts
gyou2021 Feb 28, 2025
6b96dd9
Update auto_tp.py
gyou2021 Mar 5, 2025
e7883e7
Merge branch 'master' into autoTP_Qwen2Moe_DeepSeekv2
hwchen2017 Mar 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
autotp training(fix dco) (#7004)
Same as [this PR](#6922).
[affeb88](affeb88)
I noticed the CI updated the DCO check recently. Using the suggested
rebase method for sign-off would reintroduce many conflicts, so I opted
for a squash merge with sign-off instead. thanks: )

Signed-off-by: inkcherry <[email protected]>
  • Loading branch information
inkcherry authored and gyou2021 committed Feb 28, 2025
commit 18c712fc88ad0ef1b6440d88f9cafea8e71d9a18
33 changes: 32 additions & 1 deletion deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from .runtime.config import DeepSpeedConfig, DeepSpeedConfigError
from .runtime.activation_checkpointing import checkpointing
from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from .module_inject import replace_transformer_layer, revert_transformer_layer
from .module_inject import replace_transformer_layer, revert_transformer_layer, set_autotp_mode

from .utils import log_dist, OnDevice, logger
from .comm.comm import init_distributed
Expand Down Expand Up @@ -364,3 +364,34 @@ def init_inference(model, config=None, **kwargs):
engine = InferenceEngine(model, config=ds_inference_config)

return engine


def tp_model_init(model, tp_size, dtype):
"""
Initialize the model for tensor parallelism.

Args:
model (torch.nn.Module): The model to be initialized.
tp_size (int): The tensor parallelism size.
dtype (torch.dtype): The data type to be used for the model.

Returns:
torch.nn.Module: The initialized model with tensor parallelism.
"""
# avoid re-entry
assert not hasattr(
model, 'ds_autotp_parsed'), "ds_autotp_parsed' attribute already exists in the model, re-entry is not allowed."

set_autotp_mode(training=True)

from deepspeed.runtime.tensor_parallel import TpTrainingManager
# The expected usage here is for it to be invoked by transformers package.

#TODO: We should provide a custom TP mapping solution without using autoTP
#as modifying the autoTP logic may be more difficult for users compared to configuring it

model = TpTrainingManager(model=model, tp_size=tp_size, dtype=dtype).module

setattr(model, 'ds_autotp_parsed', True)

return model
6 changes: 6 additions & 0 deletions deepspeed/comm/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,12 @@ def broadcast(tensor, src, group=None, async_op=False, prof=False, log_name='bro
return cdb.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)


@timed_op
def broadcast_object_list(object_list, src, group=None, device=None):
global cdb
return cdb.broadcast_object_list(object_list=object_list, src=src, group=group, device=device)


@timed_op
def all_gather(tensor_list,
tensor,
Expand Down
4 changes: 4 additions & 0 deletions deepspeed/comm/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,10 @@ def broadcast(self, tensor, src, group=None, async_op=False):
else:
return torch.distributed.broadcast(tensor=tensor, src=src, group=group, async_op=async_op)

@disable_compiler_collective
def broadcast_object_list(self, object_list, src, group=None, device=None):
return torch.distributed.broadcast_object_list(object_list=object_list, src=src, group=group, device=device)

@disable_compiler_collective
def all_gather(self, tensor_list, tensor, group=None, async_op=False):
if DS_COMM_ALL_GATHER_OFF:
Expand Down
1 change: 0 additions & 1 deletion deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine
from deepspeed.utils.timer import SynchronizedWallClockTimer
from deepspeed.runtime.compiler import is_compile_supported

from ..runtime.state_dict_factory import SDLoaderFactory
from ..runtime.weight_quantizer import WeightQuantization
from ..module_inject import replace_transformer_layer, generic_injection
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/module_inject/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
from .replace_module import replace_transformer_layer, revert_transformer_layer, ReplaceWithTensorSlicing, GroupQuantizer, generic_injection
from .module_quantize import quantize_transformer_layer
from .replace_policy import HFBertLayerPolicy
from .layers import LinearAllreduce, LinearLayer, EmbeddingLayer, Normalize
from .layers import LinearAllreduce, LinearLayer, EmbeddingLayer, Normalize, set_autotp_mode
from .policy import DSPolicy
62 changes: 32 additions & 30 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
from typing import Optional
import torch
from deepspeed import comm as dist
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce
from .layers import LinearAllreduce, LinearLayer, LmHeadLinearAllreduce, Yuan_LinearAllreduce, Yuan_LinearLayer, GateUpPack_LinearLayer, Conv_LinearALlreduce, fused_LinearLayer, conv_LinearLayer
from deepspeed.accelerator import get_accelerator
from .fusedqkv_utils import require_tp_fused_qkvw, prepare_tp_fused_qkvw, shard_value_with_share_qk, shard_chunk_mlp
from .fusedqkv_utils import require_tp_fused_qkvw
from deepspeed.module_inject.tp_shard import get_shard_size, get_shard_size_list
from deepspeed.utils import groups
from deepspeed.module_inject.layers import is_autotp_training_mode


def move(tensor, device, copy=True):
Expand Down Expand Up @@ -341,10 +343,18 @@ def tp_parser(model):
return policy_list

def set_tensor_parallel_config(self, mp_size, mp_group):

if is_autotp_training_mode():
self.mp_group = groups.get_tensor_model_parallel_group()
self.mp_size = groups.get_tensor_model_parallel_world_size()
return

self.mp_size = mp_size
self.mp_group = mp_group

def _replace(self, child, name, conv_linear_layer):
# This function should clearly define the routing rules for specific layers
# and avoid any complex shard-related logic.
if getattr(child, "replaced", False) == True:
return
device_name = 'cpu' if self.keep_module_on_host else get_accelerator().current_device_name()
Expand All @@ -360,14 +370,15 @@ def _replace(self, child, name, conv_linear_layer):
# For Yuan model
if 'Yuan' in str(self.module):
if 'v_proj' in name:
weight, bias = shard_value_with_share_qk(child.weight.data, child.bias, dist.get_rank(),
dist.get_world_size(), True)
return LinearLayer(weight=weight, bias=bias)
return Yuan_LinearLayer(child, self.mp_group)

elif 'o_proj' in name:
weight, bias = shard_value_with_share_qk(child.weight.data, child.bias, dist.get_rank(),
dist.get_world_size(), False)
return LinearAllreduce(weight, bias, self.mp_group)
# For Arctic model, bypass to all_reduce replacement for w2 weights
return Yuan_LinearAllreduce(child, self.mp_group)

# For MLP including chunk layer.
if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)):
return GateUpPack_LinearLayer(child, self.mp_group)
# For Arctic model, bypass to all_reduce replacement for w2 weights
arctic_w2_all_reduce_linear = False
if 'Arctic' in str(self.module) and 'w2' in name:
arctic_w2_all_reduce_linear = True
Expand All @@ -376,30 +387,15 @@ def _replace(self, child, name, conv_linear_layer):
#Deepseek processes different down_proj in different ways.
if 'down_proj' in name and 'DeepseekV2' not in str(type(self.module)):
down_proj = True
# For MLP including chunk layer.
if 'gate_up_proj' in name or ('dense_h_to_4h' in name and 'GLM' in str(self.module)):
weight, bias = shard_chunk_mlp(child.weight.data, child.bias, dist.get_rank(), dist.get_world_size())
return LinearLayer(weight=weight, bias=bias)
if name in self.all_reduce_linears or arctic_w2_all_reduce_linear or down_proj:
# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
# else [weight_shape[0], weight_shape[1] // mp_size]

setattr(child, "replaced", True)
if self.conv_linear_layer:
child.weight.data = child.weight.data.transpose(-1, -2).contiguous()
data = child.weight.data.split(get_shard_size_list(
weight_shape[0] if self.conv_linear_layer else weight_shape[1], self.mp_size, name),
dim=1)
data_dc = move(data[mp_replace.gpu_index], device_name, return_new_copy).detach()
del data
return Conv_LinearALlreduce(child, self.mp_group, name=name)
elif name == "lm_head" or name == 'embed_out':
return LmHeadLinearAllreduce(child, self.mp_group)

setattr(child, "replaced", True)
if name == "lm_head" or name == 'embed_out':
return LmHeadLinearAllreduce(
torch.nn.parameter.Parameter(data_dc, requires_grad=False), dist.get_rank(), dist.get_world_size(),
child.bias if child.bias is None else torch.nn.parameter.Parameter(
move(child.bias, device_name, return_new_copy)), self.mp_group)
return LinearAllreduce(torch.nn.parameter.Parameter(data_dc, requires_grad=False), child.bias if child.bias is None else \
torch.nn.parameter.Parameter(move(child.bias, device_name, return_new_copy)), self.mp_group)
return LinearAllreduce(child, self.mp_group, name=name)
else:

# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
Expand Down Expand Up @@ -442,7 +438,13 @@ def _replace(self, child, name, conv_linear_layer):
bias_data_dc = None

setattr(child, "replaced", True)
return LinearLayer(weight=torch.nn.parameter.Parameter(data_dc, requires_grad=False), bias=bias_data_dc)
if self.conv_linear_layer:
conv_LinearLayer(child, self.mp_group)
elif require_tp_fused_qkvw(name, self.mp_size):
#Check and handle fused qkv for TP
return fused_LinearLayer(child, self.mp_group, fused_module=self.module)

return LinearLayer(child, self.mp_group, name=name)

def _slice_embedding(self, child, name, conv_linear_layer):
if getattr(child, "replaced", False) == True:
Expand Down
Loading