Skip to content

Commit

Permalink
Integrate AddPrintMetadataFlag into master
Browse files Browse the repository at this point in the history
  • Loading branch information
Project Philly committed Feb 14, 2016
2 parents e1093b9 + 68620bb commit f89e544
Show file tree
Hide file tree
Showing 9 changed files with 89 additions and 61 deletions.
7 changes: 6 additions & 1 deletion Source/CNTK/CNTK.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,15 @@ void DumpNodeInfo(const ConfigParameters& config)
wstring defOutFilePath = modelPath + L"." + nodeName + L".txt";
wstring outputFile = config(L"outputFile", defOutFilePath);
bool printValues = config(L"printValues", true);
bool printMetadata = config(L"printMetadata", true);
if (!printValues && !printMetadata)
{
InvalidArgument("printValues and printMetadata: Since both are set to false, there will be nothing to dump");
}

ComputationNetwork net(-1); // always use CPU
net.Load<ElemType>(modelPath);
net.DumpNodeInfoToFile(nodeName, printValues, outputFile, nodeNameRegexStr);
net.DumpNodeInfoToFile(nodeName, printValues, printMetadata, outputFile, nodeNameRegexStr);
}

size_t GetMaxEpochs(const ConfigParameters& configParams)
Expand Down
4 changes: 2 additions & 2 deletions Source/CNTK/ModelEditLanguage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ void MELScript<ElemType>::CallFunction(const std::string& p_name, const ConfigPa
{
NetNdl<ElemType>* netNdl = &found->second;
ProcessNDLScript(netNdl, ndlPassAll, true);
found->second.cn->DumpAllNodesToFile(includeData, fileName);
found->second.cn->DumpAllNodesToFile(includeData, true, fileName);
}
}
else if (EqualInsensitive(name, "DumpNode"))
Expand All @@ -281,7 +281,7 @@ void MELScript<ElemType>::CallFunction(const std::string& p_name, const ConfigPa
NetNdl<ElemType>* netNdl;
vector<ComputationNodeBasePtr> nodes = FindSymbols(params[0], netNdl);
ProcessNDLScript(netNdl, ndlPassAll);
netNdl->cn->DumpNodeInfoToFile(nodes, includeData, fileName);
netNdl->cn->DumpNodeInfoToFile(nodes, includeData, true, fileName);
}
else if (EqualInsensitive(name, "CopyNode", "Copy"))
{
Expand Down
2 changes: 1 addition & 1 deletion Source/CNTK/NDLUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class NDLUtil
// if requested then dump the nodes
// Note: This happens on the invalidated network.
if (dumpFileName != L"")
m_net->DumpAllNodesToFile(false, dumpFileName);
m_net->DumpAllNodesToFile(false, true, dumpFileName);
}
SynchronousNodeEvaluator<ElemType> ndlEvaluator(m_net);
NDLNode<ElemType>* lastNode = script->Evaluate(ndlEvaluator, L"", ndlPass, skipThrough);
Expand Down
14 changes: 8 additions & 6 deletions Source/ComputationNetworkLib/ComputationNetwork.h
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ class ComputationNetwork : public ScriptableObjects::Object, public ScriptableOb
// if node name is not found, dump all nodes
// otherwise dump just that node
// This function is called from MEL, i.e. must be prepared to operate on an uncompiled network (only m_nameToNodeMap is valid).
void DumpNodeInfoToFile(const std::wstring& nodeName, const bool printValues, const std::wstring outputFile, const std::wstring& nodeNameInRegEx = L"")
void DumpNodeInfoToFile(const std::wstring& nodeName, const bool printValues, const bool printMetadata, const std::wstring outputFile, const std::wstring& nodeNameInRegEx = L"")
{
if (nodeNameInRegEx.empty())
{
Expand All @@ -619,13 +619,13 @@ class ComputationNetwork : public ScriptableObjects::Object, public ScriptableOb
FileOptions::fileOptionsText | FileOptions::fileOptionsWrite);

const ComputationNodeBasePtr& nodePtr = GetNodeFromName(nodeName);
nodePtr->DumpNodeInfo(printValues, fstream);
nodePtr->DumpNodeInfo(printValues, printMetadata, fstream);
}
else // node name is not found, dump all nodes
{
fprintf(stderr, "Warning: node name %ls does not exist in the network. dumping all nodes.\n",
nodeName.c_str());
DumpAllNodesToFile(printValues, outputFile);
DumpAllNodesToFile(printValues, printMetadata, outputFile);
}
}
else
Expand All @@ -647,12 +647,13 @@ class ComputationNetwork : public ScriptableObjects::Object, public ScriptableOb
fprintf(stderr, "\t%ls\n", x.c_str());
}
fprintf(stderr, "DumpNodeInfo: dumping node info (%s printing values) to %ls\n", printValues ? "with" : "without", outputFile.c_str());
DumpNodeInfoToFile(NodeList, printValues, outputFile);
DumpNodeInfoToFile(NodeList, printValues, printMetadata, outputFile);
}
}

