Skip to content

Commit

Permalink
Make SimpleOutputWriter consistent with regard to text output or Data…
Browse files Browse the repository at this point in the history
…Writer output.
  • Loading branch information
Dong Yu committed Feb 17, 2016
1 parent 20ca9bd commit d3baea7
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 48 deletions.
5 changes: 5 additions & 0 deletions Source/Common/Include/DataWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ protected: public: // BUGBUG: This is accessed by a wrapper class.
virtual void GetSections(std::map<std::wstring, SectionType, nocase_compare>& sections) = 0;
virtual bool SaveData(size_t recordStart, const std::map<std::wstring, void*, nocase_compare>& matrices, size_t numRecords, size_t datasetSize, size_t byteVariableSized) = 0;
virtual void SaveMapping(std::wstring saveId, const std::map<LabelIdType, LabelType>& labelMapping) = 0;
virtual bool SupportMultiUtterances() const = 0;
};

// GetWriter - get a reader type from the DLL
Expand Down Expand Up @@ -168,6 +169,10 @@ class DataWriter : public IDataWriter<ElemType>, protected Plugin
// saveId - name of the section to save into (section:subsection format)
// labelMapping - map we are saving to the file
virtual void SaveMapping(std::wstring saveId, const std::map<LabelIdType, LabelType>& labelMapping);
virtual bool SupportMultiUtterances() const
{
return false;
};
};

} } }
4 changes: 4 additions & 0 deletions Source/EvalDll/EvalWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,5 +130,9 @@ class EvalWriter : public IDataWriter<ElemType>
return (m_currentRecord >= m_recordCount);
}
virtual void SaveMapping(std::wstring saveId, const std::map<typename EvalWriter<ElemType>::LabelIdType, typename EvalWriter<ElemType>::LabelType>& /*labelMapping*/){};
virtual bool SupportMultiUtterances() const
{
return false;
};
};
} } }
4 changes: 4 additions & 0 deletions Source/Readers/BinaryReader/BinaryReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,10 @@ class BinaryWriter : public IDataWriter<ElemType>
// saveId - name of the section to save into (section:subsection format)
// labelMapping - map we are saving to the file
virtual void SaveMapping(std::wstring saveId, const std::map<typename BinaryWriter<ElemType>::LabelIdType, typename BinaryWriter<ElemType>::LabelType>& labelMapping);
virtual bool SupportMultiUtterances() const
{
return false;
};
};

// utility function to round an integer up to a multiple of size
Expand Down
1 change: 1 addition & 0 deletions Source/Readers/HTKMLFReader/HTKMLFWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,6 @@ class HTKMLFWriter : public IDataWriter<ElemType>
virtual void GetSections(std::map<std::wstring, SectionType, nocase_compare>& sections);
virtual bool SaveData(size_t recordStart, const std::map<std::wstring, void*, nocase_compare>& matrices, size_t numRecords, size_t datasetSize, size_t byteVariableSized);
virtual void SaveMapping(std::wstring saveId, const std::map<LabelIdType, LabelType>& labelMapping);
virtual bool SupportMultiUtterances() const { return false; };
};
} } }
4 changes: 4 additions & 0 deletions Source/Readers/LMSequenceReader/SequenceWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ class LMSequenceWriter : public IDataWriter<ElemType>
}
virtual void Destroy();
virtual bool SaveData(size_t recordStart, const std::map<std::wstring, void*, nocase_compare>& matrices, size_t numRecords, size_t datasetSize, size_t byteVariableSized);
virtual bool SupportMultiUtterances() const
{
return false;
};
};

