Skip to content

Commit

Permalink
InputNodes() now skips inputs that are only reachable through PreComp…
Browse files Browse the repository at this point in the history
…uteNodes that have already been computed, addressing Issue microsoft#65;

cleaned up some unnecessary NULL checks before delete
  • Loading branch information
frankseide committed Feb 5, 2016
1 parent 2bbee29 commit 2f9a48c
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 168 deletions.
6 changes: 0 additions & 6 deletions Source/CNTK/NetworkDescriptionLanguage.h
Original file line number Diff line number Diff line change
Expand Up @@ -470,9 +470,7 @@ class NDLScript : public ConfigParser
{
// need to free all the child nodes attached to this script node
for (NDLNode<ElemType>* node : m_children)
{
delete node;
}
m_children.clear();
}

Expand Down Expand Up @@ -576,14 +574,10 @@ class NDLScript : public ConfigParser
{

for (NDLNode<ElemType>* node : m_children)
{
delete node;
}
m_children.clear();
for (NDLNode<ElemType>* node : m_script)
{
delete node;
}
m_script.clear();

m_symbols.clear();
Expand Down
108 changes: 50 additions & 58 deletions Source/ComputationNetworkLib/ComputationNetwork.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -383,82 +383,74 @@ bool ComputationNetwork::IsTypicalCriterionNode(ComputationNodeBasePtr nodePtr)
return false;
}

template <class N>
void ComputationNetwork::GetNodesRequiringX(list<ComputationNodeBasePtr>& nodesRequiringX, const ComputationNodeBasePtr& rootNode, bool checkComputed)
// return list of nodes that require precomputation and not precomputed yet
list<ComputationNodeBasePtr> ComputationNetwork::GetNodesRequiringPreComputation(const ComputationNodeBasePtr& rootNode, bool checkComputed)
{
if (!rootNode) // find nodes from all available nodes
list<ComputationNodeBasePtr> nodes;
for (const auto& node : GetEvalOrder(rootNode))
{
for (const auto& nodep : m_nameToNodeMap)
auto pcnode = dynamic_pointer_cast<IPreComputeNode>(node);
if (pcnode)
{
auto node = dynamic_pointer_cast<N>(nodep.second);
if (node)
{
assert(node->RequiresPreCompute());
if (!checkComputed || !node->HasComputed())
nodesRequiringX.push_back(node);
}
assert(node->RequiresPreCompute());
if (!checkComputed || !pcnode->HasComputed())
nodes.push_back(node);
}
}
else // or for calculating a specific node
{
for (const auto& nodei : GetEvalOrder(rootNode))
{
auto node = dynamic_pointer_cast<N>(nodei);
if (node)
{
assert(node->RequiresPreCompute());
if (!checkComputed || !node->HasComputed())
nodesRequiringX.push_back(node);
}
}
}
nodesRequiringX.unique();
}

// return list of nodes that require precomputation and not precomputed yet
list<ComputationNodeBasePtr> ComputationNetwork::GetNodesRequiringPreComputation(const ComputationNodeBasePtr& rootNode, bool checkComputed)
{
list<ComputationNodeBasePtr> nodesRequiringX;
GetNodesRequiringX<PreComputedNodeBase<float>>(nodesRequiringX, rootNode, checkComputed);
GetNodesRequiringX<PreComputedNodeBase<double>>(nodesRequiringX, rootNode, checkComputed);
return nodesRequiringX;
return nodes;
}

// create the m_inputValues[] and m_learnableParameters[] lists
// This enumerates all leaves reachable from rootNode.
// Leaves are:
// - inputs
// - learnable parameters
// It does not traverse disabled ones, i.e.
// - inputs that are only reachable through PrecomputeNodes that have completed computation
// - learnable parameters that are constants
void ComputationNetwork::CollectInputAndLearnableParameters(const ComputationNodeBasePtr& rootNode)
{
assert(m_inputValues.find(rootNode) == m_inputValues.end()); // this function must only be called once
assert(m_learnableParameters.find(rootNode) == m_learnableParameters.end());

const list<ComputationNodeBasePtr>& nodes = GetEvalOrder(rootNode);
// gather the lists
set<ComputationNodeBasePtr> visited;
list<ComputationNodeBasePtr> inputs, learnableParameters;
if (rootNode)
CollectInputAndLearnableParametersRec(rootNode, visited, inputs, learnableParameters);
else
for (const auto& root : m_allRoots)
CollectInputAndLearnableParametersRec(root, visited, inputs, learnableParameters);

// collect input values for given root
// Note: This will not return nodes that are reached through a PrecomputeNode that has already been computed.
list<ComputationNodeBasePtr> inputs;
for (const auto& node : nodes)
// sort learnable parameters by name so that we get consistent order when load it from saved file
learnableParameters.sort([](const ComputationNodeBasePtr& a, const ComputationNodeBasePtr& b)
{
if (node->OperationName() == OperationNameOf(InputValue) || node->OperationName() == OperationNameOf(SparseInputValue))
inputs.push_back(node);
}
m_inputValues[rootNode] = inputs;
return a->NodeName() < b->NodeName();
});

