Skip to content

Commit

Permalink
fix pserver weight decay multi inputs test=develop
Browse files Browse the repository at this point in the history
  • Loading branch information
typhoonzero committed Nov 8, 2018
1 parent 5b7a9dd commit f3eafec
Showing 1 changed file with 36 additions and 16 deletions.
52 changes: 36 additions & 16 deletions python/paddle/fluid/transpiler/distribute_transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1706,13 +1706,27 @@ def _get_param_block(opt_op):
outputs=outputs,
attrs=opt_op.all_attrs())

def _is_splited_grad_var(self, var, var_dict):
def _get_pserver_grad_param_var(self, var, var_dict):
"""
Return pserver side grad/param variable, return None
if the variable is not grad/param, e.g.
a@GRAD -> [email protected]
a@GRAD -> a@GRAD (a is not splited)
fc_0.w_0 -> fc_0.w_0.block_0
fc_0.w_0 -> fc_0.w_0 (weight is not splited)
_generated_var_123 -> None
"""
grad_block = None
for _, g in six.iteritems(var_dict):
if self._orig_varname(g.name) == self._orig_varname(var.name):
# skip per trainer vars
if g.name.find(".trainer_") == -1:
grad_block = g
break
# only param or grads have splited blocks
if self._orig_varname(g.name) in self.grad_name_to_param_name or\
self._orig_varname(g.name) in self.param_name_to_grad_name:
grad_block = g
break
return grad_block

def _clone_lr_op(self, program, block, op):
Expand Down Expand Up @@ -1745,32 +1759,38 @@ def _append_pserver_non_opt_ops(self, optimize_block, opt_op):
for key, varlist in six.iteritems(inputs):
if not isinstance(varlist, list):
varlist = [varlist]
for var in varlist:
# for ops like clipping and weight decay, get the splited var
for i in range(len(varlist)):
var = varlist[i]
# for ops like clipping and weight decay, get the splited var (xxx.block0)
# for inputs/outputs
grad_block = self._is_splited_grad_var(
grad_block = self._get_pserver_grad_param_var(
var, program.global_block().vars)
if grad_block:
inputs[key] = grad_block
varlist[i] = grad_block
elif var.name not in program.global_block().vars:
program.global_block().create_var(
name=var.name,
persistable=var.persistable,
dtype=var.dtype,
shape=var.shape)
tmpvar = program.global_block()._clone_variable(var)
varlist[i] = tmpvar
else:
varlist[i] = program.global_block().vars[var.name]
inputs[key] = varlist

outputs = self._get_output_map_from_op(
self.origin_program.global_block().vars, opt_op)
for key, varlist in six.iteritems(outputs):
if not isinstance(varlist, list):
varlist = [varlist]
for var in varlist:
grad_block = self._is_splited_grad_var(
for i in range(len(varlist)):
var = varlist[i]
grad_block = self._get_pserver_grad_param_var(
var, program.global_block().vars)
if grad_block:
outputs[key] = grad_block
varlist[i] = grad_block
elif var.name not in program.global_block().vars:
program.global_block()._clone_variable(var)
tmpvar = program.global_block()._clone_variable(var)
varlist[i] = tmpvar
else:
varlist[i] = program.global_block().vars[var.name]
outputs[key] = varlist

return optimize_block.append_op(
type=opt_op.type,
Expand Down

0 comments on commit f3eafec

Please sign in to comment.