Skip to content

Commit

Permalink
add sharding strategy in fleet(PaddlePaddle#27900)
Browse files Browse the repository at this point in the history
* add sharding
  • Loading branch information
mapingshuo authored Oct 26, 2020
1 parent 4877bd5 commit 81244fb
Show file tree
Hide file tree
Showing 20 changed files with 1,648 additions and 14 deletions.
6 changes: 6 additions & 0 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ enum Mode {

message RecomputeConfig { repeated string checkpoints = 1; }

message ShardingConfig {
optional float fuse_broadcast_MB = 1 [ default = 32.0 ];
}

message AMPConfig {
optional float init_loss_scaling = 1 [ default = 32768.0 ];
optional int32 incr_every_n_steps = 2 [ default = 1000 ];
Expand Down Expand Up @@ -130,6 +134,7 @@ message DistributedStrategy {
optional bool cudnn_batchnorm_spatial_persistent = 23 [ default = true ];
optional bool adaptive_localsgd = 24 [ default = false ];
optional bool fp16_allreduce = 25 [ default = false ];
optional bool sharding = 26 [ default = false ];

optional RecomputeConfig recompute_configs = 101;
optional AMPConfig amp_configs = 102;
Expand All @@ -141,6 +146,7 @@ message DistributedStrategy {
optional LarsConfig lars_configs = 108;
optional LambConfig lamb_configs = 109;
optional AdaptiveLocalSGDConfig adaptive_localsgd_configs = 110;
optional ShardingConfig sharding_configs = 111;
optional BuildStrategy build_strategy = 201;
optional ExecutionStrategy execution_strategy = 202;
}
Expand Down
49 changes: 49 additions & 0 deletions python/paddle/distributed/fleet/base/distributed_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,55 @@ def recompute_configs(self, configs):
"checkpoint_configs")
assign_configs_value(self.strategy.recompute_configs, configs)

@property
def sharding(self):
"""
Indicating whether we are using sharding Optimizer for memory
optimization
Default value: False
Examples:
.. code-block:: python
import paddle.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.sharding = True
"""
return self.strategy.sharding

@sharding.setter
@is_strict_auto
def sharding(self, flag):
if isinstance(flag, bool):
self.strategy.sharding = flag
else:
print("WARNING: sharding should have value of bool type")

@property
def sharding_configs(self):
"""
Set sharding configurations.
**Note**:
fuse_broadcast_MB(float): size of a fused group of broadcasted parameters.
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.sharding = True
strategy.sharding_configs = {"fuse_broadcast_MB": 32}
"""
return get_msg_dict(self.strategy.sharding_configs)

@sharding_configs.setter
@is_strict_auto
def sharding_configs(self, configs):
check_configs_key(self.strategy.sharding_configs, configs,
"sharding_configs")
assign_configs_value(self.strategy.sharding_configs, configs)

@property
def pipeline(self):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@
from .dgc_optimizer import DGCOptimizer
from .lamb_optimizer import LambOptimizer
from .fp16_allreduce_optimizer import FP16AllReduceOptimizer
from .sharding_optimizer import ShardingOptimizer
6 changes: 6 additions & 0 deletions python/paddle/distributed/fleet/meta_optimizers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ def _init_communicator(self, program, current_endpoint, endpoints, rank,
OP_ROLE_KEY: OpRole.Forward
})

def _wait(self, current_endpoint, endpoints):
assert (self.wait_port)
other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint)
wait_server_ready(other_endpoints)

def _broadcast_params(self):
block = self.startup_program.global_block()
ring_id = -1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def _set_basic_info(self, loss, role_maker, user_defined_optimizer,
super(DGCOptimizer, self)._set_basic_info(
loss, role_maker, user_defined_optimizer, user_defined_strategy)

def _init_dgc_opt(self):
if self.dgc_opt is not None:
return

opt = self.inner_opt

if not self.role_maker._is_collective:
Expand Down Expand Up @@ -86,13 +90,16 @@ def backward(self,
parameter_list=None,
no_grad_set=None,
callbacks=None):
self._init_dgc_opt()
return self.dgc_opt.backward(loss, startup_program, parameter_list,
no_grad_set, callbacks)

def apply_gradients(self, params_grads):
self._init_dgc_opt()
return self.dgc_opt.apply_gradients(params_grads=params_grads)

def apply_optimize(self, loss, startup_program, params_grads):
self._init_dgc_opt()
return self.dgc_opt.apply_optimize(
loss, startup_program=startup_program, params_grads=params_grads)

Expand All @@ -101,6 +108,7 @@ def minimize_impl(self,
startup_program=None,
parameter_list=None,
no_grad_set=None):
self._init_dgc_opt()
optimize_ops, params_grads = \
self.dgc_opt.minimize(loss, startup_program,
parameter_list, no_grad_set)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.distributed.fleet.meta_optimizers.common import is_optimizer_op, OP_ROLE_KEY, OpRole
from paddle.distributed.fleet.meta_optimizers.sharding.utils import *

from paddle.fluid import core


class FP16Utils(object):
def __init__(self):
pass

@staticmethod
def is_fp16_cast_op(block, op, params):
if op.type != "cast":
return False
if is_optimizer_op(op):
return False
assert (len(op.desc.input_arg_names()) == 1)
assert (len(op.desc.output_arg_names()) == 1)
input_name, output_name = op.desc.input_arg_names()[
0], op.desc.output_arg_names()[0]
if input_name not in params:
return False
input_var = block.var(input_name)
output_var = block.var(output_name)
if input_var.dtype != core.VarDesc.VarType.FP32 or \
output_var.dtype != core.VarDesc.VarType.FP16:
return False
return True

@staticmethod
def is_fp32_cast_op(block, op):
if op.type != "cast":
return False
if not is_optimizer_op(op):
return False
assert (len(op.desc.input_arg_names()) == 1)
assert (len(op.desc.output_arg_names()) == 1)
input_name, output_name = op.desc.input_arg_names()[
0], op.desc.output_arg_names()[0]
input_var = block.var(input_name)
output_var = block.var(output_name)
if input_var.dtype != core.VarDesc.VarType.FP16 or \
output_var.dtype != core.VarDesc.VarType.FP32:
return False
return True

@staticmethod
def remove_cast_op(block, params, segment, offset):
inserted_op_num = 0
for op_idx in reversed(
range(offset + segment._start_idx, offset + segment._end_idx)):
op = block.ops[op_idx]
if FP16Utils.is_fp16_cast_op(block, op, params):
block._remove_op(op_idx, sync=False)
inserted_op_num -= 1
block._sync_with_cpp()
return inserted_op_num

@staticmethod
def prune_fp16(block, shard, reduced_grads_to_param, nrings):
# remove cast
for idx, op in reversed(list(enumerate(block.ops))):
if not FP16Utils.is_fp32_cast_op(block, op):
continue
output_name = op.desc.output_arg_names()[0]
param_name = output_name.strip("@GRAD")
if param_name not in shard.global_params:
raise ValueError("Input 'X' of check_finite_and_unscale must"
"be grads, but {} is not a grad".format(
input_name))
if output_name in reduced_grads_to_param:
continue
if shard.has_param(param_name):
continue
block._remove_op(idx, sync=False)
block._remove_var(output_name, sync=False)

block._sync_with_cpp()
update_loss_scaling_op_idx = -1
inf_var_name = ''
for idx, op in reversed(list(enumerate(block.ops))):
if op.type == "update_loss_scaling":
update_loss_scaling_op_idx = idx
inf_var_name = op.desc.input('FoundInfinite')[0]
op._rename_input(inf_var_name, inf_var_name + "@sharding")
if op.type in ["check_finite_and_unscale", "update_loss_scaling"]:
reversed_x = []
for input_name in op.desc.input('X'):
param_name = input_name.strip("@GRAD")
if param_name not in shard.global_params:
raise ValueError(
"Input 'X' of check_finite_and_unscale must"
"be grads, but {} is not a grad".format(input_name))
if shard.has_param(param_name):
reversed_x.append(input_name)
op.desc.set_input('X', reversed_x)
op.desc.set_output('Out', reversed_x)
if update_loss_scaling_op_idx == -1:
return
inf_var = block.var(inf_var_name)
inf_var_fp32 = block.create_var(
name=inf_var_name + "@cast_int32",
shape=inf_var.shape,
dtype=core.VarDesc.VarType.INT32)
inf_var_sharding = block.create_var(
name=inf_var_name + "@sharding",
shape=inf_var.shape,
dtype=inf_var.dtype)
block._insert_op_without_sync(
update_loss_scaling_op_idx,
type='cast',
inputs={'X': inf_var},
outputs={'Out': inf_var_fp32},
attrs={
"in_dtype": inf_var.dtype,
"out_dtype": inf_var_fp32.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
insert_sync_calc_op(block, update_loss_scaling_op_idx + 1,
[inf_var_fp32])
block._insert_op_without_sync(
update_loss_scaling_op_idx + 2,
type='c_allreduce_max',
inputs={'X': inf_var_fp32},
outputs={'Out': inf_var_fp32},
attrs={'ring_id': 0,
OP_ROLE_KEY: OpRole.Optimize})
comm_op_num = insert_sync_comm_ops(
block, update_loss_scaling_op_idx + 3, nrings, [inf_var_fp32])
block._insert_op_without_sync(
update_loss_scaling_op_idx + 3 + comm_op_num,
type='cast',
inputs={'X': inf_var_fp32},
outputs={'Out': inf_var_sharding},
attrs={
"in_dtype": inf_var_fp32.dtype,
"out_dtype": inf_var_sharding.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
block._sync_with_cpp()
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole


class GradientClipHelper(object):
def __init__(self):
pass

def _is_gradient_clip_op(self, op):
return op.desc.has_attr("op_namescope") \
and op.desc.attr("op_namescope").startswith("/gradient_clip")

def prune_gradient_clip(self, block, shard):
deperated_vars = set()
deperate_op_idx = set()
for idx, op in enumerate(block.ops):
if not self._is_gradient_clip_op(op):
continue
if op.type == "sum":
continue
deperate_op = False
for input_name in op.desc.input_arg_names():
if input_name in deperated_vars:
deperate_op = True
param_name = input_name.strip("@GRAD")
if shard.is_param(param_name) and \
not shard.has_param(param_name):
deperate_op = True

if deperate_op:
deperate_op_idx.add(idx)
for output_name in op.desc.output_arg_names():
deperated_vars.add(output_name)

if not deperated_vars:
# got no gradient_clip op
return

for idx, op in reversed(list(enumerate(block.ops))):
if not self._is_gradient_clip_op(op):
continue
if idx in deperate_op_idx:
block._remove_op(idx, sync=False)
continue
reversed_inputs = []
if op.type == "sum":
for input_name in op.desc.input_arg_names():
if input_name not in deperated_vars:
reversed_inputs.append(input_name)
op.desc.set_input("X", reversed_inputs)
assert (len(op.desc.output_arg_names()) == 1)
sum_res = op.desc.output_arg_names()[0]
block._insert_op_without_sync(
idx + 1,
type='c_sync_comm_stream',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={'ring_id': 0,
OP_ROLE_KEY: OpRole.Optimize})
block._insert_op_without_sync(
idx + 1,
type='c_allreduce_sum',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={'ring_id': 0,
OP_ROLE_KEY: OpRole.Optimize})
block._insert_op_without_sync(
idx + 1,
type='c_sync_calc_stream',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={OP_ROLE_KEY: OpRole.Optimize})

for var_name in deperated_vars:
block._remove_var(var_name, sync=False)
block._sync_with_cpp()
return
Loading

0 comments on commit 81244fb

Please sign in to comment.