m_inputValues[rootNode] = move(inputs);
m_learnableParameters[rootNode] = move(learnableParameters);
}

// instead of collecting the nodes themselves, collect the names (they will be sorted below)
list<wstring> learnableParameterNames;
for (auto nodeIter = nodes.begin(); nodeIter != nodes.end(); nodeIter++)
void ComputationNetwork::CollectInputAndLearnableParametersRec(const ComputationNodeBasePtr& node, set<ComputationNodeBasePtr>& visited, list<ComputationNodeBasePtr>& inputs, list<ComputationNodeBasePtr>& learnableParameters)
{
if (visited.find(node) != visited.end()) // allready got this one
return;
else if (node->OperationName() == OperationNameOf(InputValue) || node->OperationName() == OperationNameOf(SparseInputValue))
inputs.push_back(node);
else if (node->OperationName() == OperationNameOf(LearnableParameter) && node->IsParameterUpdateRequired())
learnableParameters.push_back(node);
else
{
ComputationNodeBasePtr node = *nodeIter;
if (node->OperationName() == OperationNameOf(LearnableParameter) && node->IsParameterUpdateRequired())
learnableParameterNames.push_back(node->NodeName());
// PreComputeNodes that are already done should not be traversed
auto pcnode = dynamic_pointer_cast<IPreComputeNode>(node);
if (pcnode && pcnode->HasComputed())
return;
// recurse
visited.insert(node);
for (const auto & input : node->GetInputs())
CollectInputAndLearnableParametersRec(input, visited, inputs, learnableParameters);
}

// sort names so that we get consistent order when load it from saved file
learnableParameterNames.sort();

// now collect the actual nodes in the sort order of their node names
list<ComputationNodeBasePtr> learnableParameters;
for (const auto& nodeNameIter : learnableParameterNames)
learnableParameters.push_back(GetNodeFromName(nodeNameIter));
m_learnableParameters[rootNode] = move(learnableParameters);
}

template <class ElemType>
Expand Down
6 changes: 2 additions & 4 deletions Source/ComputationNetworkLib/ComputationNetwork.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <regex>
#include <chrono>
#include <unordered_map>
#include <set>

