forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add sharding strategy in fleet(PaddlePaddle#27900)
* add sharding
- Loading branch information
1 parent
4877bd5
commit 81244fb
Showing
20 changed files
with
1,648 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
13 changes: 13 additions & 0 deletions
13
python/paddle/distributed/fleet/meta_optimizers/sharding/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
154 changes: 154 additions & 0 deletions
154
python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from paddle.distributed.fleet.meta_optimizers.common import is_optimizer_op, OP_ROLE_KEY, OpRole | ||
from paddle.distributed.fleet.meta_optimizers.sharding.utils import * | ||
|
||
from paddle.fluid import core | ||
|
||
|
||
class FP16Utils(object): | ||
def __init__(self): | ||
pass | ||
|
||
@staticmethod | ||
def is_fp16_cast_op(block, op, params): | ||
if op.type != "cast": | ||
return False | ||
if is_optimizer_op(op): | ||
return False | ||
assert (len(op.desc.input_arg_names()) == 1) | ||
assert (len(op.desc.output_arg_names()) == 1) | ||
input_name, output_name = op.desc.input_arg_names()[ | ||
0], op.desc.output_arg_names()[0] | ||
if input_name not in params: | ||
return False | ||
input_var = block.var(input_name) | ||
output_var = block.var(output_name) | ||
if input_var.dtype != core.VarDesc.VarType.FP32 or \ | ||
output_var.dtype != core.VarDesc.VarType.FP16: | ||
return False | ||
return True | ||
|
||
@staticmethod | ||
def is_fp32_cast_op(block, op): | ||
if op.type != "cast": | ||
return False | ||
if not is_optimizer_op(op): | ||
return False | ||
assert (len(op.desc.input_arg_names()) == 1) | ||
assert (len(op.desc.output_arg_names()) == 1) | ||
input_name, output_name = op.desc.input_arg_names()[ | ||
0], op.desc.output_arg_names()[0] | ||
input_var = block.var(input_name) | ||
output_var = block.var(output_name) | ||
if input_var.dtype != core.VarDesc.VarType.FP16 or \ | ||
output_var.dtype != core.VarDesc.VarType.FP32: | ||
return False | ||
return True | ||
|
||
@staticmethod | ||
def remove_cast_op(block, params, segment, offset): | ||
inserted_op_num = 0 | ||
for op_idx in reversed( | ||
range(offset + segment._start_idx, offset + segment._end_idx)): | ||
op = block.ops[op_idx] | ||
if FP16Utils.is_fp16_cast_op(block, op, params): | ||
block._remove_op(op_idx, sync=False) | ||
inserted_op_num -= 1 | ||
block._sync_with_cpp() | ||
return inserted_op_num | ||
|
||
@staticmethod | ||
def prune_fp16(block, shard, reduced_grads_to_param, nrings): | ||
# remove cast | ||
for idx, op in reversed(list(enumerate(block.ops))): | ||
if not FP16Utils.is_fp32_cast_op(block, op): | ||
continue | ||
output_name = op.desc.output_arg_names()[0] | ||
param_name = output_name.strip("@GRAD") | ||
if param_name not in shard.global_params: | ||
raise ValueError("Input 'X' of check_finite_and_unscale must" | ||
"be grads, but {} is not a grad".format( | ||
input_name)) | ||
if output_name in reduced_grads_to_param: | ||
continue | ||
if shard.has_param(param_name): | ||
continue | ||
block._remove_op(idx, sync=False) | ||
block._remove_var(output_name, sync=False) | ||
|
||
block._sync_with_cpp() | ||
update_loss_scaling_op_idx = -1 | ||
inf_var_name = '' | ||
for idx, op in reversed(list(enumerate(block.ops))): | ||
if op.type == "update_loss_scaling": | ||
update_loss_scaling_op_idx = idx | ||
inf_var_name = op.desc.input('FoundInfinite')[0] | ||
op._rename_input(inf_var_name, inf_var_name + "@sharding") | ||
if op.type in ["check_finite_and_unscale", "update_loss_scaling"]: | ||
reversed_x = [] | ||
for input_name in op.desc.input('X'): | ||
param_name = input_name.strip("@GRAD") | ||
if param_name not in shard.global_params: | ||
raise ValueError( | ||
"Input 'X' of check_finite_and_unscale must" | ||
"be grads, but {} is not a grad".format(input_name)) | ||
if shard.has_param(param_name): | ||
reversed_x.append(input_name) | ||
op.desc.set_input('X', reversed_x) | ||
op.desc.set_output('Out', reversed_x) | ||
if update_loss_scaling_op_idx == -1: | ||
return | ||
inf_var = block.var(inf_var_name) | ||
inf_var_fp32 = block.create_var( | ||
name=inf_var_name + "@cast_int32", | ||
shape=inf_var.shape, | ||
dtype=core.VarDesc.VarType.INT32) | ||
inf_var_sharding = block.create_var( | ||
name=inf_var_name + "@sharding", | ||
shape=inf_var.shape, | ||
dtype=inf_var.dtype) | ||
block._insert_op_without_sync( | ||
update_loss_scaling_op_idx, | ||
type='cast', | ||
inputs={'X': inf_var}, | ||
outputs={'Out': inf_var_fp32}, | ||
attrs={ | ||
"in_dtype": inf_var.dtype, | ||
"out_dtype": inf_var_fp32.dtype, | ||
OP_ROLE_KEY: OpRole.Optimize | ||
}) | ||
insert_sync_calc_op(block, update_loss_scaling_op_idx + 1, | ||
[inf_var_fp32]) | ||
block._insert_op_without_sync( | ||
update_loss_scaling_op_idx + 2, | ||
type='c_allreduce_max', | ||
inputs={'X': inf_var_fp32}, | ||
outputs={'Out': inf_var_fp32}, | ||
attrs={'ring_id': 0, | ||
OP_ROLE_KEY: OpRole.Optimize}) | ||
comm_op_num = insert_sync_comm_ops( | ||
block, update_loss_scaling_op_idx + 3, nrings, [inf_var_fp32]) | ||
block._insert_op_without_sync( | ||
update_loss_scaling_op_idx + 3 + comm_op_num, | ||
type='cast', | ||
inputs={'X': inf_var_fp32}, | ||
outputs={'Out': inf_var_sharding}, | ||
attrs={ | ||
"in_dtype": inf_var_fp32.dtype, | ||
"out_dtype": inf_var_sharding.dtype, | ||
OP_ROLE_KEY: OpRole.Optimize | ||
}) | ||
block._sync_with_cpp() |
90 changes: 90 additions & 0 deletions
90
python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole | ||
|
||
|
||
class GradientClipHelper(object): | ||
def __init__(self): | ||
pass | ||
|
||
def _is_gradient_clip_op(self, op): | ||
return op.desc.has_attr("op_namescope") \ | ||
and op.desc.attr("op_namescope").startswith("/gradient_clip") | ||
|
||
def prune_gradient_clip(self, block, shard): | ||
deperated_vars = set() | ||
deperate_op_idx = set() | ||
for idx, op in enumerate(block.ops): | ||
if not self._is_gradient_clip_op(op): | ||
continue | ||
if op.type == "sum": | ||
continue | ||
deperate_op = False | ||
for input_name in op.desc.input_arg_names(): | ||
if input_name in deperated_vars: | ||
deperate_op = True | ||
param_name = input_name.strip("@GRAD") | ||
if shard.is_param(param_name) and \ | ||
not shard.has_param(param_name): | ||
deperate_op = True | ||
|
||
if deperate_op: | ||
deperate_op_idx.add(idx) | ||
for output_name in op.desc.output_arg_names(): | ||
deperated_vars.add(output_name) | ||
|
||
if not deperated_vars: | ||
# got no gradient_clip op | ||
return | ||
|
||
for idx, op in reversed(list(enumerate(block.ops))): | ||
if not self._is_gradient_clip_op(op): | ||
continue | ||
if idx in deperate_op_idx: | ||
block._remove_op(idx, sync=False) | ||
continue | ||
reversed_inputs = [] | ||
if op.type == "sum": | ||
for input_name in op.desc.input_arg_names(): | ||
if input_name not in deperated_vars: | ||
reversed_inputs.append(input_name) | ||
op.desc.set_input("X", reversed_inputs) | ||
assert (len(op.desc.output_arg_names()) == 1) | ||
sum_res = op.desc.output_arg_names()[0] | ||
block._insert_op_without_sync( | ||
idx + 1, | ||
type='c_sync_comm_stream', | ||
inputs={'X': sum_res}, | ||
outputs={'Out': sum_res}, | ||
attrs={'ring_id': 0, | ||
OP_ROLE_KEY: OpRole.Optimize}) | ||
block._insert_op_without_sync( | ||
idx + 1, | ||
type='c_allreduce_sum', | ||
inputs={'X': sum_res}, | ||
outputs={'Out': sum_res}, | ||
attrs={'ring_id': 0, | ||
OP_ROLE_KEY: OpRole.Optimize}) | ||
block._insert_op_without_sync( | ||
idx + 1, | ||
type='c_sync_calc_stream', | ||
inputs={'X': sum_res}, | ||
outputs={'Out': sum_res}, | ||
attrs={OP_ROLE_KEY: OpRole.Optimize}) | ||
|
||
for var_name in deperated_vars: | ||
block._remove_var(var_name, sync=False) | ||
block._sync_with_cpp() | ||
return |
Oops, something went wrong.