Skip to content

Commit

Permalink
[AutoParallel] Add paddle.distributed.reshard python API. (PaddlePadd…
Browse files Browse the repository at this point in the history
…le#57293)

* Add paddle.distributed.reshard API. It supports reshard for DistTensor.

* Polish code with review comments.

* Fix problem of in_dynamic_mode

* Fix some problems according to review comments.

* Set test_reshard_api as multi-cards testcase. And set its timeout.
  • Loading branch information
GhostScreaming authored Sep 19, 2023
1 parent ad32cca commit 4c7fc29
Show file tree
Hide file tree
Showing 6 changed files with 221 additions and 0 deletions.
38 changes: 38 additions & 0 deletions paddle/fluid/pybind/auto_parallel_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,16 @@

#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
#include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#include "paddle/phi/core/distributed/auto_parallel/p_to_r_reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h"
#include "paddle/phi/core/distributed/auto_parallel/s_to_s_reshard_function.h"
#include "paddle/phi/core/enforce.h"

#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/infermeta/spmd_rules/rules.h"
Expand Down Expand Up @@ -544,6 +548,40 @@ void BindAutoParallel(py::module *m) {
},
py::return_value_policy::reference);

m->def(
"reshard",
[](py::handle py_tensor, const TensorDistAttr &dist_attr) {
auto tensor = CastPyArg2Tensor(py_tensor.ptr(), 0);
auto dev_ctx = phi::DeviceContextPool::Instance().Get(tensor.place());
std::shared_ptr<phi::distributed::DistTensor> dist_out_ptr = nullptr;
if (phi::distributed::DistTensor::classof(tensor.impl().get())) {
auto tensor_in = tensor.impl();
if (tensor_in) {
phi::distributed::DistTensor *dist_tensor =
static_cast<phi::distributed::DistTensor *>(tensor_in.get());
if (dist_tensor->dist_attr() != dist_attr) {
VLOG(6) << "reshard func, reshard tensor from "
<< dist_tensor->dist_attr() << " to " << dist_attr;
auto *func = phi::distributed::ChooseProperReshardFunction(
*dist_tensor, dist_attr);
dist_out_ptr = func->Eval(dev_ctx, *dist_tensor, dist_attr);
} else {
dist_out_ptr =
std::static_pointer_cast<phi::distributed::DistTensor>(
tensor_in);
}
}
return paddle::Tensor(dist_out_ptr);
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"The input tensor of shard function should be "
"``phi::distributed::DistTensor``. "
"However it's %s",
typeid(tensor.impl().get()).name()));
}
},
py::return_value_policy::reference);

// TODO(liuzhenhai): DistributedMapper is not used for now, but
// dist_mapper_test need the symbols forch DistributedMapper to be linked,
// remove it latter
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from .auto_parallel import shard_op # noqa: F401
from .auto_parallel.api import shard_tensor # noqa: F401
from .auto_parallel.api import dtensor_from_fn # noqa: F401
from .auto_parallel.api import reshard # noqa: F401

from .fleet import BoxPSDataset # noqa: F401

