Skip to content

Commit

Permalink
DataEnd() no longer takes a parameter since it was always the same
Browse files Browse the repository at this point in the history
  • Loading branch information
frankseide committed Feb 16, 2016
1 parent 1dce662 commit e4e2081
Show file tree
Hide file tree
Showing 27 changed files with 73 additions and 254 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ ypename LabelIdType, typename LabelType>& labelMapping) = 0;

\begin_layout Plain Layout

virtual bool DataEnd(EndDataType endDataType) = 0;
virtual bool DataEnd() = 0;
\end_layout

\begin_layout Plain Layout
Expand Down
4 changes: 2 additions & 2 deletions Source/Common/DataReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -356,11 +356,11 @@ bool DataReader<ElemType>::GetData(const std::wstring& sectionName, size_t numRe
}

template <class ElemType>
bool DataReader<ElemType>::DataEnd(EndDataType endDataType)
bool DataReader<ElemType>::DataEnd()
{
bool bRet = true;
for (size_t i = 0; i < m_ioNames.size(); i++)
bRet &= m_dataReaders[m_ioNames[i]]->DataEnd(endDataType);
bRet &= m_dataReaders[m_ioNames[i]]->DataEnd();
return bRet;
}

Expand Down
15 changes: 4 additions & 11 deletions Source/Common/Include/DataReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,6 @@ const size_t randomizeNone = 0;
// We use this constant as a stand in for the total number of frames in the dataset.
const size_t requestDataSize = randomizeAuto;

// TODO: Since all but one are unused, we can remove this enum and the parameter to DataEnd().
enum EndDataType
{
//endDataNull, // null values
//endDataEpoch, // end of epoch
//endDataSet, // end of dataset
endDataSentence, // end of sentence
};

// Data Reader interface
// implemented by DataReader and underlying classes
template <class ElemType>
Expand Down Expand Up @@ -118,7 +109,7 @@ class DATAREADER_API IDataReader
{
NOT_IMPLEMENTED;
}
virtual bool DataEnd(EndDataType)
virtual bool DataEnd()
{
NOT_IMPLEMENTED;
}
Expand Down Expand Up @@ -284,7 +275,9 @@ class DataReader : public IDataReader<ElemType>, protected Plugin, public Script
// returns: true if data remains to be read, false if the end of data was reached
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart = 0);

virtual bool DataEnd(EndDataType endDataType);
virtual bool DataEnd();
// TODO: The return value if this is never used except in loops where we do an &=. It is not clear whether that is a bug or intentionally prevents DataEnd() from being called.
// Once this is understood, we can change the return value to void.

// Gets a copy of the minibatch for the forward computation. This can be
// useful if some of the computation has to happen in the reader.
Expand Down
2 changes: 1 addition & 1 deletion Source/EvalDll/EvalReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ class EvalReader : public IDataReader<ElemType>
return false;
}

virtual bool DataEnd(EndDataType /*endDataType*/)
virtual bool DataEnd()
{
return m_currentRecord < m_recordCount;
}
Expand Down
28 changes: 3 additions & 25 deletions Source/Readers/BinaryReader/BinaryReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -416,32 +416,10 @@ bool BinaryReader<ElemType>::GetData(const std::wstring& sectionName, size_t num
}

template <class ElemType>
bool BinaryReader<ElemType>::DataEnd(EndDataType endDataType)
{
bool ret = false;
switch (endDataType)
{
//case endDataNull:
// assert(false);
// break;
//case endDataEpoch:
// ret = (m_mbStartSample / m_epochSize != m_epoch);
// break;
//case endDataSet:
//{
// // actual size is either what requested, or total number of samples read so far
// size_t actualmbsize = min(m_totalSamples, m_mbSize); // it may still return less if at end of sweep
// ret = CheckEndDataset(actualmbsize);
// break;
//}
case endDataSentence: // for fast reader each minibatch is considered a "sentence", so always true
ret = true;
break;
}
return ret;
}
bool BinaryReader<ElemType>::DataEnd() { return true; }

