Skip to content

Commit

Permalink
Fix kaldi Reader. 1) Using the new MBLayout interface. 2) Modify the …
Browse files Browse the repository at this point in the history
…configParameter to configRecordType to consistent with HTKMLFReader. 3) Clean the warning message. Next: refactor to make it consistent with HTKMLFReader.
  • Loading branch information
yzhang87 committed Dec 14, 2015
1 parent a45365c commit b8eb51d
Show file tree
Hide file tree
Showing 9 changed files with 487 additions and 3,646 deletions.
229 changes: 128 additions & 101 deletions DataReader/Kaldi2Reader/HTKMLFReader.cpp

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions DataReader/Kaldi2Reader/HTKMLFReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,10 @@ class HTKMLFReader : public IDataReader<ElemType>

std::vector<std::vector<std::vector<ElemType>>>m_labelToTargetMapMultiIO;

void PrepareForTrainingOrTesting(const ConfigParameters& config);
void PrepareForWriting(const ConfigParameters& config);
void PrepareForSequenceTraining(const ConfigParameters& config);
template<class ConfigRecordType> void PrepareForTrainingOrTesting(const ConfigRecordType & config);
template<class ConfigRecordType> void PrepareForWriting(const ConfigRecordType & config);
template<class ConfigRecordType> void PrepareForSequenceTraining(const ConfigRecordType & config);

bool GetMinibatchToTrainOrTest(std::map<std::wstring, Matrix<ElemType>*>& matrices);
bool GetOneMinibatchToTrainOrTestDataBuffer(const std::map<std::wstring, Matrix<ElemType>*>& matrices);
bool GetMinibatchToWrite(std::map<std::wstring, Matrix<ElemType>*>& matrices);
Expand All @@ -136,7 +136,7 @@ class HTKMLFReader : public IDataReader<ElemType>
size_t NumberSlicesInEachRecurrentIter() { return m_numberOfuttsPerMinibatch ;}
void SetNbrSlicesEachRecurrentIter(const size_t) { };

void GetDataNamesFromConfig(const ConfigParameters& readerConfig, std::vector<std::wstring>& features, std::vector<std::wstring>& labels);
template<class ConfigRecordType> void GetDataNamesFromConfig(const ConfigRecordType & readerConfig, std::vector<std::wstring>& features, std::vector<std::wstring>& labels);


size_t ReadLabelToTargetMappingFile (const std::wstring& labelToTargetMappingFile, const std::wstring& labelListFile, std::vector<std::vector<ElemType>>& labelToTargetMap);
Expand Down
11 changes: 6 additions & 5 deletions DataReader/Kaldi2Reader/HTKMLFWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
//DATAWRITER_API IDataWriter* DataWriterFactory(void)

template<class ElemType>
void HTKMLFWriter<ElemType>::Init(const ConfigParameters& writerConfig)
template<class ConfigRecordType>
void HTKMLFWriter<ElemType>::InitFromConfig(const ConfigRecordType& writerConfig)
{
m_tempArray = nullptr;
m_tempArraySize = 0;
Expand All @@ -40,11 +41,11 @@ namespace Microsoft { namespace MSR { namespace CNTK {
size_t firstfilesonly = SIZE_MAX; // set to a lower value for testing


m_verbosity = writerConfig(L"verbosity", "2");
m_overflowValue = writerConfig(L"overflowValue", "50");
m_maxNumOverflowWarning = writerConfig(L"maxNumOverflowWarning", "10");
m_verbosity = writerConfig(L"verbosity", 2);
m_overflowValue = writerConfig(L"overflowValue", 50);
m_maxNumOverflowWarning = writerConfig(L"maxNumOverflowWarning", 10);

ConfigArray outputNames = writerConfig(L"outputNodeNames","");
vector<wstring> outputNames = writerConfig(L"outputNodeNames", ConfigRecordType::Array(stringargvector()));
if (outputNames.size()<1)
RuntimeError("writer needs at least one outputNodeName specified in config");
int counter = 0;
Expand Down
10 changes: 5 additions & 5 deletions DataReader/Kaldi2Reader/KaldiSequenceTrainingDerivative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
{
RuntimeError("Number of labels in logLikelihood does not match that"
" in the Kaldi model for utterance %S: %d v.s. %d\n",
uttID.c_str(), logLikelihood.GetNumRows(),
m_transModel.NumPdfs());
uttID.c_str(), (int)logLikelihood.GetNumRows(),
(int)m_transModel.NumPdfs());
}

// Reads alignment.
Expand All @@ -82,7 +82,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
{
RuntimeError("Number of frames in logLikelihood does not match that"
" in the alignment for utterance %S: %d v.s. %d\n",
uttID.c_str(), logLikelihood.GetNumCols(), ali.size());
uttID.c_str(), (int)logLikelihood.GetNumCols(), (int)ali.size());
}

// Reads denominator lattice.
Expand Down Expand Up @@ -184,8 +184,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
if (maxTime != logLikelihood.GetNumCols())
{
RuntimeError("Number of frames in the logLikelihood does not match"
" that in the denominator lattice for utterance %S\n",
uttID.c_str(), logLikelihood.GetNumRows(), maxTime);
" that in the denominator lattice for utterance %S: %d vs. %d\n",
uttID.c_str(), (int)logLikelihood.GetNumRows(), (int)maxTime);
}

std::vector<std::vector<kaldi::int32>> timeStateMap(
Expand Down
Loading

0 comments on commit b8eb51d

Please sign in to comment.