Skip to content

Commit

Permalink
text reader support
Browse files Browse the repository at this point in the history
  • Loading branch information
jeanfad committed Mar 17, 2016
1 parent 2ba6fa5 commit 075cd75
Showing 1 changed file with 67 additions and 4 deletions.
71 changes: 67 additions & 4 deletions LanguageBindings/Python/cntk/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@ def __eq__(self, x): return x is self

def __ne__(self, x): return x is not self


class UCIFastReader(AbstractReader):

"""This is the reader class
"""This is the reader class the maps to UCIFastReader of CNTK
:param filename: data file path
:param labels_node_name: the name of the labels node in the network
:param labels_dim: number of label columns
Expand Down Expand Up @@ -65,7 +64,7 @@ def add_input(self, name_or_node, input_start, input_dim):
:param input_dim: the number of columns
"""
if not name_or_node or input_start is None or input_dim is None:
raise ValueError("one of the parameters of add_input is None or empty string")
raise ValueError("one of the parameters of add_input is None")

self.inputs_def.append((name_or_node, input_start, input_dim))

Expand Down Expand Up @@ -108,10 +107,74 @@ def generate_config(self):

return template % self

class CNTKTextFormatReader(AbstractReader):

"""This is the reader class the maps to CNTKTextFormatReader of CNTK
:param filename: data file path
:param inputs_def: is a list of tuples (name_or_node, alias, input_dim, format)
name_or_node: the name of the input in the network definition or the node itself
alias: a short name for the input, it is how inputs are referenced in the data files
input_dim: the lenght of the input vector
format: dense or sparse
"""

def __init__(self, filename,
inputs_def=None):
""" Reader constructor
"""
self["ReaderType"] = self.__class__.__name__
self["FileName"] = filename
self.inputs_def = inputs_def or []

def add_input(self, name_or_node, alias, input_dim, format="Dense"):
"""Add an input to the reader
name_or_node: the name of the input in the network definition or the node itself
alias: a short name for the input, it is how inputs are referenced in the data files
input_dim: the lenght of the input vector
format: dense or sparse
"""
if not name_or_node or input_dim is None or format is None:
raise ValueError("one of the parameters of add_input is None")

alias = alias or name_or_node

self.inputs_def.append((name_or_node, alias, input_dim, format))

def generate_config(self):
"""Generate the reader configuration block
"""
template = '''
readerType = "%(ReaderType)s"
file = "%(FileName)s"
'''

if self.inputs_def is not None:
template += '''
input = [
'''

for (name_or_node, alias, dim, format) in self.inputs_def:
if (isinstance(name_or_node, ComputationNode)):
name = name_or_node.var_name
a = name_or_node.var_name
else:
name = name_or_node
a = alias
template += '''
{0}=[
alias = "{1}"
dim = {2}
format = "{3}"
]'''.format(name, a, dim, format)

template += '''
]
'''
return template % self

def NumPyReader(data, filename):
"""
This is a convenience function that wraps Python arrays.
This is a factory that wraps Python arrays with a UCIFastReader.
"""

import numpy as np
Expand Down

0 comments on commit 075cd75

Please sign in to comment.