Skip to content

Commit

Permalink
[XPU] Update Sharding stage2 for XPU (PaddlePaddle#48369)
Browse files Browse the repository at this point in the history
* support xpu scalar inplace

* sharding for xpu

* update

* update

Co-authored-by: heyanru <[email protected]>
  • Loading branch information
sljlp and heyanru01 authored Nov 25, 2022
1 parent 776aef7 commit db749ee
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

import paddle
from paddle.fluid import core
from .group_sharded_utils import Type, device_guard
from .group_sharded_utils import Type, device_guard, cvt_to_device


class InternalStorage:
Expand Down Expand Up @@ -76,8 +76,8 @@ def to(self, device, dtype=None, keep_alignment=True):

if self._device != device:
tmp_buffer = (
self.buffer.cuda(self.dev_id)
if device == "gpu"
cvt_to_device(self.buffer, self.dev_id)
if device in ["gpu", "xpu", "npu"]
else self.buffer.cpu()
)
for param in self._params:
Expand Down Expand Up @@ -133,7 +133,7 @@ def add_rank_params(self, trainable_params, param2align, convert_gpu=True):

if convert_gpu:
# buffer convert from cpu to cuda
self.buffer = self.buffer.cuda(self.dev_id)
self.buffer = cvt_to_device(self.buffer, self.dev_id)

self._fill = 0

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _dygraph_clip(self, params_grads):
if paddle.device.get_device() == "cpu":
global_norm_var = global_norm_var.cuda(dev_id)

with device_guard(dev_id, "gpu"):
with device_guard(dev_id, self._device.split(":")[0]):
paddle.distributed.all_reduce(global_norm_var, group=self._group)

global_norm_var = paddle.sqrt(global_norm_var)
Expand Down Expand Up @@ -170,8 +170,8 @@ def device_guard(dev_id=0, device="cpu"):
origin_device = paddle.device.get_device()
if device == "cpu":
paddle.set_device(device)
elif device == "gpu":
paddle.set_device("gpu:{}".format(dev_id))
elif device in ["gpu", "xpu", "npu"]:
paddle.set_device("{}:{}".format(device, dev_id))
try:
yield
finally:
Expand Down Expand Up @@ -251,3 +251,20 @@ def unscale_method(self, optimizer):

scaler._unscale = MethodType(unscale_method, scaler)
return scaler


def cvt_to_device(x, dev_id, blocking=True):
"""
Copy data in x from cpu memory to supported device
"""
if paddle.is_compiled_with_cuda():
place = paddle.CUDAPlace(dev_id)
elif paddle.is_compiled_with_npu():
place = paddle.NPUPlace(dev_id)
elif paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(dev_id)
else:
raise EnvironmentError(
"Only supported compiled paddle with gpu/rocm, npu and xpu , but current verison is compiled with cpu."
)
return x._copy_to(place, blocking)
12 changes: 12 additions & 0 deletions python/paddle/distributed/sharding/group_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,12 @@ def group_sharded_parallel(
optimizer.step()
optimizer.clear_grad()
"""

device = paddle.get_device().split(":")[0]
assert device in [
"gpu",
"xpu",
], "group_sharded_parallel only support gpu and xpu now"
# check optition type
assert isinstance(
model, paddle.nn.Layer
Expand Down Expand Up @@ -148,6 +154,7 @@ def check_dtype(param):
group=group,
offload=offload,
dp_group=dp_group,
device=device,
)
model = GroupShardedStage2(
model,
Expand All @@ -156,20 +163,23 @@ def check_dtype(param):
sync_buffers=sync_buffers,
buffer_max_size=buffer_max_size,
dp_group=dp_group,
device=device,
)
else:
optimizer = ShardingOptimizerStage2(
params=model.parameters(),
optim=optimizer,
group=group,
offload=offload,
device=device,
)
model = ShardingStage2(
model,
optimizer,
group=group,
sync_buffers=sync_buffers,
buffer_max_size=buffer_max_size,
device=device,
)
elif level == 'p_g_os':
if in_dygraph_mode():
Expand All @@ -181,6 +191,7 @@ def check_dtype(param):
segment_size=segment_size,
offload=offload,
sync_comm=sync_comm,
device=device,
)
else:
model = ShardingStage3(
Expand All @@ -191,6 +202,7 @@ def check_dtype(param):
segment_size=segment_size,
offload=offload,
sync_comm=sync_comm,
device=device,
)
else:
raise ValueError("Please enter the correct level.")
Expand Down

0 comments on commit db749ee

Please sign in to comment.