Skip to content

Commit

Permalink
ComputationNetwork now has a second layout, pMBNoLayout, which matche…
Browse files Browse the repository at this point in the history
…s pMBLayout in #sequences but is otherwise empty, and used for nodes that do not require sequential processing;

m_samplesInRecurrentStep now gone from ComputationNode, if needed, the value is determined from pMBLayout--yay! One more down;
Matrix::SetValue() now happily accepts empty matrices (no reason why it should not);
(made gc happy again)
  • Loading branch information
frankseide committed Sep 19, 2015
1 parent 7e10b57 commit 93d93e0
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 31 deletions.
22 changes: 11 additions & 11 deletions MachineLearning/CNTKComputationNetworkLib/ComputationNetwork.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class ComputationNetwork : public ScriptableObjects::Object, public ScriptableOb
// -----------------------------------------------------------------------

ComputationNetwork(DEVICEID_TYPE deviceId = AUTOPLACEMATRIX) :
m_deviceId(deviceId), m_pMBLayout(make_shared<MBLayout>())
m_deviceId(deviceId), m_pMBLayout(make_shared<MBLayout>()), m_pMBNoLayout(make_shared<MBLayout>())
{
m_randomSeedOffset = 0;
m_actMiniBSize = 0;
Expand Down Expand Up @@ -595,12 +595,10 @@ class ComputationNetwork : public ScriptableObjects::Object, public ScriptableOb
// TODO: rename to ForwardProp()? To make it very clear?
void Evaluate(const ComputationNodeBasePtr rootNode)
{
// checks that will disappear once we complete the refactoring. If this passes for a while, we will eliminate one
// If this fails, comment this out (it is safe) and tell [email protected].
if (GetNumParallelSequences() != m_pMBLayout->GetNumParallelSequences())
LogicError("Evaluate: detected that m_nbrSlicesInEachRecurrentIteration != m_pMBLayout->GetNumParallelSequences()");
if (m_pMBLayout->GetNumTimeSteps() != m_pMBLayout->GetSize())
LogicError("Evaluate: detected that m_pMBLayout->GetNumTimeSteps() != m_pMBLayout->GetSize()");
// We have a matching layout structure that matches pMBLayout in number of sequences while not having any flags set.
// This is used for nodes that do not need recurrent processing, but can be done in batch.
// TODO: Does it harm if we have flags, for those that can be done in batch? I.e. why don't we just always provide flags?
m_pMBNoLayout->Resize(m_pMBLayout->GetNumParallelSequences(), 0); // TODO: this is not nice, but we currently have no trigger to detect changes in layout

// prepare to compute with the subnetwork that this rootNode depends on, including
// - auto-detecting recurrent loops
Expand All @@ -623,10 +621,11 @@ class ComputationNetwork : public ScriptableObjects::Object, public ScriptableOb
// TODO: in the future, these will be different on different nodes
for (auto nodeIter = allNodes.begin(); nodeIter != allNodes.end(); nodeIter++)
{
// TODO: nbrSlices set once to the same value for all nodes each evaluation--is it ever changed later?
(*nodeIter)->SetNumParallelSequences(GetNumParallelSequences());
if ((*nodeIter)->ReqMultiSeqHandling())
(*nodeIter)->ResetBound(m_pMBLayout);
else
(*nodeIter)->ResetBound(m_pMBNoLayout);
(*nodeIter)->VerifyNumParallelSequences(GetNumParallelSequences());
}

for (auto nodeIter = allNodes.begin(); nodeIter != allNodes.end(); nodeIter++)
Expand Down Expand Up @@ -723,7 +722,7 @@ class ComputationNetwork : public ScriptableObjects::Object, public ScriptableOb
{
for (auto nodeIter = recurrentNodes.rbegin(); nodeIter != recurrentNodes.rend(); ++nodeIter)
{
(*nodeIter)->SetNumParallelSequences(GetNumParallelSequences()); // TODO: move to FrameRange object
(*nodeIter)->VerifyNumParallelSequences(GetNumParallelSequences()); // TODO: move to FrameRange object
(*nodeIter)->ComputeGradientForChildren(timeIndex);
}
}
Expand All @@ -734,7 +733,7 @@ class ComputationNetwork : public ScriptableObjects::Object, public ScriptableOb
{
for (auto nodeIter = recurrentNodes.rbegin(); nodeIter != recurrentNodes.rend(); ++nodeIter)
{
(*nodeIter)->SetNumParallelSequences(GetNumParallelSequences());
(*nodeIter)->VerifyNumParallelSequences(GetNumParallelSequences());
(*nodeIter)->ComputeGradientForChildren(timeIndex);
}
}
Expand Down Expand Up @@ -1580,6 +1579,7 @@ class ComputationNetwork : public ScriptableObjects::Object, public ScriptableOb
// used for sentence boundary information passed from reader to reset RNN state
// specify how the minibatch is packed for each sample
MBLayoutPtr m_pMBLayout;
MBLayoutPtr m_pMBNoLayout; // this one is a dummy, passed when no layout is available/should be used

int m_actMiniBSize;

Expand Down
23 changes: 10 additions & 13 deletions MachineLearning/CNTKComputationNetworkLib/ComputationNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,23 +246,20 @@ namespace Microsoft { namespace MSR { namespace CNTK {
return m_loopId;
}

// TODO: these two will disappear once the information is correctly held in a FrameRange record
// This is called at 3 places; two are directly before ComputeGradientForChildren().
void SetNumParallelSequences(size_t bsz)
// temporary function that is called to verify stuff is called as I think it is. Delete if this does not fire for a while.
void VerifyNumParallelSequences(size_t bsz)
{
m_samplesInRecurrentStep = bsz;
//m_samplesInRecurrentStep = bsz;
if (bsz != m_pMBLayout->GetNumParallelSequences())
LogicError("VerifyNumParallelSequences: value inconsistent with MB layout");
}

// Note: only used in one place, SimpleEvaluator.h PreComputeActivityAtTime().
// The member is, however, read out at 284 places inside nodes,
// most of the time as
// This is used at 284 places inside nodes, most of the time as
// FrameSlice(frameRange/*TODO: delete the next two parameters*/, frameRange.t() * GetNumParallelSequences(), GetNumParallelSequences())
// This expression will be turned into a function call to right here, so that we compute this only at one place
// and can also handle the full-minibatch case.
// Let us try to get this member out of this class altogether; it belongs elsewhere.
size_t GetNumParallelSequences() const
{
return m_samplesInRecurrentStep;
//return m_samplesInRecurrentStep;
return m_pMBLayout->GetNumParallelSequences();
}

int64_t UpdateEvalTimeStamp()
Expand Down Expand Up @@ -682,7 +679,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
/// the order in reverse graph.
int m_visitedOrder;
int m_index;
int m_lowlink;
int m_lowlink; // TODO: comment this, as it is not obvious
bool m_visited;
bool m_inStack;
int m_indexInLoop;
Expand Down Expand Up @@ -1281,7 +1278,7 @@ protected: \
using Base::m_visitedOrder; using Base::m_index; using Base::m_lowlink; using Base::m_visited; using Base::m_inStack; \
using Base::m_indexInLoop; \
using Base::m_pMBLayout; \
using Base::m_reqMultiSeqHandling; using Base::UseCustomizedMultiSeqHandling; \
using Base::m_reqMultiSeqHandling; using Base::UseCustomizedMultiSeqHandling; using Base::GetNumParallelSequences; \
using Base::m_children; using Base::m_deviceId; using Base::m_evalTimeStamp; using Base::m_functionValues; using Base::m_gradientValues; \
using Base::m_inputChannels; using Base::m_inputHeight; using Base::m_inputWidth; using Base::m_needGradient; using Base::m_nodeName; \
using Base::m_outputChannels; using Base::m_outputHeight; using Base::m_outputWidth; using Base::s_constOnes; using Base::s_timeStampCounter; \
Expand Down
10 changes: 6 additions & 4 deletions Math/Math/Matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1005,8 +1005,9 @@ namespace Microsoft { namespace MSR { namespace CNTK {
template<class ElemType>
void Matrix<ElemType>::SetValue(const ElemType v)
{
if (IsEmpty())
throw std::logic_error("SetValue: Matrix is empty.");
if (IsEmpty()) // if empty then we are done
return;
//throw std::logic_error("SetValue: Matrix is empty.");

DISPATCH_MATRIX_ON_FLAG(this,
this,
Expand All @@ -1020,8 +1021,9 @@ namespace Microsoft { namespace MSR { namespace CNTK {
template<class ElemType>
void Matrix<ElemType>::SetValue(const DeviceBoundNumber<ElemType>& db_number)
{
if (IsEmpty())
throw std::logic_error("SetValue: Matrix is empty.");
if (IsEmpty()) // if empty then we are done
return;
//throw std::logic_error("SetValue: Matrix is empty.");

DISPATCH_MATRIX_ON_FLAG(this,
this,
Expand Down
7 changes: 5 additions & 2 deletions Math/Math/Matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -602,10 +602,13 @@ namespace Microsoft { namespace MSR { namespace CNTK {

// these accessors were for now just collected from actual usage; need to be cleaned up once this compiles again
size_t GetNumTimeSteps() const { validate(); return m_sentenceBoundaryFlags.GetNumCols(); }
size_t GetNumParallelSequences() const { return IsAllNone() ? 1 : m_sentenceBoundaryFlags.GetNumRows(); } // 1 stream if no matrix
size_t GetNumParallelSequences() const { return (m_sentenceBoundaryFlags.GetNumRows() == 0) ? 1 : m_sentenceBoundaryFlags.GetNumRows(); } // 1 stream if no matrix
size_t GetSize() const { validate(); return m_minibatchPackingFlags.size(); }
// ^^ TODO: add a check whether Size() == GetNumTimeSteps(); it really should, unless I misunderstood

// if we have no matrix/vector, this means no frame has any flag set
// We still can have a number of rows in this case.
bool IsAllNone() const { validate(); return m_minibatchPackingFlags.empty(); }

#if 0 // we have this pattern often:
// TODO: mbSize and #slices must also move into MBLayout
evalnet->SetActualMiniBatchSize(mbSize);
Expand Down
7 changes: 6 additions & 1 deletion Tests/Speech/README.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,13 @@ bin/cntk configFile=Tests/Speech/QuickE2E/cntk.config RunDir=Tests/Speech/RunDir
WORKING DIR: $(SolutionDir)Tests\Speech\Data
COMMAND: configFile=$(SolutionDir)Tests\Speech\LSTM\cntk.config stderr=$(SolutionDir)Tests\Speech\RunDir\LSTM\models\cntkSpeech.dnn.log RunDir=$(SolutionDir)Tests\Speech\RunDir\LSTM NdlDir=$(SolutionDir)Tests\Speech\LSTM DataDir=$(SolutionDir)Tests\Speech\Data DeviceId=Auto

--- MNIST:

WORKING DIR: $(SolutionDir)ExampleSetups\Image\MNIST
COMMAND: configFile=02_Conv.config configName=02_Conv


Simple test
-----------

../build/debug/bin/cntk configFile=/home/cbasoglu/src/cntk/.run-linux/Simple.conf
COMMAND: configFile=$(SolutionDir)Demos\Simple\Simple.config stderr=$(SolutionDir)Demos\Simple\RunDir\Simple.config.log RootDir=$(SolutionDir) DeviceNumber=-1

0 comments on commit 93d93e0

Please sign in to comment.