namespace Microsoft { namespace MSR { namespace CNTK {

Expand Down Expand Up @@ -164,6 +165,7 @@ class ComputationNetwork : public ScriptableObjects::Object, public ScriptableOb
private:
void DetermineSetOfAllRoots();
void CollectInputAndLearnableParameters(const ComputationNodeBasePtr& rootNode);
void CollectInputAndLearnableParametersRec(const ComputationNodeBasePtr& node, set<ComputationNodeBasePtr>& visited, list<ComputationNodeBasePtr>& inputs, list<ComputationNodeBasePtr>& learnableParameters);
bool IsCompiled() const { return m_isCompiled; }
void VerifyIsCompiled(const char* where) const;
public:
Expand Down Expand Up @@ -506,10 +508,6 @@ class ComputationNetwork : public ScriptableObjects::Object, public ScriptableOb
return nodesWithType;
}

private:
template <class N>
void GetNodesRequiringX(std::list<ComputationNodeBasePtr>& nodesRequirePreComputation, const ComputationNodeBasePtr& rootNode, bool checkComputed);

public:
// return list of nodes that require precomputation and not precomputed yet
std::list<ComputationNodeBasePtr> GetNodesRequiringPreComputation(const ComputationNodeBasePtr& rootNode = nullptr, bool checkComputed = true);
Expand Down
16 changes: 15 additions & 1 deletion Source/ComputationNetworkLib/ComputationNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -1673,11 +1673,25 @@ class LateAttachingNode : public N, public ILateAttachingNode
};

// =======================================================================
// IRecurrentNode -- helper wrapper class for ComputationNodes that can be recurrent
// IRecurrentNode -- interface implemented by ComputationNodes that can be recurrent
// =======================================================================

struct IRecurrentNode { virtual int GetRecurrenceSteppingDirection() const = 0; };

// =======================================================================
// PreComputedNodeBase -- interface implemented by ComputationNodes that precompute
// TODO: We can use this interface in more places.
// =======================================================================

struct IPreComputeNode
{
// check whether node has already undergone precomputation
virtual bool HasComputed() const = 0;
// call this with 'false' at start and with 'true' at end
// This is used for resetting and updating from accumulators.
virtual void MarkComputed(const bool hasComputed) = 0;
};

// =======================================================================
// helper macro to ease access to base members in presence of C++ two-phase name lookup
// =======================================================================
Expand Down
6 changes: 3 additions & 3 deletions Source/ComputationNetworkLib/PreComputeNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// -----------------------------------------------------------------------

template <class ElemType>
class PreComputedNodeBase : public ComputationNodeNonLooping /*ComputationNode*/<ElemType>
class PreComputedNodeBase : public ComputationNodeNonLooping /*ComputationNode*/<ElemType>, public IPreComputeNode
{
typedef ComputationNodeNonLooping<ElemType> Base;
UsingComputationNodeMembers;
Expand All @@ -40,14 +40,14 @@ class PreComputedNodeBase : public ComputationNodeNonLooping /*ComputationNode*/
// interface through which this node is operated on are these two functions

// check whether node has already undergone precomputation
virtual bool HasComputed() const
virtual bool /*IPreComputeNode::*/ HasComputed() const override
{
return m_hasComputed;
}

// call this with 'false' at start and with 'true' at end
// This is used for resetting and updating from accumulators.
virtual void MarkComputed(const bool hasComputed)
virtual void /*IPreComputeNode::*/ MarkComputed(const bool hasComputed) override
{
m_hasComputed = hasComputed;
CreateMatrixIfNull(m_value);
Expand Down
84 changes: 13 additions & 71 deletions Source/Readers/Kaldi2Reader/HTKMLFReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -673,87 +673,29 @@ void HTKMLFReader<ElemType>::PrepareForWriting(const ConfigRecordType& readerCon
template <class ElemType>
HTKMLFReader<ElemType>::~HTKMLFReader()
{
if (m_mbiter != NULL)
{
delete m_mbiter;
m_mbiter = NULL;
}
if (m_frameSource != NULL)
{
delete m_frameSource;
m_frameSource = NULL;
}
if (m_lattices != NULL)
{
delete m_lattices;
m_lattices = NULL;
}
if (m_seqTrainDeriv != NULL)
{
delete m_seqTrainDeriv;
m_seqTrainDeriv = NULL;
}
if (m_uttDerivBuffer != NULL)
{
delete m_uttDerivBuffer;
m_uttDerivBuffer = NULL;
}
delete m_mbiter;
delete m_frameSource;
delete m_lattices;
delete m_seqTrainDeriv;
delete m_uttDerivBuffer;

if (!m_featuresBufferMultiIO.empty())
{
foreach_index (i, m_featuresBufferMultiIO)
{
if (m_featuresBufferMultiIO[i] != NULL)
{
delete[] m_featuresBufferMultiIO[i];
m_featuresBufferMultiIO[i] = NULL;
}
}
}
foreach_index(i, m_featuresBufferMultiIO)
delete[] m_featuresBufferMultiIO[i];

if (!m_labelsBufferMultiIO.empty())
{
foreach_index (i, m_labelsBufferMultiIO)
{
if (m_labelsBufferMultiIO[i] != NULL)
{
delete[] m_labelsBufferMultiIO[i];
m_labelsBufferMultiIO[i] = NULL;
}
}
}
foreach_index(i, m_labelsBufferMultiIO)
delete[] m_labelsBufferMultiIO[i];

for (size_t i = 0; i < m_numberOfuttsPerMinibatch; i++)
{
if (m_featuresBufferMultiUtt[i] != NULL)
{
delete[] m_featuresBufferMultiUtt[i];
m_featuresBufferMultiUtt[i] = NULL;
}
if (m_labelsBufferMultiUtt[i] != NULL)
{
delete[] m_labelsBufferMultiUtt[i];
m_labelsBufferMultiUtt[i] = NULL;
}
delete[] m_featuresBufferMultiUtt[i];
delete[] m_labelsBufferMultiUtt[i];
}

foreach_index (i, m_trainingOrTestingFeatureSections)
{
if (m_trainingOrTestingFeatureSections[i] != NULL)
{
delete m_trainingOrTestingFeatureSections[i];
m_trainingOrTestingFeatureSections[i] = NULL;
}
}
delete m_trainingOrTestingFeatureSections[i];

foreach_index (i, m_writingFeatureSections)
{
if (m_writingFeatureSections[i] != NULL)
{
delete m_writingFeatureSections[i];
m_writingFeatureSections[i] = NULL;
}
}
delete m_writingFeatureSections[i];
}

// StartMinibatchLoop - Startup a minibatch loop
Expand Down
12 changes: 2 additions & 10 deletions Source/Readers/Kaldi2Reader/KaldiSequenceTrainingDerivative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,8 @@ KaldiSequenceTrainingDerivative<ElemType>::KaldiSequenceTrainingDerivative(
template <class ElemType>
KaldiSequenceTrainingDerivative<ElemType>::~KaldiSequenceTrainingDerivative()
{
if (m_denlatReader != NULL)
{
delete m_denlatReader;
m_denlatReader = NULL;
}
if (m_aliReader != NULL)
{
delete m_aliReader;
m_aliReader = NULL;
}
delete m_denlatReader;
delete m_aliReader;
}

template <class ElemType>
Expand Down
Loading

0 comments on commit 2f9a48c

Please sign in to comment.