// dump all nodes in the network to file
void DumpAllNodesToFile(const bool printValues,
const bool printMetadata,
const std::wstring outputFile)
{
File fstream(outputFile,
Expand All @@ -661,12 +662,13 @@ class ComputationNetwork : public ScriptableObjects::Object, public ScriptableOb
for (auto nodeIter = m_nameToNodeMap.begin(); nodeIter != m_nameToNodeMap.end(); nodeIter++)
{
ComputationNodeBasePtr nodePtr = nodeIter->second;
nodePtr->DumpNodeInfo(printValues, fstream);
nodePtr->DumpNodeInfo(printValues, printMetadata, fstream);
}
}

void DumpNodeInfoToFile(const vector<ComputationNodeBasePtr>& nodes,
const bool printValues,
const bool printMetadata,
const std::wstring outputFile)
{
File fstream(outputFile,
Expand All @@ -675,7 +677,7 @@ class ComputationNetwork : public ScriptableObjects::Object, public ScriptableOb
for (auto nodeIter = nodes.begin(); nodeIter != nodes.end(); nodeIter++)
{
ComputationNodeBasePtr nodePtr = *nodeIter;
nodePtr->DumpNodeInfo(printValues, fstream);
nodePtr->DumpNodeInfo(printValues, printMetadata, fstream);
}
}

Expand Down
23 changes: 13 additions & 10 deletions Source/ComputationNetworkLib/ComputationNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,20 +235,23 @@ TensorShape ComputationNodeBase::GetTensorSliceFor(size_t rank, const FrameRange
// -----------------------------------------------------------------------

template <class ElemType>
/*virtual*/ void ComputationNode<ElemType>::DumpNodeInfo(const bool /*printValues*/, File& fstream) const
/*virtual*/ void ComputationNode<ElemType>::DumpNodeInfo(const bool /*printValues*/, const bool printMetadata, File& fstream) const
{
fstream << L"\n" + NodeName() + L"=" + OperationName();

if (!IsLeaf())
if (printMetadata)
{
fstream << wstring(L"(");
for (size_t i = 0; i < GetNumInputs(); i++)
fstream << L"\n" + NodeName() + L"=" + OperationName();

if (!IsLeaf())
{
if (i > 0)
fstream << wstring(L",");
fstream << (Input(i) ? Input(i)->NodeName() : L"NULL");
fstream << wstring(L"(");
for (size_t i = 0; i < GetNumInputs(); i++)
{
if (i > 0)
fstream << wstring(L",");
fstream << (Input(i) ? Input(i)->NodeName() : L"NULL");
}
fstream << wstring(L")");
}
fstream << wstring(L")");
}
}

Expand Down
20 changes: 13 additions & 7 deletions Source/ComputationNetworkLib/ComputationNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ struct /*interface*/ IComputationNode
// --- optional overrides for more informative logging

virtual void PrintSelfBeforeValidation() const = 0; // called in validation loop right before Validate()
virtual void DumpNodeInfo(const bool /*printValues*/, File& fstream) const = 0;
virtual void DumpNodeInfo(const bool /*printValues*/, const bool /*printMetadata*/, File& fstream) const = 0;

protected:
virtual ~IComputationNode()
Expand Down Expand Up @@ -1439,16 +1439,19 @@ class ComputationNode : public ComputationNodeBase // abstract class that cannot
// miscellaneous
// -----------------------------------------------------------------------

virtual void DumpNodeInfo(const bool /*printValues*/, File& fstream) const;
virtual void DumpNodeInfo(const bool /*printValues*/, const bool /*printMetadata*/, File& fstream) const;

protected:

// print node values
void PrintNodeValuesToFile(const bool printValues, File& fstream) const
void PrintNodeValuesToFile(const bool printValues, const bool printMetadata, File& fstream) const
{
if (printValues)
{
fstream << wstring(L"\n");
{
if (printMetadata)
{
fstream << wstring(L"\n");
}
const Matrix<ElemType>& m = Value();
for (size_t i = 0; i < m.GetNumRows(); i++)
{
Expand All @@ -1458,7 +1461,10 @@ class ComputationNode : public ComputationNodeBase // abstract class that cannot
}
fstream << wstring(L"\n");
}
fstream << wstring(L"####################################################################");
if (printMetadata)
{
fstream << wstring(L"####################################################################");
}
}
}

Expand Down Expand Up @@ -1633,7 +1639,7 @@ class FlowControlNode : public ComputationNodeBase
// these are meant to be called during computation, so provide dummy implementations
virtual bool RequiresPreCompute() const override { return false; } // return true if the node's value should be computed before the normal training. e.g., mean and invStd of input features.
virtual void PrintSelfBeforeValidation() const override { }
virtual void DumpNodeInfo(const bool /*printValues*/, File& fstream) const override { }
virtual void DumpNodeInfo(const bool /*printValues*/, const bool /*printMetadata*/, File& fstream) const override {}

protected: public: // needed in ComputationNetwork::FindInRecurrentLoops(), which really should be part of SEQTraversalFlowControlNode
std::vector<ComputationNodeBasePtr> m_nestedNodes; // nodes tucked away in this node, in evaluation order
Expand Down
33 changes: 18 additions & 15 deletions Source/ComputationNetworkLib/ConvolutionalNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,9 @@ class ConvolutionNode : public ComputationNode<ElemType>, public NumInputs<2>
}
}

void DumpNodeInfo(const bool printValues, File& fstream) const override
void DumpNodeInfo(const bool printValues, const bool printMetadata, File& fstream) const override
{
Base::DumpNodeInfo(printValues, fstream);
Base::DumpNodeInfo(printValues, printMetadata, fstream);

auto inDims = ImageDimensions(GetInputSampleLayout(1), m_imageLayoutKind);
auto outDims = ImageDimensions(m_sampleLayout, m_imageLayoutKind);
Expand Down Expand Up @@ -457,21 +457,24 @@ class PoolingNodeBase : public ComputationNode<ElemType>, public NumInputs<1>
}
}