Expand Down Expand Up @@ -128,4 +129,5 @@
"DistAttr",
"shard_tensor",
"dtensor_from_fn",
"reshard",
]
45 changes: 45 additions & 0 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,48 @@ def dtensor_from_fn(fn, dist_attr, *args, **kwargs):
"""
tensor = fn(*args, **kwargs)
return shard_tensor(tensor, dist_attr=dist_attr)


def reshard(dist_tensor, dist_attr):
"""
Reshard a distributed ``paddle.Tensor`` with given distributed attributes.
Args:
dist_tensor(Tensor): the distributed tensor to be resharded.
dist_attr(paddle.distributed.DistAttr): Specify how tensors are distributed or sliced on ProcessMesh.
Returns:
Tensor: A Distributed Tensor reshared with distributed attributes.
Examples:
.. code-block:: python
import paddle
import paddle.distributed as dist
mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"])
dist_attr = dist.DistAttr(mesh=mesh, sharding_specs=['x', 'y'])
out_mesh = dist.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"])
out_dist_attr = dist.DistAttr(mesh=out_mesh, sharding_specs=[None, None])
# dense tensor
a = paddle.to_tensor([[1,2,3],
[5,6,7]])
# distributed tensor
d_tensor = dist.shard_tensor(a, dist_attr=dist_attr)
out_d_tensor = dist.reshard(d_tensor, out_dist_attr)
print(d_tensor)
print(out_d_tensor)
"""

if paddle.framework.in_dynamic_mode():
return paddle.base.core.reshard(dist_tensor, dist_attr)
else:
# TODO(GhostScreaming): Support static DistTensor later.
raise RuntimeError(
"paddle.dist.reshard only support dynamic graph now. It will be supported for static graph later."
)
4 changes: 4 additions & 0 deletions test/auto_parallel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_pass_quantization MODULES test_pass_quantization)
set_tests_properties(test_pass_quantization
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 60)
py_test_modules(test_reshard_api MODULES test_reshard_api)
set_tests_properties(test_reshard_api PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 150)
py_test_modules(test_reshard_s_to_r MODULES test_reshard_s_to_r)
set_tests_properties(test_reshard_s_to_r
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100)
Expand Down Expand Up @@ -186,6 +189,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_engine_save_load MODULES test_engine_save_load)
py_test_modules(test_rule_based_tuner MODULES test_rule_based_tuner)
py_test_modules(test_dist_tensor MODULES test_dist_tensor)
py_test_modules(test_api_dist_branch MODULES test_api_dist_branch)
py_test_modules(test_shard_tensor_api MODULES test_shard_tensor_api)
py_test_modules(test_cost_interface MODULES test_cost_interface)
# End of unittests WITH single card WITHOUT timeout
Expand Down
87 changes: 87 additions & 0 deletions test/auto_parallel/reshard_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import numpy as np

import paddle
import paddle.distributed as dist


class TestReshardAPI:
def __init__(self):
self._shape = eval(os.getenv("shape"))
self._dtype = os.getenv("dtype")
self._seeds = eval(os.getenv("seeds"))
self._backend = os.getenv("backend")
self._shard = eval(os.getenv("shard"))
self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"])

def run_test_cases(self):
if self._backend == "cpu":
paddle.set_device("cpu")
self.test_case_p_to_r()

def test_case_p_to_r(self):
a = paddle.ones(self._shape)
in_shard_specs = [None for i in range(len(self._shape))]
out_shard_specs = [None for i in range(len(self._shape))]
dist_attr = dist.DistAttr(
mesh=self._mesh, sharding_specs=in_shard_specs
)
dist_attr._set_partial_dims([0])
out_dist_attr = dist.DistAttr(
mesh=self._mesh, sharding_specs=out_shard_specs
)

input_tensor = dist.shard_tensor(a, dist_attr=dist_attr)
output_tensor = dist.reshard(input_tensor, dist_attr=out_dist_attr)

input_tensor = dist.shard_tensor(a, dist_attr=dist_attr)
assert np.equal(output_tensor.shape, input_tensor.shape).all()
np.testing.assert_equal(output_tensor._local_value().numpy(), a.numpy())

def test_case_r_to_s(self):
a = paddle.ones(self._shape)
in_shard_specs = [None for i in range(len(self._shape))]
out_shard_specs = [None for i in range(len(self._shape))]
out_shard_specs[self._shard] = "x"
dist_attr = dist.DistAttr(
mesh=self._mesh, sharding_specs=in_shard_specs
)
out_dist_attr = dist.DistAttr(
mesh=self._mesh, sharding_specs=out_shard_specs
)

input_tensor = dist.shard_tensor(a, dist_attr=dist_attr)
output_tensor = dist.reshard(input_tensor, dist_attr=out_dist_attr)

out_shape = list(self._shape)
if out_shape[self._shard] % 2 == 0:
out_shape[self._shard] = out_shape[self._shard] // 2
np.testing.assert_equal(output_tensor.numpy(), input_tensor.numpy())
else:
out_shape[self._shard] = (
out_shape[self._shard] // 2
if dist.get_rank() == 1
else out_shape[self._shard] // 2 + 1
)

assert np.equal(output_tensor.shape, input_tensor.shape).all()
assert np.equal(output_tensor._local_shape, out_shape).all()


if __name__ == '__main__':
TestReshardAPI().run_test_cases()
45 changes: 45 additions & 0 deletions test/auto_parallel/test_reshard_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import collective.test_communication_api_base as test_base


class TestReshardAPI(test_base.CommunicationTestDistBase):
def setUp(self):
super().setUp(num_of_devices=2, timeout=120)
self._default_envs = {
"shape": "(10, 20)",
"dtype": "float32",
"seeds": str(self._seeds),
"shard": "0",
}
self._changeable_envs = {
"backend": ["cpu", "gpu"],
}

def test_reshard_api(self):
envs_list = test_base.gen_product_envs_list(
self._default_envs, self._changeable_envs
)
for envs in envs_list:
self.run_test_case(
"reshard_api.py",
user_defined_envs=envs,
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 4c7fc29

Please sign in to comment.