// instantiate all the combinations we expect to be used
template class BinaryReader<double>;
template class BinaryReader<float>;
} } }

}}}
3 changes: 2 additions & 1 deletion Source/Readers/BinaryReader/BinaryReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,8 @@ class BinaryReader : public IDataReader<ElemType>
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<typename BinaryReader<ElemType>::LabelIdType, typename BinaryReader<ElemType>::LabelType>& labelMapping);
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart = 0);

virtual bool DataEnd(EndDataType endDataType);
virtual bool DataEnd();

void SetRandomSeed(int)
{
NOT_IMPLEMENTED;
Expand Down
22 changes: 1 addition & 21 deletions Source/Readers/DSSMReader/DSSMReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -446,27 +446,7 @@ void DSSMReader<ElemType>::SetLabelMapping(const std::wstring& /*sectionName*/,
}

template <class ElemType>
bool DSSMReader<ElemType>::DataEnd(EndDataType endDataType)
{
bool ret = false;
switch (endDataType)
{
//case endDataNull:
// assert(false);
// break;
//case endDataEpoch:
// // ret = (m_mbStartSample / m_epochSize < m_epoch);
// ret = (m_readNextSample >= m_totalSamples);
// break;
//case endDataSet:
// ret = (m_readNextSample >= m_totalSamples);
// break;
case endDataSentence: // for fast reader each minibatch is considered a "sentence", so always true
ret = true;
break;
}
return ret;
}
bool DSSMReader<ElemType>::DataEnd() { return true; }

template <class ElemType>
DSSM_BinaryInput<ElemType>::DSSM_BinaryInput()
Expand Down
2 changes: 1 addition & 1 deletion Source/Readers/DSSMReader/DSSMReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ class DSSMReader : public IDataReader<ElemType>
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, typename LabelType>& labelMapping);
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart = 0);

virtual bool DataEnd(EndDataType endDataType);
virtual bool DataEnd();

void SetRandomSeed(int)
{
Expand Down
24 changes: 5 additions & 19 deletions Source/Readers/HTKMLFReader/HTKMLFReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1788,28 +1788,14 @@ bool HTKMLFReader<ElemType>::GetData(const std::wstring& /*sectionName*/, size_t
}

template <class ElemType>
bool HTKMLFReader<ElemType>::DataEnd(EndDataType endDataType)
bool HTKMLFReader<ElemType>::DataEnd()
{
// each minibatch is considered a "sentence"
// other datatypes not really supported...
// assert(endDataType == endDataSentence);
// for the truncated BPTT, we need to support check wether it's the end of data
bool ret = false;
switch (endDataType)
{
//case endDataNull:
//case endDataEpoch:
//case endDataSet:
// LogicError("DataEnd: does not support endDataTypes: endDataNull, endDataEpoch and endDataSet");
// break;
case endDataSentence:
if (m_truncated)
ret = m_sentenceEnd[0];
else
ret = true; // useless in current condition
break;
}
return ret;
if (m_truncated)
return m_sentenceEnd[0];
else
return true; // useless in current condition
}

template <class ElemType>
Expand Down
2 changes: 1 addition & 1 deletion Source/Readers/HTKMLFReader/HTKMLFReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ class HTKMLFReader : public IDataReader<ElemType>
virtual bool GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput, vector<size_t>& uids, vector<size_t>& boundaries, vector<size_t>& extrauttmap);
virtual bool GetHmmData(msra::asr::simplesenonehmm* hmm);

virtual bool DataEnd(EndDataType endDataType);
virtual bool DataEnd();
void CopyMBLayoutTo(MBLayoutPtr);
void SetSentenceEndInBatch(vector<size_t>& /*sentenceEnd*/);
void SetSentenceEnd(int /*actualMbSize*/){};
Expand Down
27 changes: 6 additions & 21 deletions Source/Readers/Kaldi2Reader/HTKMLFReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1999,29 +1999,14 @@ bool HTKMLFReader<ElemType>::GetData(const std::wstring& /*sectionName*/, size_t
}

