Skip to content
This repository has been archived by the owner on Jan 27, 2021. It is now read-only.

Commit

Permalink
Allow unnamed arguments for 1-input nodes and disallow otherwise
Browse files Browse the repository at this point in the history
  • Loading branch information
ivrodr-msft committed Oct 24, 2016
1 parent d65c1cd commit d1d02cb
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 20 deletions.
3 changes: 2 additions & 1 deletion bindings/python/cntk/ops/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ def eval(self, arguments=None, device=None):
Returns:
`bool`: `True` if updates have been performed
'''
_, output_map = self.forward(arguments or {}, self.outputs, device=device)

_, output_map = self.forward(arguments, self.outputs, device=device)

if len(output_map) > 1:
return output_map
Expand Down
14 changes: 13 additions & 1 deletion bindings/python/cntk/ops/tests/function_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import pytest
from ..functions import *
from ...trainer import *
from .. import constant, parameter, input_variable, placeholder_variable, times
from .. import constant, parameter, input_variable, placeholder_variable, times, plus


def test_variable_forwarding():
Expand Down Expand Up @@ -75,3 +75,15 @@ def test_replace_placeholder_s():
op = times(p, right_val)
op.replace_placeholder(c)
assert op.eval() == 26

def test_exception_for_unnamed_arguments():
i1 = input_variable((1,2), name='i1')
i2 = input_variable((2,1), name='i2')
root_node = plus(i1, i2)
input1 = [[[1,2]]]
input2 = [[[[1],[2]]]]

with pytest.raises(Exception):
# not allowed, since plus has more than 1 input
result = root_node.eval([input1, input2])

19 changes: 16 additions & 3 deletions bindings/python/cntk/tests/persist_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,20 @@ def test_load_save_inputs(tmpdir):
loaded_result = loaded_node.eval({'i1': input1, 'i2': input2})
assert np.allclose(loaded_result, expected)

# Test spefying the input node names by order
loaded_result = loaded_node.eval([input1, input2])
def test_load_save_unique_input(tmpdir):
i1 = input_variable((1,2), name='i1')
root_node = softmax(i1)

input1 = [[[1,2]]]
result = root_node.eval(input1)
expected = [[[[ 0.268941, 0.731059]]]]
assert np.allclose(result, expected)

filename = str(tmpdir / 'i_plus_0.mod')
save_model(root_node, filename)

loaded_node = load_model('float', filename)

# Test specifying the only value for an unique input
loaded_result = loaded_node.eval(input1)
assert np.allclose(loaded_result, expected)

33 changes: 18 additions & 15 deletions bindings/python/cntk/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def get_temp_filename(directory=None):
Args:
directory (str): optional directory, in which the temporary file will
be created
be created
Returns:
Filename of the temporary file
Expand Down Expand Up @@ -282,7 +282,7 @@ def get_data_type(*args):
dtypes.add(np.float64)
else:
dtypes.add(np.float32)
else:
else:
# We don't know anything so we convert everything to float32. If it
# works, we know the type.
# TODO figure out a better/faster way.
Expand Down Expand Up @@ -338,7 +338,7 @@ def sanitize_batch(var, batch, seq_starts=None, data_type=None, device=None):
mask.
Args:
var (`:class:cntk.ops.variables.Variable`): variable node for which the ``batch`` is
var (:class:`cntk.ops.variables.Variable`): variable node for which the ``batch`` is
meant
batch (`list` of NumPy arrays): input
seq_starts (`list` of `bool` or `None`): if `None`, every sequence is
Expand All @@ -347,7 +347,7 @@ def sanitize_batch(var, batch, seq_starts=None, data_type=None, device=None):
continuation of the previous one (`False`)
Returns:
`:class:cntk.cntk_py.Value`: converted batch
:class:`cntk.cntk_py.Value`: converted batch
'''
from ..cntk_py import Value

Expand Down Expand Up @@ -477,18 +477,19 @@ def sanitize_var_map(op_arguments, arguments, precision=None,
'''
Sanitizes a dictionary of `Variable`s to input data such that it can be
handed off to the :meth:`cntk.ops.functions.Function.forward` method.
handed off to the evaluation methods (:meth:`cntk.ops.functions.Function.forward`, :meth:`cntk.ops.functions.Function.backward`, :meth:`cntk.Trainer.train_minibatch` and
:meth:`cntk.Trainer.test_minibatch`).
Args:
op_arguments (:class:`cntk.ops.functions.Function`): arguments of the root function. In
forward pass it is typically `op.arguments`, in backward mode it is
`op.outputs`
arguments (`dict` or `list` or `tuple`): maps variables to their
input data. The interpretation depends on the input type
* `dict`: keys are input variable or names and values are the input data.
* `list`: elements are input data in the order their respective variables have been defined in the network.
In both cases, every every sample in the data will be interpreted
arguments: maps variables to their
input data. The interpretation depends on the input type:
* `dict`: keys are input variable or names and values are the input data.
* any other type: if node has an unique input, argument is mapped to this input.
For nodes with more than one input, only `dict` is allowed.
In both cases, every sample in the data will be interpreted
as a new sequence. To mark samples as continuations of the
previous sequence, specify ``arguments`` as `tuple`: the
first element will be used as ``arguments``, and the second one will
Expand Down Expand Up @@ -521,16 +522,18 @@ def sanitize_var_map(op_arguments, arguments, precision=None,
raise ValueError('your graph has %i inputs, but you specified %i' %
(len(op_arguments), len(arguments)))

if isinstance(arguments, list):
arguments = dict(zip(op_arguments, arguments))

if isinstance(arguments, dict):
arg_names = [var.name for var in op_arguments]
name_counter = collections.Counter(arg_names)

var_name_map = dict((var.name, var) for var in op_arguments)
else:
raise ValueError('type "%s" is not supported' % type(arguments))
if len(op_arguments) == 1:
name_counter = collections.Counter([op_arguments[0].name])
var_name_map = dict([(op_arguments[0].name, op_arguments[0])])
arguments = dict([(op_arguments[0], arguments)])
else:
raise ValueError('non-dict argument (%s) is not supported for nodes with more than one input' % type(arguments).__name__)

sample_sizes = [len(v) for v in arguments.values()]
if len(set(sample_sizes)) != 1:
Expand Down

0 comments on commit d1d02cb

Please sign in to comment.