diff --git a/paddle/fluid/eager/custom_operator/custom_operator_utils.cc b/paddle/fluid/eager/custom_operator/custom_operator_utils.cc index 2f082aad035252..965cdaa358c100 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_utils.cc +++ b/paddle/fluid/eager/custom_operator/custom_operator_utils.cc @@ -438,7 +438,8 @@ static std::vector> RunInferDtypeFunc( paddle::Tensor BuildEmptyDistPaddleTensor( const phi::distributed::ProcessMesh& process_mesh, const phi::DDim& dims, - phi::DataType dtype) { + phi::DataType dtype, + const std::vector& dims_mapping = {}) { paddle::Tensor empty_tensor; phi::DenseTensorMeta meta; meta.dims = dims; @@ -446,6 +447,9 @@ paddle::Tensor BuildEmptyDistPaddleTensor( auto dist_attr = phi::distributed::TensorDistAttr(common::vectorize(dims)); dist_attr.set_process_mesh(process_mesh); + if (!dims_mapping.empty()) { + dist_attr.set_dims_mapping(dims_mapping); + } auto dist_t = std::make_shared( std::make_shared( @@ -699,8 +703,14 @@ std:: if (out_dim.size() == 1) { output_dims.emplace_back(out_dim[0]); if (!rank_is_in_current_mesh) { + std::vector dims_mapping = {}; + if (!spmd_info.second.empty()) { + dims_mapping = PADDLE_GET_CONST(phi::distributed::TensorDistAttr, + spmd_info.second[i]) + .dims_mapping(); + } *(ctx.MutableOutputAt(pair.first)) = BuildEmptyDistPaddleTensor( - current_process_mesh, out_dim[0], out_dtype[0]); + current_process_mesh, out_dim[0], out_dtype[0], dims_mapping); } } else { for (size_t j = pair.first; j < pair.second; j++) { diff --git a/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc b/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc index b534ba504803b2..e6644985176bd3 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc +++ b/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc @@ -172,7 +172,7 @@ DistTensor::DistTensor(const std::shared_ptr& local_value, } else { value_ = std::make_shared( std::make_shared(nullptr, 0, local_value->place()), - phi::DenseTensorMeta(local_value->dtype(), global_dims_)); + phi::DenseTensorMeta(local_value->dtype(), phi::make_ddim({0}))); } } @@ -197,6 +197,11 @@ DistTensor::DistTensor(const std::shared_ptr& global_value, : global_dims_(global_value->dims()) { process_mesh_ = process_mesh; placements_ = placements; + // If the dims.size() == -1, the dims=[0] by default, which is not consistent + // and will cause ToTensorDistAttr‘s error. + if (global_dims_ == DDim()) { + global_dims_ = phi::make_ddim({}); + } dist_attr_ = ToTensorDistAttr(process_mesh_, placements_, global_dims_); // If the current rank doesn't in process_mesh, we should create an diff --git a/python/paddle/distributed/auto_parallel/api.py b/python/paddle/distributed/auto_parallel/api.py index b2da271f93efbb..ce0a860fd1a34c 100644 --- a/python/paddle/distributed/auto_parallel/api.py +++ b/python/paddle/distributed/auto_parallel/api.py @@ -66,6 +66,7 @@ from paddle.optimizer import Optimizer from .moe_utils import ( + _cal_local_shape, _dist_reshape, _NdMeshAlltoAll, _reshard_mesh_shape, @@ -459,14 +460,24 @@ def moe_global_mesh_tensor( ) process_ids = np.array(mesh.process_ids).reshape(mesh.shape) local_coord = np.where(process_ids == dist.get_rank()) - local_tensor_idx = local_coord[local_mesh_dim][0] - # local_tensor_idx = mesh.process_ids.index(dist.get_rank()) + # when rank is not in current mesh, local_coord is empty, so we should calculate the + # local tensor's shape. + if local_coord[0].size == 0: + local_tensor_idx = 0 + else: + local_tensor_idx = local_coord[local_mesh_dim][0] local_tensor = local_tensor_list[local_tensor_idx] if paddle.in_dynamic_mode(): - global_dims = _cal_global_shape( - local_tensor._local_value().shape, mesh, placements - ) + if local_coord[0].size == 0: + local_tensor_shape = _cal_local_shape( + local_tensor_list[0].shape, local_mesh_list[0], local_placements + ) + else: + local_tensor_shape = ( + local_tensor_list[local_tensor_idx]._local_value().shape + ) + global_dims = _cal_global_shape(local_tensor_shape, mesh, placements) resharded_local_tensor_list = [] for i, tensor in enumerate(local_tensor_list): tensor.get_tensor()._unsafe_set_skip_check_mesh(True) @@ -547,7 +558,9 @@ def forward( assert check_placements_equal( global_placements, dist_tensor.placements ), f"the global_placements ({global_placements}) is not equal to dist_tensor's placements ({dist_tensor.placements})." - local_shape = dist_tensor._local_value().shape + local_shape = _cal_local_shape( + dist_tensor.shape, global_mesh, global_placements + ) for idx, placement in enumerate(local_placements): if placement.is_shard(): shard_dim = placement.get_dim() @@ -579,7 +592,10 @@ def backward(ctx, *grad_tensor): mesh = ctx.global_mesh process_ids = np.array(mesh.process_ids).reshape(mesh.shape) local_coord = np.where(process_ids == dist.get_rank()) - local_tensor_idx = local_coord[ctx.local_mesh_dim][0] + if local_coord[0].size == 0: + local_tensor_idx = 0 + else: + local_tensor_idx = local_coord[ctx.local_mesh_dim][0] local_grad = grad_tensor[local_tensor_idx] global_tensor = paddle.Tensor( local_grad._local_value(), diff --git a/python/paddle/distributed/auto_parallel/moe_utils.py b/python/paddle/distributed/auto_parallel/moe_utils.py index 3a2b0021610a28..b6ac80231b4d0a 100644 --- a/python/paddle/distributed/auto_parallel/moe_utils.py +++ b/python/paddle/distributed/auto_parallel/moe_utils.py @@ -71,6 +71,62 @@ def _specific_alltoall_dim( return mesh_dim +def _dtensor_from_local( + local_tensor, mesh, placements, local_tensor_shape=None +): + # assume the each rank has the same tensor shape for now, just use the local shape to calculate the global shape + global_dims = list(local_tensor.shape) + if local_tensor_shape is not None: + global_dims = local_tensor_shape + for idx, placement in enumerate(placements): + if placement.is_shard(): + shard_dim = placement.get_dim() + local_dim_size = global_dims[shard_dim] + global_dims[shard_dim] = local_dim_size * mesh.shape[idx] + + if paddle.in_dynamic_mode(): + place = paddle.framework._current_expected_place() + place = paddle.framework._get_paddle_place(place) + + return paddle.Tensor( + local_tensor, + dims=global_dims, + process_mesh=mesh, + placements=placements, + place=place, + ) + + # TODO Adopt Mix2Dist Pass to allow the program could be executed actually. + elif paddle.framework.in_pir_mode(): + assert isinstance( + local_tensor, (type(None), paddle.pir.Value) + ), "input tensor is not pir value." + assert ( + local_tensor.is_dense_tensor_type() + ), "dtensor_from_local() are only supported dense tensor type right." + sharding_specs = ( + paddle.distributed.auto_parallel.placement_type.get_shard_spec( + mesh, placements, local_tensor.ndim + ) + ) + dims_mapping = paddle.distributed.auto_parallel.static.utils.convert_to_dims_mapping( + sharding_specs, mesh + ) + local_shape = local_tensor.shape + global_tensor_type = paddle.pir.create_shaped_type( + local_tensor.type(), global_dims + ) + dist_dense_tensor_type = paddle.base.libpaddle.pir.create_dist_dense_tensor_type_by_dense_tensor( + global_tensor_type, local_shape, mesh, dims_mapping + ) + local_tensor.set_type(dist_dense_tensor_type) + return local_tensor + else: + raise RuntimeError( + "dtensor_from_local() are only supported in dynamic or pir mode." + ) + + class _NdMeshAlltoAll(PyLayer): @staticmethod def forward( @@ -87,12 +143,19 @@ def forward( ctx.out_mesh = copy.deepcopy(mesh) ctx.out_placements = copy.deepcopy(placements) - out = dist.auto_parallel.api.dtensor_from_local( - dist_tensor._local_value(), sub_mesh, [dist_tensor.placements[dim]] + local_shape = _cal_local_shape( + dist_tensor.shape, mesh, dist_tensor.placements + ) + out = _dtensor_from_local( + dist_tensor._local_value(), + sub_mesh, + [dist_tensor.placements[dim]], + local_shape, ) out = dist.reshard(out, sub_mesh, [placements[dim]]) - out = dist.auto_parallel.api.dtensor_from_local( - out._local_value(), mesh, placements + local_shape = _cal_local_shape(out.shape, mesh, out.placements) + out = _dtensor_from_local( + out._local_value(), mesh, placements, local_shape ) out.stop_gradient = dist_tensor.stop_gradient return out @@ -148,7 +211,10 @@ def forward( place = paddle.framework._current_expected_place() place = paddle.framework._get_paddle_place(place) - local_tensor = dist_tensor._local_value().clone() + if dist_tensor._local_value()._is_initialized(): + local_tensor = dist_tensor._local_value().clone() + else: + local_tensor = dist_tensor._local_value() ctx.x_global_shape = copy.deepcopy(dist_tensor.shape) ctx.x_local_shape = copy.deepcopy(local_tensor.shape) ctx.x_mesh = copy.deepcopy(dist_tensor.process_mesh) @@ -170,8 +236,13 @@ def backward(ctx, out_grad): place = paddle.framework._current_expected_place() place = paddle.framework._get_paddle_place(place) - local_grad = out_grad._local_value().clone() - local_grad = local_grad.reshape(ctx.x_local_shape) + if out_grad._local_value()._is_initialized(): + local_grad = out_grad._local_value().clone() + x_local_shape = ctx.x_local_shape + else: + local_grad = out_grad._local_value() + x_local_shape = [0] + local_grad = local_grad.reshape(x_local_shape) ret = paddle.Tensor( local_grad, dims=ctx.x_global_shape, @@ -195,6 +266,8 @@ def _dist_reshape( tgt_global_shape = infer_positive_shape(dist_tensor.shape, global_shape) tgt_local_shape = _cal_local_shape(tgt_global_shape, mesh, placements) src_local_shape = dist_tensor._local_value().shape + if not dist_tensor._local_value()._is_initialized(): + tgt_local_shape = dist_tensor._local_value().shape assert np.prod(tgt_local_shape) == np.prod( src_local_shape ), f"The local shapes {src_local_shape} and {tgt_local_shape} are mismatched."