template <class ElemType>
bool HTKMLFReader<ElemType>::DataEnd(EndDataType endDataType)
bool HTKMLFReader<ElemType>::DataEnd()
{
// each minibatch is considered a "sentence"
// other datatypes not really supported...
// assert(endDataType == endDataSentence);
// for the truncated BPTT, we need to support check wether it's the end of data
bool ret = false;
switch (endDataType)
{
//case endDataNull:
//case endDataEpoch:
//case endDataSet:
// throw std::logic_error("DataEnd: does not support endDataTypes: endDataNull, endDataEpoch and endDataSet");
// break;
case endDataSentence:
if (m_truncated)
ret = m_sentenceEnd[0];
else
ret = true; // useless in current condition
break;
}
return ret;
}
// for the truncated BPTT, we need to support check whether it's the end of data
if (m_truncated)
return m_sentenceEnd[0];
else
return true; // useless in current condition

template <class ElemType>
void HTKMLFReader<ElemType>::SetSentenceEndInBatch(vector<size_t>& sentenceEnd)
Expand Down
2 changes: 1 addition & 1 deletion Source/Readers/Kaldi2Reader/HTKMLFReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ class HTKMLFReader : public IDataReader<ElemType>
const Matrix<ElemType>& outputs,
const MBLayoutPtr pMBLayout);

virtual bool DataEnd(EndDataType endDataType);
virtual bool DataEnd();
void SetSentenceEndInBatch(vector<size_t>& /*sentenceEnd*/);
void SetSentenceEnd(int /*actualMbSize*/){};

Expand Down
45 changes: 7 additions & 38 deletions Source/Readers/LMSequenceReader/SequenceReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -899,25 +899,9 @@ void SequenceReader<ElemType>::StartMinibatchLoop(size_t mbSize, size_t epoch, s
}

template <class ElemType>
bool SequenceReader<ElemType>::DataEnd(EndDataType endDataType)
bool SequenceReader<ElemType>::DataEnd()
{
bool ret = false;
switch (endDataType)
{
//case endDataNull:
// assert(false);
// break;
//case endDataEpoch:
// ret = m_sequence.size() > 0 && m_mbStartSample > m_sequence[m_sequence.size() - 1];
// break;
//case endDataSet:
// ret = !EnsureDataAvailable(m_mbStartSample);
// break;
case endDataSentence: // for fast reader each minibatch is considered a "sentence", so always true
ret = SentenceEnd();
break;
}
return ret;
return SentenceEnd();
}

template <class ElemType>
Expand Down Expand Up @@ -1998,27 +1982,12 @@ void BatchSequenceReader<ElemType>::SetSentenceSegBatch(vector<size_t>& sentence

// note: DataEnd() must be called for each minibatch in order to propagate mSentenceEnd to mProcessed[]
template <class ElemType>
bool BatchSequenceReader<ElemType>::DataEnd(EndDataType endDataType)
bool BatchSequenceReader<ElemType>::DataEnd()
{
//size_t firstPosInSentence;
bool ret = false;
switch (endDataType)
{
//case endDataNull:
// assert(false);
// break;
//case endDataEpoch:
//case endDataSet:
// ret = !GetMinibatchData(firstPosInSentence); // TODO: What does this do? Check whether there is more data?
// break;
case endDataSentence: // for fast reader each minibatch is considered a "sentence", so always true
if (mSentenceEnd)
for (auto seq : mToProcess)
mProcessed[seq] = true;
ret = mSentenceEnd;
break;
}
return ret;
if (mSentenceEnd)
for (auto seq : mToProcess)
mProcessed[seq] = true;
return mSentenceEnd;
}

// fill the labels (from m_labelIdData)
Expand Down
4 changes: 2 additions & 2 deletions Source/Readers/LMSequenceReader/SequenceReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ class SequenceReader : public IDataReader<ElemType>
virtual void SetLabelMapping(const std::wstring& sectionName, const std::map<LabelIdType, LabelType>& labelMapping);
virtual bool GetData(const std::wstring& sectionName, size_t numRecords, void* data, size_t& dataBufferSize, size_t recordStart = 0);

virtual bool DataEnd(EndDataType endDataType);
virtual bool DataEnd();

//int GetSentenceEndIdFromOutputLabel() { return -1; };
};
Expand Down Expand Up @@ -411,7 +411,7 @@ class BatchSequenceReader : public SequenceReader<ElemType>
public:
void StartMinibatchLoop(size_t mbSize, size_t epoch, size_t requestedEpochSamples = requestDataSize) override;
bool GetMinibatch(std::map<std::wstring, Matrix<ElemType>*>& matrices) override;
bool DataEnd(EndDataType endDataType) override;
bool DataEnd() override;

