Skip to content

Commit

Permalink
CNTK v2 library: Fix a few bugs in UDF backprop state management.
Browse files Browse the repository at this point in the history
  • Loading branch information
amitaga committed May 13, 2017
1 parent fd28b91 commit 7e7e3b8
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 12 deletions.
21 changes: 12 additions & 9 deletions bindings/python/cntk/ops/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,10 +1334,14 @@ def _forward(self, arguments, outputs, device=None, outputs_to_retain=None):
else:
state = self.forward(args, outputs, device, outputs_to_retain)

if state is None:
state = self._get_none_state(device)
elif not isinstance(state, cntk_py.BackPropState):
state = cntk_py.UserBackPropState.create(self, device, state)
if isinstance(state, cntk_py.BackPropState):
self._state_wrapped = False
else:
self._state_wrapped = True
if state is None:
state = self._get_none_state(device)
else:
state = cntk_py.UserBackPropState.create(self, device, state)

if self.as_numpy:
for k,v in outputs.items():
Expand Down Expand Up @@ -1379,12 +1383,11 @@ def _backward(self, state, root_gradients, variables):
if v.needs_gradient:
root_gradients[v] = _value_as_sequence_or_array(root_gradients[v], v)

state = cntk_py.UserBackPropState.data(state)
if not isinstance(state, cntk_py.BackPropState):
raise ValueError('state must be of type BackPropState')

else:
if not isinstance(state, cntk_py.BackPropState):
raise ValueError('if as_numpy=False, state must be of '
'type BackPropState')
if self._state_wrapped:
state = cntk_py.UserBackPropState.data(state)

map_if_possible(variables)

Expand Down
10 changes: 7 additions & 3 deletions bindings/python/cntk/ops/tests/userfunction_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,11 +635,15 @@ def infer_outputs(self):

def forward(self, arguments, device=None, outputs_to_retain=None):
result = self.compute_func.eval({self.compute_func.arguments[0] : arguments[0], self.compute_func.arguments[1] : arguments[1]}, as_numpy=False)
return arguments, result
self.backprop_state = arguments
return self.backprop_state, result

def backward(self, state, root_gradients, variables):
assert state == self.backprop_state
variables[self.inputs[0]] = root_gradients

def test_udf_input_values_no_sharing():
i = C.input_variable(1, name='i_var')
i = C.input_variable(1, needs_gradient=True, name='i_var')
m = C.user_function(MyArgumentPreservingPlus(i + 1, i + 2))

w = C.parameter(shape=(1,), init=1)
Expand All @@ -648,5 +652,5 @@ def test_udf_input_values_no_sharing():
m3 = C.splice(m2, m2, axis=0)
m4 = C.splice(m3, m3, axis=0)

grad_value, result = m4.grad({i : np.asarray([2], dtype=np.float32)}, outputs=[m4], wrt=[w])
grad_value, result = m4.grad({i : np.asarray([2], dtype=np.float32)}, outputs=[m4], wrt=[w, i])
assert np.array_equal(result, [[8, 8, 8, 8, 8, 8, 8, 8]])

0 comments on commit 7e7e3b8

Please sign in to comment.