Skip to content

Commit

Permalink
[auto-parallel] support pipeline parallel for moe (PaddlePaddle#69296)
Browse files Browse the repository at this point in the history
* [auto-parallel] support pipeline parallel for moe-all2ll
  • Loading branch information
zhangting2020 authored Nov 18, 2024
1 parent e9b1291 commit e06da0a
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 17 deletions.
14 changes: 12 additions & 2 deletions paddle/fluid/eager/custom_operator/custom_operator_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -438,14 +438,18 @@ static std::vector<std::vector<phi::DataType>> RunInferDtypeFunc(
paddle::Tensor BuildEmptyDistPaddleTensor(
const phi::distributed::ProcessMesh& process_mesh,
const phi::DDim& dims,
phi::DataType dtype) {
phi::DataType dtype,
const std::vector<int64_t>& dims_mapping = {}) {
paddle::Tensor empty_tensor;
phi::DenseTensorMeta meta;
meta.dims = dims;
meta.dtype = dtype;

auto dist_attr = phi::distributed::TensorDistAttr(common::vectorize(dims));
dist_attr.set_process_mesh(process_mesh);
if (!dims_mapping.empty()) {
dist_attr.set_dims_mapping(dims_mapping);
}

auto dist_t = std::make_shared<phi::distributed::DistTensor>(
std::make_shared<phi::DenseTensor>(
Expand Down Expand Up @@ -699,8 +703,14 @@ std::
if (out_dim.size() == 1) {
output_dims.emplace_back(out_dim[0]);
if (!rank_is_in_current_mesh) {
std::vector<int64_t> dims_mapping = {};
if (!spmd_info.second.empty()) {
dims_mapping = PADDLE_GET_CONST(phi::distributed::TensorDistAttr,
spmd_info.second[i])
.dims_mapping();
}
*(ctx.MutableOutputAt(pair.first)) = BuildEmptyDistPaddleTensor(
current_process_mesh, out_dim[0], out_dtype[0]);
current_process_mesh, out_dim[0], out_dtype[0], dims_mapping);
}
} else {
for (size_t j = pair.first; j < pair.second; j++) {
Expand Down
7 changes: 6 additions & 1 deletion paddle/phi/core/distributed/auto_parallel/dist_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ DistTensor::DistTensor(const std::shared_ptr<phi::DenseTensor>& local_value,
} else {
value_ = std::make_shared<DenseTensor>(
std::make_shared<phi::Allocation>(nullptr, 0, local_value->place()),
phi::DenseTensorMeta(local_value->dtype(), global_dims_));
phi::DenseTensorMeta(local_value->dtype(), phi::make_ddim({0})));
}
}

Expand All @@ -197,6 +197,11 @@ DistTensor::DistTensor(const std::shared_ptr<phi::DenseTensor>& global_value,
: global_dims_(global_value->dims()) {
process_mesh_ = process_mesh;
placements_ = placements;
// If the dims.size() == -1, the dims=[0] by default, which is not consistent
// and will cause ToTensorDistAttr‘s error.
if (global_dims_ == DDim()) {
global_dims_ = phi::make_ddim({});
}
dist_attr_ = ToTensorDistAttr(process_mesh_, placements_, global_dims_);

// If the current rank doesn't in process_mesh, we should create an
Expand Down
30 changes: 23 additions & 7 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
from paddle.optimizer import Optimizer

from .moe_utils import (
_cal_local_shape,
_dist_reshape,
_NdMeshAlltoAll,
_reshard_mesh_shape,
Expand Down Expand Up @@ -459,14 +460,24 @@ def moe_global_mesh_tensor(
)
process_ids = np.array(mesh.process_ids).reshape(mesh.shape)
local_coord = np.where(process_ids == dist.get_rank())
local_tensor_idx = local_coord[local_mesh_dim][0]
# local_tensor_idx = mesh.process_ids.index(dist.get_rank())
# when rank is not in current mesh, local_coord is empty, so we should calculate the
# local tensor's shape.
if local_coord[0].size == 0:
local_tensor_idx = 0
else:
local_tensor_idx = local_coord[local_mesh_dim][0]
local_tensor = local_tensor_list[local_tensor_idx]

if paddle.in_dynamic_mode():
global_dims = _cal_global_shape(
local_tensor._local_value().shape, mesh, placements
)
if local_coord[0].size == 0:
local_tensor_shape = _cal_local_shape(
local_tensor_list[0].shape, local_mesh_list[0], local_placements
)
else:
local_tensor_shape = (
local_tensor_list[local_tensor_idx]._local_value().shape
)
global_dims = _cal_global_shape(local_tensor_shape, mesh, placements)
resharded_local_tensor_list = []
for i, tensor in enumerate(local_tensor_list):
tensor.get_tensor()._unsafe_set_skip_check_mesh(True)
Expand Down Expand Up @@ -547,7 +558,9 @@ def forward(
assert check_placements_equal(
global_placements, dist_tensor.placements
), f"the global_placements ({global_placements}) is not equal to dist_tensor's placements ({dist_tensor.placements})."
local_shape = dist_tensor._local_value().shape
local_shape = _cal_local_shape(
dist_tensor.shape, global_mesh, global_placements
)
for idx, placement in enumerate(local_placements):
if placement.is_shard():
shard_dim = placement.get_dim()
Expand Down Expand Up @@ -579,7 +592,10 @@ def backward(ctx, *grad_tensor):
mesh = ctx.global_mesh
process_ids = np.array(mesh.process_ids).reshape(mesh.shape)
local_coord = np.where(process_ids == dist.get_rank())
local_tensor_idx = local_coord[ctx.local_mesh_dim][0]
if local_coord[0].size == 0:
local_tensor_idx = 0
else:
local_tensor_idx = local_coord[ctx.local_mesh_dim][0]
local_grad = grad_tensor[local_tensor_idx]
global_tensor = paddle.Tensor(
local_grad._local_value(),
Expand Down
87 changes: 80 additions & 7 deletions python/paddle/distributed/auto_parallel/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,62 @@ def _specific_alltoall_dim(
return mesh_dim


def _dtensor_from_local(
local_tensor, mesh, placements, local_tensor_shape=None
):
# assume the each rank has the same tensor shape for now, just use the local shape to calculate the global shape
global_dims = list(local_tensor.shape)
if local_tensor_shape is not None:
global_dims = local_tensor_shape
for idx, placement in enumerate(placements):
if placement.is_shard():
shard_dim = placement.get_dim()
local_dim_size = global_dims[shard_dim]
global_dims[shard_dim] = local_dim_size * mesh.shape[idx]

if paddle.in_dynamic_mode():
place = paddle.framework._current_expected_place()
place = paddle.framework._get_paddle_place(place)

return paddle.Tensor(
local_tensor,
dims=global_dims,
process_mesh=mesh,
placements=placements,
place=place,
)

# TODO Adopt Mix2Dist Pass to allow the program could be executed actually.
elif paddle.framework.in_pir_mode():
assert isinstance(
local_tensor, (type(None), paddle.pir.Value)
), "input tensor is not pir value."
assert (
local_tensor.is_dense_tensor_type()
), "dtensor_from_local() are only supported dense tensor type right."
sharding_specs = (
paddle.distributed.auto_parallel.placement_type.get_shard_spec(
mesh, placements, local_tensor.ndim
)
)
dims_mapping = paddle.distributed.auto_parallel.static.utils.convert_to_dims_mapping(
sharding_specs, mesh
)
local_shape = local_tensor.shape
global_tensor_type = paddle.pir.create_shaped_type(
local_tensor.type(), global_dims
)
dist_dense_tensor_type = paddle.base.libpaddle.pir.create_dist_dense_tensor_type_by_dense_tensor(
global_tensor_type, local_shape, mesh, dims_mapping
)
local_tensor.set_type(dist_dense_tensor_type)
return local_tensor
else:
raise RuntimeError(
"dtensor_from_local() are only supported in dynamic or pir mode."
)


class _NdMeshAlltoAll(PyLayer):
@staticmethod
def forward(
Expand All @@ -87,12 +143,19 @@ def forward(
ctx.out_mesh = copy.deepcopy(mesh)
ctx.out_placements = copy.deepcopy(placements)

out = dist.auto_parallel.api.dtensor_from_local(
dist_tensor._local_value(), sub_mesh, [dist_tensor.placements[dim]]
local_shape = _cal_local_shape(
dist_tensor.shape, mesh, dist_tensor.placements
)
out = _dtensor_from_local(
dist_tensor._local_value(),
sub_mesh,
[dist_tensor.placements[dim]],
local_shape,
)
out = dist.reshard(out, sub_mesh, [placements[dim]])
out = dist.auto_parallel.api.dtensor_from_local(
out._local_value(), mesh, placements
local_shape = _cal_local_shape(out.shape, mesh, out.placements)
out = _dtensor_from_local(
out._local_value(), mesh, placements, local_shape
)
out.stop_gradient = dist_tensor.stop_gradient
return out
Expand Down Expand Up @@ -148,7 +211,10 @@ def forward(
place = paddle.framework._current_expected_place()
place = paddle.framework._get_paddle_place(place)

local_tensor = dist_tensor._local_value().clone()
if dist_tensor._local_value()._is_initialized():
local_tensor = dist_tensor._local_value().clone()
else:
local_tensor = dist_tensor._local_value()
ctx.x_global_shape = copy.deepcopy(dist_tensor.shape)
ctx.x_local_shape = copy.deepcopy(local_tensor.shape)
ctx.x_mesh = copy.deepcopy(dist_tensor.process_mesh)
Expand All @@ -170,8 +236,13 @@ def backward(ctx, out_grad):
place = paddle.framework._current_expected_place()
place = paddle.framework._get_paddle_place(place)

local_grad = out_grad._local_value().clone()
local_grad = local_grad.reshape(ctx.x_local_shape)
if out_grad._local_value()._is_initialized():
local_grad = out_grad._local_value().clone()
x_local_shape = ctx.x_local_shape
else:
local_grad = out_grad._local_value()
x_local_shape = [0]
local_grad = local_grad.reshape(x_local_shape)
ret = paddle.Tensor(
local_grad,
dims=ctx.x_global_shape,
Expand All @@ -195,6 +266,8 @@ def _dist_reshape(
tgt_global_shape = infer_positive_shape(dist_tensor.shape, global_shape)
tgt_local_shape = _cal_local_shape(tgt_global_shape, mesh, placements)
src_local_shape = dist_tensor._local_value().shape
if not dist_tensor._local_value()._is_initialized():
tgt_local_shape = dist_tensor._local_value().shape
assert np.prod(tgt_local_shape) == np.prod(
src_local_shape
), f"The local shapes {src_local_shape} and {tgt_local_shape} are mismatched."
Expand Down

0 comments on commit e06da0a

Please sign in to comment.