Skip to content

Commit

Permalink
reader supports multiple inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
jeanfad committed Mar 15, 2016
1 parent c72d791 commit 3ca13ab
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 38 deletions.
6 changes: 3 additions & 3 deletions LanguageBindings/Python/cntk/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def add_macro(self, path):
'''
self.macros.append(path)

def _generate_train_config(self, input_map):
def _generate_train_config(self, reader):
'''
Generates the configuration file for the train action.
'''
Expand All @@ -130,7 +130,7 @@ def _generate_train_config(self, input_map):
'DevideId': self.device_id,
'ModelDescription': self.to_description(),
'ModelPath': model_filename,
'Reader': self._generate_reader_config(input_map),
'Reader': reader_config,
'SGD': self.optimizer.generate_config(),
}
return tmpl % tmpl_dict
Expand Down Expand Up @@ -283,7 +283,7 @@ def _call_cntk(self, config_file_name, config_content):
subprocess.check_call(
[CNTK_EXECUTABLE_PATH, "configFile=%s" % filename])

def train(self, input_map):
def train(self, reader):
'''
Run the train action locally.
:param input_map: mapping of input node to (reader, (start_dim, num_dim))
Expand Down
81 changes: 46 additions & 35 deletions LanguageBindings/Python/cntk/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,71 +23,82 @@ def __ne__(self, x): return x is not self
class UCIFastReader(AbstractReader):

"""This is the reader class
:param filename: data file path
:param features_dim: number of label columns
:param features_start: the index of the first label column
:param label_node_name: the name of the labels node in the network
:param labels_dim: number of label columns
:param labels_start: the index of the first label column
:param num_of_classes: the number of classes
:param label_mapping_file:
the mapping file path, it can be simply with all the possible classes, one per line
:param custom_delimiter: the default is space and tab, you can specify other delimiters to be used
:param custom_delimiter: the default is space and tab, you can specify other delimiters to be used
:param inputs_def: is a list of tuples (input_name, input_start, input_dim)
input_name: the name of the input in the network definition
input_start: the start column
input_dim: the number of columns
"""

def __init__(self, filename,
features_dim=None,
features_start=None,
label_node_name="labels",
labels_dim=None,
labels_start=None,
num_of_classes=None,
label_mapping_file=None,
custom_delimiter=None):
num_of_classes=None,
label_mapping_file=None,
custom_delimiter=None,
inputs_def=None):
""" Reader constructor
"""
self["ReaderType"] = self.__class__.__name__
self["FileName"] = filename
self["LabelsDim"] = labels_dim
self["LabelsStart"] = labels_start
self["FeaturesDim"] = features_dim
self["FeaturesStart"] = features_start
self["LabelsStart"] = labels_start
self["NumOfClasses"] = num_of_classes
self["LabelMappingFile"] = label_mapping_file
self["CustomDelimiter"] = custom_delimiter
self.inputs_def = inputs_def or []

def add_input(self, input_name, input_start, input_dim):
"""Add an input to the reader
:param input_name: the name of the input in the network definition
:param input_start: the start column
:param input_dim: the number of columns
"""
if (not (input_name and input_start and input_dim)):
raise ValueError("one of the parameters of add_input is None or empty string")

self.inputs_def.append((input_name, input_start, input_dim))

def generate_config(self):
"""Generate the reader configuration block
"""
template = '''
readerType = "%(ReaderType)s"
file = "%(FileName)s"
randomize = "none"
verbosity = 1
'''
readerType = "%(ReaderType)s"
file = "%(FileName)s"
randomize = "none"
verbosity = 1
'''

if self['CustomDelimiter'] is not None:
template += '''
customDelimiter=%(CustomDelimiter)s
'''
# TODO: generalize the reader to take n features sequences and m label
# sequences
if self['FeaturesStart'] is not None:
template += '''
v2=[
start = "%(FeaturesStart)s"
dim = "%(FeaturesDim)s"
]'''

customDelimiter=%(CustomDelimiter)s
'''

if self['LabelsStart'] is not None:
template += '''
labels=[
start = "%(LabelsStart)s"
dim = "%(LabelsDim)s"
labelDim="%(NumOfClasses)s"
labelMappingFile="%(LabelMappingFile)s"
]'''

if self.inputs_def is not None:
for (name, start, dim) in self.inputs_def:
template += '''
{0}=[
start = {1}
dim = {2}
]'''.format(name, start, dim)

v0=[
start = "%(LabelsStart)s"
dim = "%(LabelsDim)s"
labelDim="%(NumOfClasses)s"
labelMappingFile="%(LabelMappingFile)s"
]'''

return template % self

Expand Down

0 comments on commit 3ca13ab

Please sign in to comment.