void CopyMBLayoutTo(MBLayoutPtr pMBLayout) { assert(mToProcess.size() == m_pMBLayout->GetNumParallelSequences()); pMBLayout->CopyFrom(m_pMBLayout); }
size_t GetNumParallelSequences() override { return mToProcess.size(); } // TODO: or get it from MBLayout? Can this ever be called before GetMinibatch()?
Expand Down
45 changes: 16 additions & 29 deletions Source/Readers/LUSequenceReader/LUSequenceReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1003,32 +1003,18 @@ void BatchLUSequenceReader<ElemType>::SetSentenceBegin(int wrd, int pos, int /*a
}

template <class ElemType>
bool BatchLUSequenceReader<ElemType>::DataEnd(EndDataType endDataType)
bool BatchLUSequenceReader<ElemType>::DataEnd()
{
bool ret = false;
switch (endDataType)
if (mSentenceEndAt.size() != mToProcess.size())
LogicError("DataEnd: Sentence ending vector size %d and the toprocess vector size %d should be the same.", (int)mSentenceEndAt.size(), (int)mToProcess.size());
for (size_t i = 0; i < mToProcess.size(); i++)
{
//case endDataNull:
// assert(false);
// break;
//case endDataEpoch:
//case endDataSet:
// ret = !EnsureDataAvailable(m_mbStartSample);
// break;
case endDataSentence: // for fast reader each minibatch is considered a "sentence", so always true
if (mSentenceEndAt.size() != mToProcess.size())
LogicError("DataEnd: Sentence ending vector size %d and the toprocess vector size %d should be the same.", (int) mSentenceEndAt.size(), (int) mToProcess.size());
ret = true;
for (size_t i = 0; i < mToProcess.size(); i++)
{
if (mSentenceEndAt[i] == NO_INPUT)
LogicError("BatchLUSequenceReader: Minibatch should be large enough to accomodate the longest sentence.");
size_t k = mToProcess[i];
mProcessed[k] = true;
}
break;
if (mSentenceEndAt[i] == NO_INPUT)
LogicError("BatchLUSequenceReader: Minibatch should be large enough to accomodate the longest sentence.");
size_t k = mToProcess[i];
mProcessed[k] = true;
}
return ret;
return true;
}

template <class ElemType>
Expand Down Expand Up @@ -1317,17 +1303,18 @@ int MultiIOBatchLUSequenceReader<ElemType>::GetSentenceEndIdFromOutputLabel()
#endif

template <class ElemType>
bool MultiIOBatchLUSequenceReader<ElemType>::DataEnd(EndDataType endDataType)
bool MultiIOBatchLUSequenceReader<ElemType>::DataEnd()
{
bool ret = true;
for (typename map<wstring, BatchLUSequenceReader<ElemType>*>::iterator p = mReader.begin(); p != mReader.end(); p++)
{
ret |= (p->second)->DataEnd(endDataType);
}
for (auto& iter : mReader)
ret &= iter.second->DataEnd();
// ###### BREAKING ######
// The above was an |= which did not make sense. I follow the other examples where we have an &= here. Hope that is correct.
// ###### BREAKING ######
return ret;
}

/// history is shared
// history is shared
template <class ElemType>
bool MultiIOBatchLUSequenceReader<ElemType>::GetProposalObs(std::map<std::wstring, Matrix<ElemType>*>& matrices, const size_t tidx, vector<size_t>& history)
{
Expand Down
Loading

0 comments on commit e4e2081

Please sign in to comment.