void DumpNodeInfo(const bool printValues, File& fstream) const override
void DumpNodeInfo(const bool printValues, const bool printMetadata, File& fstream) const override
{
Base::DumpNodeInfo(printValues, fstream);
Base::DumpNodeInfo(printValues, printMetadata, fstream);

auto inputSampleLayout = GetInputSampleLayout(0);

char str[4096];
sprintf(str, "Input[Width:%lu, Height:%lu, Channels:%lu] \n", inputSampleLayout[1], inputSampleLayout[2], inputSampleLayout[0]);
fstream << string(str);
sprintf(str, "PoolingWindow[Width:%lu, Height:%lu] SubSampling[Horizontal:%lu, Vertical:%lu]\n", m_windowWidth, m_windowHeight, m_horizontalSubsample, m_verticalSubsample);
fstream << string(str);
sprintf(str, "Output[Width:%lu, Height:%lu, Channels:%lu] \n", m_sampleLayout[1], m_sampleLayout[2], m_sampleLayout[0]);
fstream << string(str);
sprintf(str, "TotalSizePerSample[Input:%lu, Output:%lu] \n", m_inputSizePerSample, m_outputSizePerSample);
fstream << string(str);
if (printMetadata)
{
auto inputSampleLayout = GetInputSampleLayout(0);

char str[4096];
sprintf(str, "Input[Width:%lu, Height:%lu, Channels:%lu] \n", inputSampleLayout[1], inputSampleLayout[2], inputSampleLayout[0]);
fstream << string(str);
sprintf(str, "PoolingWindow[Width:%lu, Height:%lu] SubSampling[Horizontal:%lu, Vertical:%lu]\n", m_windowWidth, m_windowHeight, m_horizontalSubsample, m_verticalSubsample);
fstream << string(str);
sprintf(str, "Output[Width:%lu, Height:%lu, Channels:%lu] \n", m_sampleLayout[1], m_sampleLayout[2], m_sampleLayout[0]);
fstream << string(str);
sprintf(str, "TotalSizePerSample[Input:%lu, Output:%lu] \n", m_inputSizePerSample, m_outputSizePerSample);
fstream << string(str);
}
}

