From 3ca13aba3807be6f7c2b4bc168b104a1bfede745 Mon Sep 17 00:00:00 2001 From: jeanfad Date: Tue, 15 Mar 2016 17:42:37 +0100 Subject: [PATCH] reader supports multiple inputs --- LanguageBindings/Python/cntk/context.py | 6 +- LanguageBindings/Python/cntk/reader.py | 81 ++++++++++++++----------- 2 files changed, 49 insertions(+), 38 deletions(-) diff --git a/LanguageBindings/Python/cntk/context.py b/LanguageBindings/Python/cntk/context.py index 7df896a92a6e..626266f1a128 100644 --- a/LanguageBindings/Python/cntk/context.py +++ b/LanguageBindings/Python/cntk/context.py @@ -119,7 +119,7 @@ def add_macro(self, path): ''' self.macros.append(path) - def _generate_train_config(self, input_map): + def _generate_train_config(self, reader): ''' Generates the configuration file for the train action. ''' @@ -130,7 +130,7 @@ def _generate_train_config(self, input_map): 'DevideId': self.device_id, 'ModelDescription': self.to_description(), 'ModelPath': model_filename, - 'Reader': self._generate_reader_config(input_map), + 'Reader': reader_config, 'SGD': self.optimizer.generate_config(), } return tmpl % tmpl_dict @@ -283,7 +283,7 @@ def _call_cntk(self, config_file_name, config_content): subprocess.check_call( [CNTK_EXECUTABLE_PATH, "configFile=%s" % filename]) - def train(self, input_map): + def train(self, reader): ''' Run the train action locally. :param input_map: mapping of input node to (reader, (start_dim, num_dim)) diff --git a/LanguageBindings/Python/cntk/reader.py b/LanguageBindings/Python/cntk/reader.py index 47081c199ac6..c113f612a848 100644 --- a/LanguageBindings/Python/cntk/reader.py +++ b/LanguageBindings/Python/cntk/reader.py @@ -23,71 +23,82 @@ def __ne__(self, x): return x is not self class UCIFastReader(AbstractReader): """This is the reader class - :param filename: data file path - :param features_dim: number of label columns - :param features_start: the index of the first label column + :param label_node_name: the name of the labels node in the network :param labels_dim: number of label columns :param labels_start: the index of the first label column :param num_of_classes: the number of classes :param label_mapping_file: the mapping file path, it can be simply with all the possible classes, one per line - :param custom_delimiter: the default is space and tab, you can specify other delimiters to be used + :param custom_delimiter: the default is space and tab, you can specify other delimiters to be used + :param inputs_def: is a list of tuples (input_name, input_start, input_dim) + input_name: the name of the input in the network definition + input_start: the start column + input_dim: the number of columns """ def __init__(self, filename, - features_dim=None, - features_start=None, + label_node_name="labels", labels_dim=None, labels_start=None, - num_of_classes=None, - label_mapping_file=None, - custom_delimiter=None): + num_of_classes=None, + label_mapping_file=None, + custom_delimiter=None, + inputs_def=None): """ Reader constructor """ self["ReaderType"] = self.__class__.__name__ self["FileName"] = filename self["LabelsDim"] = labels_dim - self["LabelsStart"] = labels_start - self["FeaturesDim"] = features_dim - self["FeaturesStart"] = features_start + self["LabelsStart"] = labels_start self["NumOfClasses"] = num_of_classes self["LabelMappingFile"] = label_mapping_file self["CustomDelimiter"] = custom_delimiter + self.inputs_def = inputs_def or [] + def add_input(self, input_name, input_start, input_dim): + """Add an input to the reader + :param input_name: the name of the input in the network definition + :param input_start: the start column + :param input_dim: the number of columns + """ + if (not (input_name and input_start and input_dim)): + raise ValueError("one of the parameters of add_input is None or empty string") + + self.inputs_def.append((input_name, input_start, input_dim)) + def generate_config(self): """Generate the reader configuration block """ template = ''' - readerType = "%(ReaderType)s" - file = "%(FileName)s" - randomize = "none" - verbosity = 1 - ''' + readerType = "%(ReaderType)s" + file = "%(FileName)s" + randomize = "none" + verbosity = 1 + ''' if self['CustomDelimiter'] is not None: template += ''' - customDelimiter=%(CustomDelimiter)s - ''' - # TODO: generalize the reader to take n features sequences and m label - # sequences - if self['FeaturesStart'] is not None: - template += ''' - - v2=[ - start = "%(FeaturesStart)s" - dim = "%(FeaturesDim)s" - ]''' - + customDelimiter=%(CustomDelimiter)s + ''' + if self['LabelsStart'] is not None: template += ''' + labels=[ + start = "%(LabelsStart)s" + dim = "%(LabelsDim)s" + labelDim="%(NumOfClasses)s" + labelMappingFile="%(LabelMappingFile)s" + ]''' + + if self.inputs_def is not None: + for (name, start, dim) in self.inputs_def: + template += ''' + {0}=[ + start = {1} + dim = {2} + ]'''.format(name, start, dim) - v0=[ - start = "%(LabelsStart)s" - dim = "%(LabelsDim)s" - labelDim="%(NumOfClasses)s" - labelMappingFile="%(LabelMappingFile)s" - ]''' return template % self