Skip to content

Commit

Permalink
Making reader optional
Browse files Browse the repository at this point in the history
  • Loading branch information
wilrich-msft committed Mar 22, 2016
1 parent 168e7f1 commit 21fbaa2
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions contrib/Python/cntk/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def _generate_eval_config(self, root_node, reader):
return tmpl % tmpl_dict

@abstractmethod
def train(self, reader):
def train(self, reader=None):
'''
Abstract method for the action train.
:param reader: the reader to use for this action. Alternatively, you
Expand All @@ -217,7 +217,7 @@ def train(self, reader):
pass

@abstractmethod
def test(self, reader):
def test(self, reader=None):
'''
Abstract method for the action test.
:param reader: the reader to use for this action. Alternatively, you
Expand All @@ -226,7 +226,7 @@ def test(self, reader):
pass

@abstractmethod
def predict(self, reader):
def predict(self, reader=None):
'''
Abstract method for the action write. It evaluated the trained model on
the data provided by the reader.
Expand Down Expand Up @@ -283,7 +283,7 @@ def _call_cntk(self, config_file_name, config_content):

return output.decode("utf-8")

def train(self, reader, override_existing = True):
def train(self, reader=None, override_existing = True):
'''
Run the train action locally.
:param reader: the reader to use for this action. Alternatively, you
Expand All @@ -293,7 +293,7 @@ def train(self, reader, override_existing = True):
config_content = self._generate_train_config(reader, override_existing)
output = self._call_cntk(CNTK_TRAIN_CONFIG_FILENAME, config_content)

def test(self, reader):
def test(self, reader=None):
'''
Run the test action locally.
:param reader: the reader to use for this action. Alternatively, you
Expand All @@ -302,7 +302,7 @@ def test(self, reader):
config_content = self._generate_test_config(reader)
output = self._call_cntk(CNTK_TEST_CONFIG_FILENAME, config_content)

def predict(self, reader):
def predict(self, reader=None):
'''
Run the write action locally, use the trained model of this context.
:param reader: the reader to use for this action. Alternatively, you
Expand Down Expand Up @@ -374,7 +374,8 @@ def eval(self, node, reader=None):
expected_size = np.multiply.reduce(shapes[node.var_name])
expected_shape = shapes[node.var_name]

if data.size != expected_size:
receieved_all = data.size == expected_size
if not receieved_all:
# For some reason the CNTK write action has issues with multi-row
# output. So we have to CNTK reshape it to one row and do it again,
# but then NumPy reshape using node's expected shape.
Expand Down

0 comments on commit 21fbaa2

Please sign in to comment.