Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into linux-gcc
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhang87 committed Jul 18, 2015
2 parents 82c7d8c + fa3fb05 commit 8830b5a
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 10 deletions.
20 changes: 16 additions & 4 deletions MachineLearning/CNTK/MinibatchFetcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,35 @@ template<class ElemType>
class MinibatchFetcher
{
public:
MinibatchFetcher(IDataReader<ElemType>* trainSetDataReader, const std::map<std::wstring, Matrix<ElemType>*>* inputMatrices) :
MinibatchFetcher(IDataReader<ElemType>* trainSetDataReader,
std::map<std::wstring, Matrix<ElemType>*>* inputMatrices,
Matrix<ElemType>* sentenceBegin,
vector<MinibatchPackingFlag>* sentenceExistsBeginOrNoLabels)
:
m_reader(trainSetDataReader),
m_inputMatrices(inputMatrices)
m_inputMatrices(inputMatrices),
m_sentenceBegin(sentenceBegin),
m_sentenceExistsBeginOrNoLabels(sentenceExistsBeginOrNoLabels)
{
assert((m_sentenceBegin != nullptr) && (m_sentenceExistsBeginOrNoLabels != nullptr));
}

// This virtual dtor is necessary to allow invocation of derived dtors, which have some required synchronization points
virtual ~MinibatchFetcher() {}

virtual bool GetMinibatch()
{
return m_reader->GetMinibatch(*const_cast<std::map<std::wstring, Matrix<ElemType>*>*>(m_inputMatrices));
bool retVal = m_reader->GetMinibatch(*m_inputMatrices);
m_reader->SetSentenceSegBatch(*m_sentenceBegin, *m_sentenceExistsBeginOrNoLabels);

return retVal;
}

protected:
IDataReader<ElemType>* m_reader;
const std::map<std::wstring, Matrix<ElemType>*>* m_inputMatrices;
std::map<std::wstring, Matrix<ElemType>*>* m_inputMatrices;
Matrix<ElemType>* m_sentenceBegin;
vector<MinibatchPackingFlag>* m_sentenceExistsBeginOrNoLabels;
};

}}}
46 changes: 43 additions & 3 deletions MachineLearning/CNTK/MinibatchPrefetcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,16 @@ template<class ElemType>
class MinibatchPrefetcher : public MinibatchFetcher<ElemType>
{
public:
MinibatchPrefetcher(IDataReader<ElemType>* trainSetDataReader, const std::map<std::wstring, Matrix<ElemType>*>* inputMatrices) :
MinibatchFetcher<ElemType>(trainSetDataReader, inputMatrices),
using MinibatchFetcher<ElemType>::m_sentenceBegin;
using MinibatchFetcher<ElemType>::m_sentenceExistsBeginOrNoLabels;

MinibatchPrefetcher(IDataReader<ElemType>* trainSetDataReader,
std::map<std::wstring, Matrix<ElemType>*>* inputMatrices,
Matrix<ElemType>* sentenceBegin,
vector<MinibatchPackingFlag>* sentenceExistsBeginOrNoLabels) :
MinibatchFetcher<ElemType>(trainSetDataReader, inputMatrices, sentenceBegin, sentenceExistsBeginOrNoLabels),
m_prefetchSentenceBegin(nullptr),
m_prefetchSentenceExistsBeginOrNoLabels(nullptr),
m_isEpochReadingDone(false),
m_minibatchReady(false),
m_isTerminating(false)
Expand All @@ -42,6 +50,20 @@ class MinibatchPrefetcher : public MinibatchFetcher<ElemType>
iter->second->GetFormat());
}

if (sentenceBegin != nullptr)
{
m_prefetchSentenceBegin = new Matrix<ElemType>(sentenceBegin->GetNumRows(),
sentenceBegin->GetNumCols(),
sentenceBegin->GetDeviceId(),
sentenceBegin->GetMatrixType(),
sentenceBegin->GetFormat());
}

if (sentenceExistsBeginOrNoLabels != nullptr)
{
m_prefetchSentenceExistsBeginOrNoLabels = new vector<MinibatchPackingFlag>();
}