template <class ElemType>
Expand Down
4 changes: 4 additions & 0 deletions Source/Readers/LUSequenceReader/LUSequenceWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,9 @@ class LUSequenceWriter : public IDataWriter<ElemType>
}
virtual void Destroy();
virtual bool SaveData(size_t recordStart, const std::map<std::wstring, void*, nocase_compare>& matrices, size_t numRecords, size_t datasetSize, size_t byteVariableSized);
virtual bool SupportMultiUtterances() const
{
return false;
};
};
} } }
102 changes: 54 additions & 48 deletions Source/SGDLib/SimpleOutputWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,11 @@ class SimpleOutputWriter
{
typedef shared_ptr<ComputationNode<ElemType>> ComputationNodePtr;

public:
SimpleOutputWriter(ComputationNetworkPtr net, int verbosity = 0)
: m_net(net), m_verbosity(verbosity)
{
}

void WriteOutput(IDataReader<ElemType>& dataReader, size_t mbSize, IDataWriter<ElemType>& dataWriter, const std::vector<std::wstring>& outputNodeNames, size_t numOutputSamples = requestDataSize, bool doUnitTest = false)
private:
std::vector<ComputationNodeBasePtr> DetermineOutputNodes(const std::vector<std::wstring>& outputNodeNames)
{
// specify output nodes and files
std::vector<ComputationNodeBasePtr> outputNodes;

if (outputNodeNames.size() == 0)
{
if (m_verbosity > 0)
Expand All @@ -49,22 +44,56 @@ class SimpleOutputWriter
outputNodes.push_back(m_net->GetNodeFromName(outputNodeNames[i]));
}

// allocate memory for forward computation
m_net->AllocateAllMatrices({}, outputNodes, nullptr);
return outputNodes;
}

// specify feature value nodes
std::map<std::wstring, Matrix<ElemType>*> inputMatrices;
std::vector<ComputationNodeBasePtr> DetermineInputNodes(const std::vector<ComputationNodeBasePtr>& outputNodes)
{
//use map to remove duplicated items
std::map<ComputationNodeBasePtr, int> inputNodesMap;
for (auto& onode : outputNodes)
{
for (auto& inode : m_net->InputNodes(onode))
inputMatrices[inode->NodeName()] = &dynamic_pointer_cast<ComputationNode<ElemType>>(inode)->Value();
inputNodesMap[inode] = 1;
}

std::vector<ComputationNodeBasePtr> inputNodes;
for (auto& inode : inputNodesMap)
inputNodes.push_back(inode.first);

return inputNodes;
}

std::map<std::wstring, Matrix<ElemType>*> RetrieveInputMatrices(const std::vector<ComputationNodeBasePtr>& inputNodes)
{
std::map<std::wstring, Matrix<ElemType>*> inputMatrices;

for (auto& inode : inputNodes)
inputMatrices[inode->NodeName()] = &dynamic_pointer_cast<ComputationNode<ElemType>>(inode)->Value();

// Matrix<ElemType> endOfFile = Matrix<ElemType>((size_t)1,(size_t)1);
// endOfFile(0,0)=0;
return inputMatrices;
}

public:
SimpleOutputWriter(ComputationNetworkPtr net, int verbosity = 0)
: m_net(net), m_verbosity(verbosity)
{
}

void WriteOutput(IDataReader<ElemType>& dataReader, size_t mbSize, IDataWriter<ElemType>& dataWriter, const std::vector<std::wstring>& outputNodeNames, size_t numOutputSamples = requestDataSize, bool doUnitTest = false)
{
std::vector<ComputationNodeBasePtr> outputNodes = DetermineOutputNodes(outputNodeNames);
std::vector<ComputationNodeBasePtr> inputNodes = DetermineInputNodes(outputNodes);

// allocate memory for forward computation
m_net->AllocateAllMatrices({}, outputNodes, nullptr);

std::map<std::wstring, Matrix<ElemType>*> inputMatrices = RetrieveInputMatrices(inputNodes);

// evaluate with minibatches
dataReader.StartMinibatchLoop(mbSize, 0, numOutputSamples);
dataReader.SetNumParallelSequences(1);

if (!dataWriter.SupportMultiUtterances())
dataReader.SetNumParallelSequences(1);
m_net->StartEvaluateMinibatchLoop(outputNodes);

size_t totalEpochSamples = 0;
Expand All @@ -73,10 +102,7 @@ class SimpleOutputWriter
size_t actualMBSize;
while (DataReaderHelpers::GetMinibatchIntoNetwork(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize))
{
// Update timestamp for all input nodes ancestors of the output nodes
for (auto& onode : outputNodes)
for (auto& inode : m_net->InputNodes(onode))
inode->BumpEvalTimeStamp();
ComputationNetwork::BumpEvalTimeStamp(inputNodes);

for (int i = 0; i < outputNodes.size(); i++)
{
Expand Down Expand Up @@ -111,24 +137,15 @@ class SimpleOutputWriter
// E.g. create a shared function that takes the actual writing operation as a lambda.
void WriteOutput(IDataReader<ElemType>& dataReader, size_t mbSize, std::wstring outputPath, const std::vector<std::wstring>& outputNodeNames, size_t numOutputSamples = requestDataSize)
{
msra::files::make_intermediate_dirs(outputPath);
std::vector<ComputationNodeBasePtr> outputNodes = DetermineOutputNodes(outputNodeNames);
std::vector<ComputationNodeBasePtr> inputNodes = DetermineInputNodes(outputNodes);

// specify output nodes and files
std::vector<ComputationNodeBasePtr> outputNodes;
if (outputNodeNames.size() == 0)
{
fprintf(stderr, "OutputNodeNames are not specified, using the default outputnodes.\n");
if (m_net->OutputNodes().size() == 0)
LogicError("There is no default output node specified in the network.");
// allocate memory for forward computation
m_net->AllocateAllMatrices({}, outputNodes, nullptr);

outputNodes = m_net->OutputNodes();
}
else
{
for (int i = 0; i < outputNodeNames.size(); i++)
outputNodes.push_back(m_net->GetNodeFromName(outputNodeNames[i]));
}
std::map<std::wstring, Matrix<ElemType>*> inputMatrices = RetrieveInputMatrices(inputNodes);

msra::files::make_intermediate_dirs(outputPath);
std::vector<ofstream*> outputStreams;
for (int i = 0; i < outputNodes.size(); i++)
#ifdef _MSC_VER
Expand All @@ -137,16 +154,6 @@ class SimpleOutputWriter
outputStreams.push_back(new ofstream(wtocharpath(outputPath + L"." + outputNodes[i]->NodeName()).c_str()));
#endif

// allocate memory for forward computation
m_net->AllocateAllMatrices({}, outputNodes, nullptr);

// specify feature value nodes
auto& featureNodes = m_net->FeatureNodes();
std::map<std::wstring, Matrix<ElemType>*> inputMatrices;
// BUGBUG: This loop is inconsistent with the above version of this function in that it does not handle label nodes.
for (size_t i = 0; i < featureNodes.size(); i++)
inputMatrices[featureNodes[i]->NodeName()] = &dynamic_pointer_cast<ComputationNode<ElemType>>(featureNodes[i])->Value();

// evaluate with minibatches
dataReader.StartMinibatchLoop(mbSize, 0, numOutputSamples);

Expand All @@ -160,8 +167,7 @@ class SimpleOutputWriter
size_t actualMBSize;
while (DataReaderHelpers::GetMinibatchIntoNetwork(dataReader, m_net, nullptr, false, false, inputMatrices, actualMBSize))
{
// BUGBUG: This loop is inconsistent with the above version of this function in that it does not handle label nodes.
ComputationNetwork::BumpEvalTimeStamp(featureNodes);
ComputationNetwork::BumpEvalTimeStamp(inputNodes);

for (int i = 0; i < outputNodes.size(); i++)
{
Expand Down

0 comments on commit d3baea7

Please sign in to comment.