protected:
Expand Down
28 changes: 17 additions & 11 deletions Source/ComputationNetworkLib/InputAndParamNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,17 +196,20 @@ class LearnableParameter : public ComputationNode<ElemType>, public NumInputs<0>
m_pMBLayout = nullptr; // this node does not hold mini-batch data
}

virtual void DumpNodeInfo(const bool printValues, File& fstream) const override
virtual void DumpNodeInfo(const bool printValues, const bool printMetadata, File& fstream) const override
{
Base::DumpNodeInfo(printValues, fstream);
if (printMetadata)
{
Base::DumpNodeInfo(printValues, printMetadata, fstream);

char str[4096];
sprintf(str, "[%lu,%lu] ", GetAsMatrixNumRows(), GetAsMatrixNumCols());
fstream << string(str);
sprintf(str, "NeedGradient=%s", m_parameterUpdateRequired ? "true" : "false"); // TODO: update NDL to accept a better matching name as well
fstream << string(str);
char str[4096];
sprintf(str, "[%lu,%lu] ", GetAsMatrixNumRows(), GetAsMatrixNumCols());
fstream << string(str);
sprintf(str, "NeedGradient=%s", m_parameterUpdateRequired ? "true" : "false"); // TODO: update NDL to accept a better matching name as well
fstream << string(str);
}

PrintNodeValuesToFile(printValues, fstream);
PrintNodeValuesToFile(printValues, printMetadata, fstream);
}
};

Expand Down Expand Up @@ -306,10 +309,13 @@ class InputValueBase : public ComputationNode<ElemType>, public NumInputs<0>
LogicError("InputValueBase::BackpropTo() should never be called.");
}

virtual void DumpNodeInfo(const bool printValues, File& fstream) const override
virtual void DumpNodeInfo(const bool printValues, const bool printMetadata, File& fstream) const override
{
Base::DumpNodeInfo(printValues, fstream);
fstream << "[" << string(GetSampleLayout()) << "]";
Base::DumpNodeInfo(printValues, printMetadata, fstream);
if (printMetadata)
{
fstream << "[" << string(GetSampleLayout()) << "]";
}
}

private:
Expand Down
19 changes: 11 additions & 8 deletions Source/ComputationNetworkLib/PreComputeNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,20 @@ class PreComputedNodeBase : public ComputationNodeNonLooping /*ComputationNode*/
// Note: This loses the sample layout, but that is recovered by Validate().
}

virtual void DumpNodeInfo(const bool printValues, File& fstream) const override
virtual void DumpNodeInfo(const bool printValues, const bool printMetadata, File& fstream) const override
{
Base::DumpNodeInfo(printValues, fstream);
Base::DumpNodeInfo(printValues, printMetadata, fstream);

char str[4096];
sprintf(str, "[%s] ", string(GetSampleLayout()).c_str());
fstream << string(str);
sprintf(str, "HasComputed=%ls", HasComputed() ? L"true" : L"false");
fstream << string(str);
if (printMetadata)
{
char str[4096];
sprintf(str, "[%s] ", string(GetSampleLayout()).c_str());
fstream << string(str);
sprintf(str, "HasComputed=%ls", HasComputed() ? L"true" : L"false");
fstream << string(str);
}

PrintNodeValuesToFile(printValues, fstream);
PrintNodeValuesToFile(printValues, printMetadata, fstream);
}

virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override
Expand Down

0 comments on commit f89e544

Please sign in to comment.