// Launch a worker thread
m_prefetchThread = std::thread([this]() { this->PrefetchWorker(); });
}
Expand All @@ -66,6 +88,9 @@ class MinibatchPrefetcher : public MinibatchFetcher<ElemType>
{
delete iter->second;
}

delete m_prefetchSentenceBegin;
delete m_prefetchSentenceExistsBeginOrNoLabels;
}

virtual bool GetMinibatch()
Expand Down Expand Up @@ -97,6 +122,17 @@ class MinibatchPrefetcher : public MinibatchFetcher<ElemType>
std::swap(*(iter->second), *m_prefetchInput[iter->first]);
}

if (m_sentenceBegin != nullptr)
{
assert(m_sentenceBegin->GetDeviceId() == m_prefetchSentenceBegin->GetDeviceId());
std::swap(*m_sentenceBegin, *m_prefetchSentenceBegin);
}

if (m_sentenceExistsBeginOrNoLabels != nullptr)
{
std::swap(*m_sentenceExistsBeginOrNoLabels, *m_prefetchSentenceExistsBeginOrNoLabels);
}

hasMoreEpochReading = true;
}

Expand Down Expand Up @@ -160,14 +196,18 @@ class MinibatchPrefetcher : public MinibatchFetcher<ElemType>
Matrix<ElemType>::SyncComputeBeforeRead(m_deviceId);

// Get the next minibatch and wait for it to be available on the device
bool isDone = !this->m_reader->GetMinibatch(const_cast<std::map<std::wstring, Matrix<ElemType>*>&>(m_prefetchInput));
bool isDone = !this->m_reader->GetMinibatch(m_prefetchInput);
this->m_reader->SetSentenceSegBatch(*m_prefetchSentenceBegin, *m_prefetchSentenceExistsBeginOrNoLabels);

Matrix<ElemType>::SyncPendingRead(m_deviceId);

return isDone;
}

// @TODO: We need to add support for a larger number of prefetch buffers, larger than 1
std::map<std::wstring, Matrix<ElemType>*> m_prefetchInput;
Matrix<ElemType>* m_prefetchSentenceBegin;
vector<MinibatchPackingFlag>* m_prefetchSentenceExistsBeginOrNoLabels;
std::thread m_prefetchThread;
std::mutex m_mutex;
std::condition_variable m_cv;
Expand Down
5 changes: 2 additions & 3 deletions MachineLearning/CNTK/SGD.h
Original file line number Diff line number Diff line change
Expand Up @@ -1710,8 +1710,8 @@ class SGD : ComputationNetworkHelper<ElemType>
AttemptUtteranceDerivativeFeatures(net, trainSetDataReader, FeatureNodes, inputMatrices);
std::unique_ptr<MinibatchFetcher<ElemType>> mbFetcher(
m_doPrefetchTrainingData ?
new MinibatchPrefetcher<ElemType>(trainSetDataReader, inputMatrices) :
new MinibatchFetcher<ElemType>(trainSetDataReader, inputMatrices));
new MinibatchPrefetcher<ElemType>(trainSetDataReader, inputMatrices, &(net.SentenceBoundary()), &(net.MinibatchPackingFlags())) :
new MinibatchFetcher<ElemType>(trainSetDataReader, inputMatrices, &(net.SentenceBoundary()), &(net.MinibatchPackingFlags())));

fprintf(stderr, "\nStarting minibatch loop, prefetching is: %s\n", m_doPrefetchTrainingData ? "ENABLED" : "DISABLED");

Expand All @@ -1735,7 +1735,6 @@ class SGD : ComputationNetworkHelper<ElemType>

net.SetActualMiniBatchSize(actualMBSize);
net.SetActualNbrSlicesInEachRecIter(trainSetDataReader->NumberSlicesInEachRecurrentIter());
trainSetDataReader->SetSentenceSegBatch(net.SentenceBoundary(), net.MinibatchPackingFlags());

#ifndef EVALDLL
if (m_doGradientCheck && GradientCheck(net, criterionNodes, learnableNodes, 0) == false)
Expand Down

0 comments on commit 8830b5a

Please sign in to comment.