forked from microsoft/CNTK
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ComputationNetwork now has a second layout, pMBNoLayout, which matche…
…s pMBLayout in #sequences but is otherwise empty, and used for nodes that do not require sequential processing; m_samplesInRecurrentStep now gone from ComputationNode, if needed, the value is determined from pMBLayout--yay! One more down; Matrix::SetValue() now happily accepts empty matrices (no reason why it should not); (made gc happy again)
- Loading branch information
1 parent
7e10b57
commit 93d93e0
Showing
5 changed files
with
38 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -75,7 +75,7 @@ class ComputationNetwork : public ScriptableObjects::Object, public ScriptableOb | |
// ----------------------------------------------------------------------- | ||
|
||
ComputationNetwork(DEVICEID_TYPE deviceId = AUTOPLACEMATRIX) : | ||
m_deviceId(deviceId), m_pMBLayout(make_shared<MBLayout>()) | ||
m_deviceId(deviceId), m_pMBLayout(make_shared<MBLayout>()), m_pMBNoLayout(make_shared<MBLayout>()) | ||
{ | ||
m_randomSeedOffset = 0; | ||
m_actMiniBSize = 0; | ||
|
@@ -595,12 +595,10 @@ class ComputationNetwork : public ScriptableObjects::Object, public ScriptableOb | |
// TODO: rename to ForwardProp()? To make it very clear? | ||
void Evaluate(const ComputationNodeBasePtr rootNode) | ||
{ | ||
// checks that will disappear once we complete the refactoring. If this passes for a while, we will eliminate one | ||
// If this fails, comment this out (it is safe) and tell [email protected]. | ||
if (GetNumParallelSequences() != m_pMBLayout->GetNumParallelSequences()) | ||
LogicError("Evaluate: detected that m_nbrSlicesInEachRecurrentIteration != m_pMBLayout->GetNumParallelSequences()"); | ||
if (m_pMBLayout->GetNumTimeSteps() != m_pMBLayout->GetSize()) | ||
LogicError("Evaluate: detected that m_pMBLayout->GetNumTimeSteps() != m_pMBLayout->GetSize()"); | ||
// We have a matching layout structure that matches pMBLayout in number of sequences while not having any flags set. | ||
// This is used for nodes that do not need recurrent processing, but can be done in batch. | ||
// TODO: Does it harm if we have flags, for those that can be done in batch? I.e. why don't we just always provide flags? | ||
m_pMBNoLayout->Resize(m_pMBLayout->GetNumParallelSequences(), 0); // TODO: this is not nice, but we currently have no trigger to detect changes in layout | ||
|
||
// prepare to compute with the subnetwork that this rootNode depends on, including | ||
// - auto-detecting recurrent loops | ||
|
@@ -623,10 +621,11 @@ class ComputationNetwork : public ScriptableObjects::Object, public ScriptableOb | |
// TODO: in the future, these will be different on different nodes | ||
for (auto nodeIter = allNodes.begin(); nodeIter != allNodes.end(); nodeIter++) | ||
{ | ||
// TODO: nbrSlices set once to the same value for all nodes each evaluation--is it ever changed later? | ||
(*nodeIter)->SetNumParallelSequences(GetNumParallelSequences()); | ||
if ((*nodeIter)->ReqMultiSeqHandling()) | ||
(*nodeIter)->ResetBound(m_pMBLayout); | ||
else | ||
(*nodeIter)->ResetBound(m_pMBNoLayout); | ||
(*nodeIter)->VerifyNumParallelSequences(GetNumParallelSequences()); | ||
} | ||
|
||
for (auto nodeIter = allNodes.begin(); nodeIter != allNodes.end(); nodeIter++) | ||
|
@@ -723,7 +722,7 @@ class ComputationNetwork : public ScriptableObjects::Object, public ScriptableOb | |
{ | ||
for (auto nodeIter = recurrentNodes.rbegin(); nodeIter != recurrentNodes.rend(); ++nodeIter) | ||
{ | ||
(*nodeIter)->SetNumParallelSequences(GetNumParallelSequences()); // TODO: move to FrameRange object | ||
(*nodeIter)->VerifyNumParallelSequences(GetNumParallelSequences()); // TODO: move to FrameRange object | ||
(*nodeIter)->ComputeGradientForChildren(timeIndex); | ||
} | ||
} | ||
|
@@ -734,7 +733,7 @@ class ComputationNetwork : public ScriptableObjects::Object, public ScriptableOb | |
{ | ||
for (auto nodeIter = recurrentNodes.rbegin(); nodeIter != recurrentNodes.rend(); ++nodeIter) | ||
{ | ||
(*nodeIter)->SetNumParallelSequences(GetNumParallelSequences()); | ||
(*nodeIter)->VerifyNumParallelSequences(GetNumParallelSequences()); | ||
(*nodeIter)->ComputeGradientForChildren(timeIndex); | ||
} | ||
} | ||
|
@@ -1580,6 +1579,7 @@ class ComputationNetwork : public ScriptableObjects::Object, public ScriptableOb | |
// used for sentence boundary information passed from reader to reset RNN state | ||
// specify how the minibatch is packed for each sample | ||
MBLayoutPtr m_pMBLayout; | ||
MBLayoutPtr m_pMBNoLayout; // this one is a dummy, passed when no layout is available/should be used | ||
|
||
int m_actMiniBSize; | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters