Skip to content

Commit

Permalink
Better parameter checks
Browse files Browse the repository at this point in the history
  • Loading branch information
wilrich-msft committed Mar 15, 2016
1 parent 1660699 commit 8d850d3
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 18 deletions.
14 changes: 7 additions & 7 deletions LanguageBindings/Python/cntk/examples/LogReg/test1.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from cntk import *

if (__name__ == "__main__"):
x = Input(2)
y = Input(3)
x = Input(2, var_name='x')
y = Input(3, var_name='y')
w = LearnableParameter(3, 2)
b = LearnableParameter(3, 1)
t = Times(w, x)
Expand All @@ -15,17 +15,17 @@
ec.tag = 'criterion'

reader = UCIFastReader(
"Train-3Classes.txt", "2", "0", "1", "2", "3", "SimpleMapping-3Classes.txt")
#reader.add_input(x, 0, 2)
#reader.add_input(y, 2, 1)
"Train-3Classes.txt", "y", "1", "2", "3", "SimpleMapping-3Classes.txt")
reader.add_input('x', 0, 2)
reader.add_input('y', 2, 1)

my_sgd = SGD(
epoch_size=0, minibatch_size=25, learning_ratesPerMB=0.1, max_epochs=3)

with Context('demo', optimizer=my_sgd, root_node= ec, clean_up=False) as ctx:
input_map = {x: (reader, (0, 2)), y: (reader, (2, 1))}
ctx.train(input_map)
ctx.train(reader)

#import ipdb;ipdb.set_trace()
result = ctx.eval(out, input_map)
print(result)
print(result[:3])
35 changes: 24 additions & 11 deletions LanguageBindings/Python/cntk/graph.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@

class ComputationNode(object):
'''
Base class for all nodes and operators. Provides a NumPy-like interface
with operators that are converted to CNTK operators.
'''

def __init__(self, name, params=None, var_name=None):
if not isinstance(name, str):
raise ValueError("Parameter 'name' has to be a string and not '%s'"%type(name))
if var_name is not None and not isinstance(var_name, str):
raise ValueError("Parameter 'var_name' has to be a string and not '%s'"%type(var_name))
self.name = name
self.params = params
self.var_name = var_name
Expand Down Expand Up @@ -100,7 +103,7 @@ def _param_to_brainscript(self, p_name, p_value):
p_value = ":".join(v for v in p_value)
else:
raise ValueError('Sequence initialization is only allowed for' +
' parameter "dims" and not "%s"' % p_name)
' parameters dims and not "%s"' % p_name)
else:
p_value = str(p_value)

Expand All @@ -116,16 +119,26 @@ def _to_description_unroll(self, desc, unrolled_nodes, inputs, node_counter=0):
if self.params:
for p_name in self.params:
p_value = self.__dict__[p_name]
if hasattr(p_value, '_to_description') and p_name:
if p_value in unrolled_nodes:
# we have seen this node already, so just retrieve its
# name
child_var = unrolled_nodes[p_value]
if hasattr(p_value, '_to_description') and p_name or \
p_name == 'inputs':
# TODO this is under the assumption that RowStack's
# inputs parameter gets a tuple of inputs

if p_name == 'inputs':
inputs = p_value
else:
child_var, node_counter, child_desc = p_value._to_description_unroll(
desc, unrolled_nodes, inputs, node_counter)
unrolled_nodes[p_value] = child_var
param_variable_names.append(child_var)
inputs = [p_value]

for p_value in inputs:
if p_value in unrolled_nodes:
# we have seen this node already, so just retrieve its
# name
child_var = unrolled_nodes[p_value]
else:
child_var, node_counter, child_desc = p_value._to_description_unroll(
desc, unrolled_nodes, inputs, node_counter)
unrolled_nodes[p_value] = child_var
param_variable_names.append(child_var)
else:
param_variable_names.append(
self._param_to_brainscript(p_name, p_value))
Expand Down

0 comments on commit 8d850d3

Please sign in to comment.