From abd520633acda03f77bbfe4cc765d59bab93ae9c Mon Sep 17 00:00:00 2001 From: Frank Seide Date: Fri, 22 Jan 2016 11:25:52 -0800 Subject: [PATCH] deleted the following nodes, as they are either redundant or currently not implemented correctly (but may come back in the future): StrideTimesNode, ParallelNode, LSTMNode, BatchModeNode, TimeReverseNode; also sorted some of the lookup-table like functions (matching node names) alphabetically --- Source/CNTK/NetworkDescriptionLanguage.cpp | 198 +- .../ComputationNetwork.cpp | 11 +- .../ComputationNetwork.h | 4 +- .../ComputationNetworkBuilder.cpp | 256 +- .../ComputationNetworkBuilder.h | 80 +- .../SpecialPurposeNodes.h | 3340 ++--------------- 6 files changed, 385 insertions(+), 3504 deletions(-) diff --git a/Source/CNTK/NetworkDescriptionLanguage.cpp b/Source/CNTK/NetworkDescriptionLanguage.cpp index 7d493e6fcac4..0afdc287f90d 100644 --- a/Source/CNTK/NetworkDescriptionLanguage.cpp +++ b/Source/CNTK/NetworkDescriptionLanguage.cpp @@ -150,146 +150,72 @@ bool CheckFunction(std::string& p_nodeType, bool* allowUndeterminedVariable) bool ret = false; if (allowUndeterminedVariable) *allowUndeterminedVariable = true; // be default we allow undetermined variables - if (EqualInsensitive(nodeType, OperationNameOf(InputValue), L"Input")) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(SparseInputValue), L"SparseInput")) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(LearnableParameter), L"Parameter")) - ret = true; - else if (EqualInsensitive(nodeType, L"ImageParameter")) - ret = true; - //else if (EqualInsensitive(nodeType, OperationNameOf(SparseLearnableParameter), L"SparseParameter")) - // ret = true; - else if (EqualInsensitive(nodeType, L"Constant", L"Const")) - ret = true; - else if (EqualInsensitive(nodeType, L"ImageInput", L"Image")) - ret = true; - else if (EqualInsensitive(nodeType, L"SparseImageInput", L"SparseImage")) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(SumElementsNode))) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(SumColumnElementsNode))) - ret = true; - else if (EqualInsensitive(nodeType, L"Scale")) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(TransposeNode))) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(TimesNode))) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(TransposeTimesNode))) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(StrideTimesNode))) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(ElementTimesNode))) - ret = true; - else if (EqualInsensitive(nodeType, L"RowElementTimes")) - ret = true; - else if (EqualInsensitive(nodeType, L"ColumnElementTimes")) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(DiagTimesNode))) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(CosDistanceNode), L"CosDist")) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(KhatriRaoProductNode), L"ColumnwiseCrossProduct")) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(PlusNode))) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(MinusNode))) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(NegateNode))) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(RectifiedLinearNode), L"ReLU")) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(SigmoidNode))) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(TanhNode))) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(ExpNode))) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(LogNode))) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(CosineNode), L"Cos")) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(HardmaxNode))) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(SoftmaxNode))) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(LogSoftmaxNode))) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(SquareErrorNode), L"SE")) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(LogisticNode), L"Logistic")) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(CrossEntropyWithSoftmaxNode), L"CEWithSM")) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(SequenceWithSoftmaxNode), L"SEWithSM")) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(CrossEntropyNode))) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(ClassBasedCrossEntropyWithSoftmaxNode), L"CBCEWithSM")) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(MatrixL1RegNode), L"L1Reg")) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(MatrixL2RegNode), L"L2Reg")) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(PerDimMeanVarNormalizationNode), L"PerDimMVNorm")) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(PerDimMeanVarDeNormalizationNode), L"PerDimMVDeNorm")) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(ErrorPredictionNode), L"ClassificationError")) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(DropoutNode))) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(ReshapeNode))) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(RowRepeatNode))) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(DiagonalNode))) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(MeanNode))) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(InvStdDevNode))) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(ConvolutionNode), L"Convolve")) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(MaxPoolingNode))) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(AveragePoolingNode))) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(PastValueNode), L"Delay")) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(FutureValueNode))) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(RowSliceNode))) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(RowStackNode))) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(LookupTableNode))) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(GMMLogLikelihoodNode), L"GMMLL")) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(CosDistanceWithNegativeSamplesNode), L"CosWithNegSamples")) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(TimeReverseNode), L"TimeReverse")) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(CRFNode), L"CRF")) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(DummyCriterionNode), L"DummyCriterion")) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(ParallelNode), L"Parallel")) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(LSTMNode), L"LSTM")) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(StrideTimesNode), L"StrideTimes")) - ret = true; - else if (EqualInsensitive(nodeType, OperationNameOf(BatchNormalizationNode))) - ret = true; + if (EqualInsensitive(nodeType, OperationNameOf(AveragePoolingNode))) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(BatchNormalizationNode))) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(CRFNode), L"CRF")) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(ClassBasedCrossEntropyWithSoftmaxNode), L"CBCEWithSM")) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(ConvolutionNode), L"Convolve")) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(CosDistanceNode), L"CosDist")) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(CosDistanceWithNegativeSamplesNode), L"CosWithNegSamples")) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(CosineNode), L"Cos")) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(CrossEntropyNode))) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(CrossEntropyWithSoftmaxNode), L"CEWithSM")) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(DiagTimesNode))) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(DiagonalNode))) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(DropoutNode))) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(DummyCriterionNode), L"DummyCriterion")) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(ElementTimesNode))) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(ErrorPredictionNode), L"ClassificationError")) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(ExpNode))) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(FutureValueNode))) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(GMMLogLikelihoodNode), L"GMMLL")) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(HardmaxNode))) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(InputValue), L"Input")) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(InvStdDevNode))) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(KhatriRaoProductNode), L"ColumnwiseCrossProduct")) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(LearnableParameter), L"Parameter")) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(LogNode))) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(LogSoftmaxNode))) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(LogisticNode), L"Logistic")) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(LookupTableNode))) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(MatrixL1RegNode), L"L1Reg")) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(MatrixL2RegNode), L"L2Reg")) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(MaxPoolingNode))) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(MeanNode))) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(MinusNode))) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(NegateNode))) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(PastValueNode), L"Delay")) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(PerDimMeanVarDeNormalizationNode), L"PerDimMVDeNorm")) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(PerDimMeanVarNormalizationNode), L"PerDimMVNorm")) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(PlusNode))) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(RectifiedLinearNode), L"ReLU")) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(ReshapeNode))) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(RowRepeatNode))) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(RowSliceNode))) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(RowStackNode))) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(SequenceDecoderNode), L"SEWithSM")) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(SequenceWithSoftmaxNode), L"SEWithSM")) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(SigmoidNode))) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(SoftmaxNode))) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(SparseInputValue), L"SparseInput")) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(SquareErrorNode), L"SE")) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(SumColumnElementsNode))) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(SumElementsNode))) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(TanhNode))) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(TimesNode))) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(TransposeNode))) ret = true; + else if (EqualInsensitive(nodeType, OperationNameOf(TransposeTimesNode))) ret = true; + else if (EqualInsensitive(nodeType, L"ColumnElementTimes")) ret = true; + else if (EqualInsensitive(nodeType, L"Constant", L"Const")) ret = true; + else if (EqualInsensitive(nodeType, L"ImageInput", L"Image")) ret = true; + else if (EqualInsensitive(nodeType, L"ImageParameter")) ret = true; + else if (EqualInsensitive(nodeType, L"RowElementTimes")) ret = true; + else if (EqualInsensitive(nodeType, L"Scale")) ret = true; + else if (EqualInsensitive(nodeType, L"SparseImageInput", L"SparseImage")) ret = true; // return the actual node name in the parameter if we found something if (ret) - { p_nodeType = msra::strfun::utf8(nodeType); - } return ret; } diff --git a/Source/ComputationNetworkLib/ComputationNetwork.cpp b/Source/ComputationNetworkLib/ComputationNetwork.cpp index 6b38e34d48cc..8b22f57bc41f 100644 --- a/Source/ComputationNetworkLib/ComputationNetwork.cpp +++ b/Source/ComputationNetworkLib/ComputationNetwork.cpp @@ -457,7 +457,7 @@ void ComputationNetwork::GetNodesRequiringX(list& nodesR nodesRequiringX.unique(); } -//return list of nodes that require precomputation and not precomputed yet. +// return list of nodes that require precomputation and not precomputed yet list ComputationNetwork::GetNodesRequiringPreComputation(const ComputationNodeBasePtr& rootNode, bool checkComputed) { list nodesRequiringX; @@ -466,15 +466,6 @@ list ComputationNetwork::GetNodesRequiringPreComputation return nodesRequiringX; } -//return list of nodes that require batch mode and not precomputed yet. -list ComputationNetwork::GetNodesRequiringBatchMode(const ComputationNodeBasePtr& rootNode, bool checkComputed) -{ - list nodesRequiringX; - GetNodesRequiringX>(nodesRequiringX, rootNode, checkComputed); - GetNodesRequiringX>(nodesRequiringX, rootNode, checkComputed); - return nodesRequiringX; -} - // create the m_inputValues[] and m_learnableParameters[] lists void ComputationNetwork::CollectInputAndLearnableParameters(const ComputationNodeBasePtr& rootNode) { diff --git a/Source/ComputationNetworkLib/ComputationNetwork.h b/Source/ComputationNetworkLib/ComputationNetwork.h index 78f254e58ab5..82f25570dd4e 100644 --- a/Source/ComputationNetworkLib/ComputationNetwork.h +++ b/Source/ComputationNetworkLib/ComputationNetwork.h @@ -563,10 +563,8 @@ class ComputationNetwork : public ScriptableObjects::Object, public ScriptableOb void GetNodesRequiringX(std::list& nodesRequirePreComputation, const ComputationNodeBasePtr& rootNode, bool checkComputed); public: - //return list of nodes that require precomputation and not precomputed yet. + // return list of nodes that require precomputation and not precomputed yet std::list GetNodesRequiringPreComputation(const ComputationNodeBasePtr& rootNode = nullptr, bool checkComputed = true); - //return list of nodes that require precomputation and not precomputed yet. - std::list GetNodesRequiringBatchMode(const ComputationNodeBasePtr& rootNode = nullptr, bool checkComputed = true); // ----------------------------------------------------------------------- // unit testing diff --git a/Source/ComputationNetworkLib/ComputationNetworkBuilder.cpp b/Source/ComputationNetworkLib/ComputationNetworkBuilder.cpp index aa95e5bf85a3..1f465e91d70f 100644 --- a/Source/ComputationNetworkLib/ComputationNetworkBuilder.cpp +++ b/Source/ComputationNetworkLib/ComputationNetworkBuilder.cpp @@ -33,135 +33,67 @@ template static shared_ptr> CreateStandardNode(const std::wstring& nodeType, _Types&&... _Args) { // please keep this table sorted - if (nodeType == OperationNameOf(CRFNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(ClassBasedCrossEntropyWithSoftmaxNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(CosDistanceNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(CosDistanceWithNegativeSamplesNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(CosineNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(CrossEntropyNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(CrossEntropyWithSoftmaxNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(SequenceWithSoftmaxNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(DiagonalNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(DiagTimesNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(DropoutNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(DummyCriterionNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(ElementTimesNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(ErrorPredictionNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(ExpNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(FutureValueNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(GMMLogLikelihoodNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(HardmaxNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(InvStdDevNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(KhatriRaoProductNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(LSTMNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(LogNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(LogSoftmaxNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(LookupTableNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(MatrixL1RegNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(MatrixL2RegNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(MeanNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(MinusNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(NegateNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(NoiseContrastiveEstimationNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(ParallelNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(PastValueNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(PerDimMeanVarNormalizationNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(PerDimMeanVarDeNormalizationNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(PlusNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(ReconcileMBLayoutNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(RectifiedLinearNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(ReshapeNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(RowRepeatNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(RowSliceNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(RowStackNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(SequenceDecoderNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(ShiftNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(SigmoidNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(SoftmaxNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(SquareErrorNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(LogisticNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(StrideTimesNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(SumColumnElementsNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(SumElementsNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(TanhNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(TimeReverseNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(TimesNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(TransposeNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(TransposeTimesNode)) - return New>(forward<_Types>(_Args)...); + if (nodeType == OperationNameOf(CRFNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(ClassBasedCrossEntropyWithSoftmaxNode))return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(CosDistanceNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(CosDistanceWithNegativeSamplesNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(CosineNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(CrossEntropyNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(CrossEntropyWithSoftmaxNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(DiagonalNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(DiagTimesNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(DropoutNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(DummyCriterionNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(ElementTimesNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(ErrorPredictionNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(ExpNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(FutureValueNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(GMMLogLikelihoodNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(HardmaxNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(InvStdDevNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(KhatriRaoProductNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(LogNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(LogSoftmaxNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(LookupTableNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(MatrixL1RegNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(MatrixL2RegNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(MeanNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(MinusNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(NegateNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(NoiseContrastiveEstimationNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(PastValueNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(PerDimMeanVarNormalizationNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(PerDimMeanVarDeNormalizationNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(PlusNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(ReconcileMBLayoutNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(RectifiedLinearNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(ReshapeNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(RowRepeatNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(RowSliceNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(RowStackNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(SequenceDecoderNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(ShiftNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(SigmoidNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(SoftmaxNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(SquareErrorNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(LogisticNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(SumColumnElementsNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(SumElementsNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(TanhNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(TimesNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(TransposeNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(TransposeTimesNode)) return New>(forward<_Types>(_Args)...); // old names we also support - else if (nodeType == L"ColumnElementTimes") - return New>(forward<_Types>(_Args)...); - else if (nodeType == L"Delay") - return New>(forward<_Types>(_Args)...); - else if (nodeType == L"PerDimMeanVarNormalizationNode") - return New>(forward<_Types>(_Args)...); - else if (nodeType == L"PerDimMeanVarDeNormalizationNode") - return New>(forward<_Types>(_Args)...); - else if (nodeType == L"RowElementTimes") - return New>(forward<_Types>(_Args)...); - else if (nodeType == L"Scale") - return New>(forward<_Types>(_Args)...); + else if (nodeType == L"ColumnElementTimes") return New>(forward<_Types>(_Args)...); + else if (nodeType == L"Delay") return New>(forward<_Types>(_Args)...); + else if (nodeType == L"PerDimMeanVarNormalizationNode") return New>(forward<_Types>(_Args)...); + else if (nodeType == L"PerDimMeanVarDeNormalizationNode") return New>(forward<_Types>(_Args)...); + else if (nodeType == L"RowElementTimes") return New>(forward<_Types>(_Args)...); + else if (nodeType == L"Scale") return New>(forward<_Types>(_Args)...); #if 1 - else if (nodeType == OperationNameOf(DeprecatedReshapeNode)) - return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(DeprecatedReshapeNode)) return New>(forward<_Types>(_Args)...); #endif - else - InvalidArgument("Attempted to instantiate undefined operation %ls.", nodeType.c_str()); + else InvalidArgument("Attempted to instantiate undefined operation %ls.", nodeType.c_str()); } // create a new node of a type given as a string, with var args so that this can be used at multiple places @@ -170,23 +102,14 @@ template static shared_ptr> CreateNode(const std::wstring& nodeType, _Types&&... _Args) { // check more types - if (nodeType == OperationNameOf(AveragePoolingNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(BatchNormalizationNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(ConvolutionNode)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(SparseInputValue)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(InputValue)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(LearnableParameter)) - return New>(forward<_Types>(_Args)...); - else if (nodeType == OperationNameOf(MaxPoolingNode)) - return New>(forward<_Types>(_Args)...); - //else if (nodeType == OperationNameOf(SparseLearnableParameter)) return New>(forward<_Types>(_Args)...); - else - return CreateStandardNode(nodeType, forward<_Types>(_Args)...); + if (nodeType == OperationNameOf(AveragePoolingNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(BatchNormalizationNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(ConvolutionNode)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(SparseInputValue)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(InputValue)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(LearnableParameter)) return New>(forward<_Types>(_Args)...); + else if (nodeType == OperationNameOf(MaxPoolingNode)) return New>(forward<_Types>(_Args)...); + else return CreateStandardNode(nodeType, forward<_Types>(_Args)...); } // this function is called from SimpleNetworkBuilder and old NDL @@ -244,14 +167,6 @@ shared_ptr> ComputationNetworkBuilder::Creat return net.AddNodeToNetWithElemType(New>(net.GetDeviceId(), paramName, tensorShape)); } -#if 0 // not functional at present - //sparse matrix size is optionally specified - template shared_ptr> ComputationNetworkBuilder::CreateSparseLearnableParameter(const std::wstring & paramName, const size_t rows, const size_t cols, const size_t size) - { - return net.AddNodeToNetWithElemType(New>(net.GetDeviceId(), paramName, rows, cols, size)); - } -#endif - template shared_ptr> ComputationNetworkBuilder::CreateInputNode(const std::wstring& inputName, const size_t rows) { @@ -328,7 +243,7 @@ shared_ptr> ComputationNetworkBuilder::Convo return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName, kernelWidth, kernelHeight, outputChannels, horizontalSubsample, verticalSubsample, imageLayoutKind, zeroPadding, maxTempMemSizeInSamples), - weight, inputValues); + weight, inputValues); } template @@ -336,9 +251,7 @@ shared_ptr> ComputationNetworkBuilder::MaxPo const size_t windowWidth, const size_t windowHeight, const size_t horizontalSubsample, const size_t verticalSubsample, ImageLayoutKind imageLayoutKind, const std::wstring nodeName) { - return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName, - windowWidth, windowHeight, horizontalSubsample, verticalSubsample, imageLayoutKind), - inputValues); + return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName, windowWidth, windowHeight, horizontalSubsample, verticalSubsample, imageLayoutKind), inputValues); } template @@ -346,9 +259,7 @@ shared_ptr> ComputationNetworkBuilder::Avera const size_t windowWidth, const size_t windowHeight, const size_t horizontalSubsample, const size_t verticalSubsample, ImageLayoutKind imageLayoutKind, const std::wstring nodeName) { - return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName, - windowWidth, windowHeight, horizontalSubsample, verticalSubsample, imageLayoutKind), - inputValues); + return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName, windowWidth, windowHeight, horizontalSubsample, verticalSubsample, imageLayoutKind), inputValues); } template @@ -440,17 +351,6 @@ shared_ptr> ComputationNetworkBuilder::Dummy return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), objectives, derivatives, prediction); } -template -shared_ptr> ComputationNetworkBuilder::LSTM(const ComputationNodePtr obs, - const ComputationNodePtr inputGate, - const ComputationNodePtr forgetGate, - const ComputationNodePtr outputGate, - const ComputationNodePtr memoryCellWgt, - const std::wstring nodeName) -{ - return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), obs, inputGate, forgetGate, outputGate, memoryCellWgt); -} - template shared_ptr> ComputationNetworkBuilder::CrossEntropy(const ComputationNodePtr label, const ComputationNodePtr prediction, const std::wstring nodeName) { @@ -571,12 +471,6 @@ shared_ptr> ComputationNetworkBuilder::Eleme return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a, b); } -template -shared_ptr> ComputationNetworkBuilder::StrideTimes(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, const std::wstring nodeName) -{ - return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a, b, c); -} - template shared_ptr> ComputationNetworkBuilder::DiagTimes(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName) { @@ -655,12 +549,6 @@ shared_ptr> ComputationNetworkBuilder::Futur return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName, initHiddenActivity, row_size, timeStep), a); } -template -shared_ptr> ComputationNetworkBuilder::Parallel(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName) -{ - return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), a, b); -} - template shared_ptr> ComputationNetworkBuilder::RowSlice(const ComputationNodePtr a, const size_t start_index, const size_t num_rows, const std::wstring nodeName) { @@ -686,12 +574,6 @@ shared_ptr> ComputationNetworkBuilder::GMMLo return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), unnormedPrior, mean, logStddev, feature); } -template -shared_ptr> ComputationNetworkBuilder::TimeReverse(const ComputationNodePtr input, const std::wstring nodeName) -{ - return net.AddNodeToNetAndAttachInputs(New>(net.GetDeviceId(), nodeName), input); -} - template shared_ptr> ComputationNetworkBuilder::LookupTable(const ComputationNodePtr dictionary, const ComputationNodePtr input, const std::wstring nodeName) { diff --git a/Source/ComputationNetworkLib/ComputationNetworkBuilder.h b/Source/ComputationNetworkLib/ComputationNetworkBuilder.h index f0b8ce2f46aa..236d4ff19a69 100644 --- a/Source/ComputationNetworkLib/ComputationNetworkBuilder.h +++ b/Source/ComputationNetworkLib/ComputationNetworkBuilder.h @@ -59,6 +59,8 @@ class ComputationNetworkBuilder ComputationNodePtr CreateComputationNode(const std::wstring& nodeType, const std::wstring& nodeName); // The following functions create nodes and link them to the network and their inputs. // TODO: Do we need both this set and the one above that does not add inputs? Can they share more code? + ComputationNodePtr BatchNormalization(const ComputationNodePtr input, const ComputationNodePtr scale, const ComputationNodePtr bias, + const ComputationNodePtr runMean, const ComputationNodePtr runInvStdDev, bool eval = false, bool spatial = false, double expAvgFactor = 1, ImageLayoutKind imageLayoutKind = ImageLayoutKind::CHW, const std::wstring nodeName = L""); ComputationNodePtr Convolution(const ComputationNodePtr weight, const ComputationNodePtr inputValues, const size_t kernelWidth, const size_t kernelHeight, const size_t outputChannels, @@ -71,63 +73,57 @@ class ComputationNetworkBuilder ComputationNodePtr AveragePooling(const ComputationNodePtr inputValues, const size_t windowWidth, const size_t windowHeight, const size_t horizontalSubsample, const size_t verticalSubsample, ImageLayoutKind imageLayoutKind, const std::wstring nodeName = L""); - ComputationNodePtr ErrorPrediction(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L""); - ComputationNodePtr PerDimMeanVarNormalization(const ComputationNodePtr feature, const ComputationNodePtr mean, const ComputationNodePtr InvStdDev, const std::wstring nodeName = L""); - ComputationNodePtr PerDimMeanVarDeNormalization(const ComputationNodePtr feature, const ComputationNodePtr mean, const ComputationNodePtr InvStdDev, const std::wstring nodeName = L""); - ComputationNodePtr SquareError(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L""); - ComputationNodePtr Logistic(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L""); - ComputationNodePtr Logistic(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, const std::wstring nodeName = L""); - ComputationNodePtr SequenceDecoder(const ComputationNodePtr label, const ComputationNodePtr prediction, const ComputationNodePtr pairscore, const std::wstring nodeName = L""); - ComputationNodePtr CrossEntropyWithSoftmax(const ComputationNodePtr label, const ComputationNodePtr prediction, const std::wstring nodeName = L""); - ComputationNodePtr SequenceWithSoftmax(const ComputationNodePtr label, const ComputationNodePtr prediction, const ComputationNodePtr loglikelihood, const std::wstring nodeName = L""); - ComputationNodePtr NoiseContrastiveEstimation(const ComputationNodePtr label, const ComputationNodePtr prediction, const ComputationNodePtr input_weight, const ComputationNodePtr input_bias, const std::wstring nodeName = L"", NCEEvalMode mode = NCEEvalMode::None); - ComputationNodePtr ClassCrossEntropyWithSoftmax(const ComputationNodePtr label, const ComputationNodePtr prediction, const ComputationNodePtr input_weight, const ComputationNodePtr cls_log_post_prob, const std::wstring nodeName = L""); ComputationNodePtr CRF(const ComputationNodePtr label, const ComputationNodePtr postDepScore, const ComputationNodePtr transition_score, const std::wstring nodeName = L""); - ComputationNodePtr DummyCriterion(const ComputationNodePtr objectives, const ComputationNodePtr derivatives, const ComputationNodePtr prediction, const std::wstring nodeName = L""); - ComputationNodePtr LSTM(const ComputationNodePtr obs, const ComputationNodePtr inputGate, const ComputationNodePtr forgetGate, const ComputationNodePtr outputGate, const ComputationNodePtr memoryCellWgt, const std::wstring nodeName = L""); + ComputationNodePtr ClassCrossEntropyWithSoftmax(const ComputationNodePtr label, const ComputationNodePtr prediction, const ComputationNodePtr input_weight, const ComputationNodePtr cls_log_post_prob, const std::wstring nodeName = L""); + ComputationNodePtr Cos(const ComputationNodePtr a, const std::wstring nodeName = L""); + ComputationNodePtr CosDistance(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L""); ComputationNodePtr CrossEntropy(const ComputationNodePtr label, const ComputationNodePtr prediction, const std::wstring nodeName = L""); + ComputationNodePtr CrossEntropyWithSoftmax(const ComputationNodePtr label, const ComputationNodePtr prediction, const std::wstring nodeName = L""); + ComputationNodePtr DiagTimes(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L""); + ComputationNodePtr Diagonal(const ComputationNodePtr a, const std::wstring nodeName = L""); + ComputationNodePtr Dropout(const ComputationNodePtr a, const std::wstring nodeName = L""); + ComputationNodePtr DummyCriterion(const ComputationNodePtr objectives, const ComputationNodePtr derivatives, const ComputationNodePtr prediction, const std::wstring nodeName = L""); + ComputationNodePtr ElementTimes(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L""); + ComputationNodePtr ErrorPrediction(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L""); + ComputationNodePtr Exp(const ComputationNodePtr a, const std::wstring nodeName = L""); + ComputationNodePtr FutureValue(const ComputationNodePtr a, const float initHiddenActivity, const size_t row_size, size_t timeStep, const std::wstring nodeName = L""); + ComputationNodePtr GMMLogLikelihood(const ComputationNodePtr unnormedPrior, const ComputationNodePtr mean, const ComputationNodePtr logStddev, const ComputationNodePtr feature, const std::wstring nodeName = L""); + ComputationNodePtr Hardmax(const ComputationNodePtr a, const std::wstring nodeName = L""); + ComputationNodePtr InvStdDev(const ComputationNodePtr a, const std::wstring nodeName = L""); + ComputationNodePtr KhatriRaoProduct(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L""); + ComputationNodePtr Log(const ComputationNodePtr a, const std::wstring nodeName = L""); + ComputationNodePtr LogSoftmax(const ComputationNodePtr a, const std::wstring nodeName = L""); + ComputationNodePtr Logistic(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, const std::wstring nodeName = L""); + ComputationNodePtr Logistic(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L""); + ComputationNodePtr LookupTable(const ComputationNodePtr dictionary, const ComputationNodePtr input, const std::wstring nodeName = L""); ComputationNodePtr MatrixL1Reg(const ComputationNodePtr a, const std::wstring nodeName = L""); ComputationNodePtr MatrixL2Reg(const ComputationNodePtr a, const std::wstring nodeName = L""); ComputationNodePtr Mean(const ComputationNodePtr a, const std::wstring nodeName = L""); - ComputationNodePtr InvStdDev(const ComputationNodePtr a, const std::wstring nodeName = L""); + ComputationNodePtr Minus(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L""); ComputationNodePtr Negate(const ComputationNodePtr a, const std::wstring nodeName = L""); + ComputationNodePtr NoiseContrastiveEstimation(const ComputationNodePtr label, const ComputationNodePtr prediction, const ComputationNodePtr input_weight, const ComputationNodePtr input_bias, const std::wstring nodeName = L"", NCEEvalMode mode = NCEEvalMode::None); + ComputationNodePtr PastValue(const ComputationNodePtr a, const float initHiddenActivity, const size_t row_size, size_t timeStep, const std::wstring nodeName = L""); + ComputationNodePtr PerDimMeanVarDeNormalization(const ComputationNodePtr feature, const ComputationNodePtr mean, const ComputationNodePtr InvStdDev, const std::wstring nodeName = L""); + ComputationNodePtr PerDimMeanVarNormalization(const ComputationNodePtr feature, const ComputationNodePtr mean, const ComputationNodePtr InvStdDev, const std::wstring nodeName = L""); + ComputationNodePtr Plus(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L""); ComputationNodePtr RectifiedLinear(const ComputationNodePtr a, const std::wstring nodeName = L""); + ComputationNodePtr Reshape(const ComputationNodePtr a, const TensorShape& imageLayout, const std::wstring nodeName = L""); + ComputationNodePtr RowRepeat(const ComputationNodePtr a, const size_t num_repeat, const std::wstring nodeName = L""); + ComputationNodePtr RowSlice(const ComputationNodePtr a, const size_t start_index, const size_t num_rows, const std::wstring nodeName = L""); + ComputationNodePtr RowStack(const std::vector pinputs, const std::wstring nodeName = L""); + ComputationNodePtr SequenceDecoder(const ComputationNodePtr label, const ComputationNodePtr prediction, const ComputationNodePtr pairscore, const std::wstring nodeName = L""); + ComputationNodePtr SequenceWithSoftmax(const ComputationNodePtr label, const ComputationNodePtr prediction, const ComputationNodePtr loglikelihood, const std::wstring nodeName = L""); ComputationNodePtr Sigmoid(const ComputationNodePtr a, const std::wstring nodeName = L""); - ComputationNodePtr Tanh(const ComputationNodePtr a, const std::wstring nodeName = L""); - ComputationNodePtr Exp(const ComputationNodePtr a, const std::wstring nodeName = L""); - ComputationNodePtr Log(const ComputationNodePtr a, const std::wstring nodeName = L""); - ComputationNodePtr Cos(const ComputationNodePtr a, const std::wstring nodeName = L""); ComputationNodePtr Softmax(const ComputationNodePtr a, const std::wstring nodeName = L""); - ComputationNodePtr Hardmax(const ComputationNodePtr a, const std::wstring nodeName = L""); - ComputationNodePtr LogSoftmax(const ComputationNodePtr a, const std::wstring nodeName = L""); + ComputationNodePtr SquareError(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L""); ComputationNodePtr Sum(const ComputationNodePtr a, const std::wstring nodeName = L""); - ComputationNodePtr Transpose(const ComputationNodePtr matrix, const std::wstring nodeName = L""); + ComputationNodePtr Tanh(const ComputationNodePtr a, const std::wstring nodeName = L""); ComputationNodePtr Times(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L""); + ComputationNodePtr Transpose(const ComputationNodePtr matrix, const std::wstring nodeName = L""); ComputationNodePtr TransposeTimes(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L""); - ComputationNodePtr ElementTimes(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L""); - ComputationNodePtr StrideTimes(const ComputationNodePtr a, const ComputationNodePtr b, const ComputationNodePtr c, const std::wstring nodeName = L""); - ComputationNodePtr DiagTimes(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L""); - ComputationNodePtr CosDistance(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L""); - ComputationNodePtr KhatriRaoProduct(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L""); - ComputationNodePtr Plus(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L""); - ComputationNodePtr Minus(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L""); - ComputationNodePtr Dropout(const ComputationNodePtr a, const std::wstring nodeName = L""); - ComputationNodePtr Reshape(const ComputationNodePtr a, const TensorShape& imageLayout, const std::wstring nodeName = L""); #if 1 // legacy ComputationNodePtr DeprecatedReshape(const ComputationNodePtr a, const size_t num_rows, const TensorShape& imageLayout, const std::wstring nodeName = L""); #endif - ComputationNodePtr RowRepeat(const ComputationNodePtr a, const size_t num_repeat, const std::wstring nodeName = L""); - ComputationNodePtr Diagonal(const ComputationNodePtr a, const std::wstring nodeName = L""); - ComputationNodePtr PastValue(const ComputationNodePtr a, const float initHiddenActivity, const size_t row_size, size_t timeStep, const std::wstring nodeName = L""); - ComputationNodePtr FutureValue(const ComputationNodePtr a, const float initHiddenActivity, const size_t row_size, size_t timeStep, const std::wstring nodeName = L""); - ComputationNodePtr Parallel(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L""); - ComputationNodePtr RowSlice(const ComputationNodePtr a, const size_t start_index, const size_t num_rows, const std::wstring nodeName = L""); - ComputationNodePtr RowStack(const std::vector pinputs, const std::wstring nodeName = L""); - ComputationNodePtr GMMLogLikelihood(const ComputationNodePtr unnormedPrior, const ComputationNodePtr mean, const ComputationNodePtr logStddev, const ComputationNodePtr feature, const std::wstring nodeName = L""); - ComputationNodePtr TimeReverse(const ComputationNodePtr input, const std::wstring nodeName = L""); - ComputationNodePtr LookupTable(const ComputationNodePtr dictionary, const ComputationNodePtr input, const std::wstring nodeName = L""); - ComputationNodePtr BatchNormalization(const ComputationNodePtr input, const ComputationNodePtr scale, const ComputationNodePtr bias, - const ComputationNodePtr runMean, const ComputationNodePtr runInvStdDev, bool eval = false, bool spatial = false, double expAvgFactor = 1, ImageLayoutKind imageLayoutKind = ImageLayoutKind::CHW, const std::wstring nodeName = L""); }; // create a new from config diff --git a/Source/ComputationNetworkLib/SpecialPurposeNodes.h b/Source/ComputationNetworkLib/SpecialPurposeNodes.h index 16ff486653b4..9d8ef0f403e8 100644 --- a/Source/ComputationNetworkLib/SpecialPurposeNodes.h +++ b/Source/ComputationNetworkLib/SpecialPurposeNodes.h @@ -668,3205 +668,293 @@ class SequenceWithSoftmaxNode : public ComputationNodeNonLooping, publ template class SequenceWithSoftmaxNode; template class SequenceWithSoftmaxNode; -#if 0 //def ENABLE_TENSORVIEW -// ----------------------------------------------------------------------- -// PlusNode (summand1, summand2) -// ----------------------------------------------------------------------- - -template -class PlusNode : public BinaryElementWiseNode -{ - typedef BinaryElementWiseNode Base; - UsingBinaryElementwiseNodeBaseMembers; - static const std::wstring TypeName() - { - return L"Plus"; - } - -public: - DeclareConstructorFromConfigWithNumInputs(PlusNode); - PlusNode(DEVICEID_TYPE deviceId, const wstring& name) - : Base(deviceId, name) - { - } - - virtual void /*ComputationNode::*/ BackpropTo(const size_t inputIndex, const FrameRange& fr) override - { - Matrix gradientValues = GradientFor(fr); - Matrix functionValues = ValueFor(fr); - Matrix inputGradientValues = Input(inputIndex)->GradientFor(fr.AllowBroadcast()); - -#if DUMPOUTPUT - functionValues.Print("PlusNode"); -#endif - size_t rowsc = Input(inputIndex)->GetNumRows(), colsc = Input(inputIndex)->GetNumColsFor(fr.AllowBroadcast()); - size_t rowsp = this->GetNumRows(), colsp = this->GetNumColsFor(fr); -#if DUMPOUTPUT - fprintf(stderr, "input dimensions %lld x %lld, this node dimensions %lld x %lld\n", rowsc, colsc, rowsp, colsp); - gradientValues.Print("Gradient-in"); - inputGradientValues.Print("child Gradient-in/out"); -#endif - - if (colsc == colsp && rowsc == rowsp) // matching dimensions --this may also trigger for column vector added to a frame, if fr denotes a single frame - { - // BUGBUG: if we reduce from a frame of a MB into a one-column vector, then we must also mask gaps - inputGradientValues += gradientValues; - } - else if (colsc == 1 && rowsc == 1) // child is a scalar - { - MaskMissingGradientColumnsToZero(fr); // reducing over frames, so we must zero out the gaps - inputGradientValues += gradientValues.SumOfElements(); - } - else if (colsc == 1 && colsp != 1) // child is a broadcasting column vector - { - MaskMissingGradientColumnsToZero(fr); // reducing over frames, so we must zero out the gaps - // Special case for convolution node bias. See comment in EvaluateThisNode for more details. - // BUGBUG: This is not composable. For example, MinusNode does not allow this. - auto convNode = dynamic_pointer_cast>(m_inputs[0]); - if (convNode != nullptr || (convNode = dynamic_pointer_cast>(m_inputs[1])) != nullptr) - convNode->BackwardBias(gradientValues, inputGradientValues); - else - { - size_t colspExpand = rowsp * colsp / rowsc; - Matrix::MultiplyAndAdd(gradientValues.Reshaped(rowsc, colspExpand), false, ConstOnes(colspExpand, 1, functionValues.GetDeviceId()), false, inputGradientValues); - } - } - else if (rowsc == 1 && rowsp != 1) // child is a broadcasting row vector - { - Matrix::MultiplyAndAdd(ConstOnes(1, rowsp, functionValues.GetDeviceId()), false, gradientValues, false, inputGradientValues); - } - else if (colsc != 1 && colsp % colsc == 0) - { - // the children matrix is [a b] and the parent considers it as [a a a b b b] - // Note: There is no need to mask gaps here because this operation is only allowed on non-MBLayout inputs - size_t ratio = colsp / colsc; - for (size_t i = 0; i < colsc; i++) - { - size_t colspExpand = rowsp * colsp / rowsc / colsc; - Matrix tmp = gradientValues.ColumnSlice(i * ratio, ratio); - tmp.Reshape(rowsc, colspExpand); - Matrix res = inputGradientValues.ColumnSlice(i, 1); - Matrix::MultiplyAndAdd(tmp, false, ConstOnes(colspExpand, 1, functionValues.GetDeviceId()), false, res); - inputGradientValues.ColumnSlice(i, 1).SetValue(res); - } - } - else - RuntimeError("Plus partial: unexpected condition."); -#if DUMPOUTPUT - inputGradientValues.Print("child Gradient-out"); -#endif - } - - virtual void /*ComputationNode::*/ ForwardProp(const FrameRange& fr) override - { - Matrix functionValues = ValueFor(fr); - Matrix inputFunctionValues0 = Input(0)->ValueFor(fr.AllowBroadcast()); - Matrix inputFunctionValues1 = Input(1)->ValueFor(fr.AllowBroadcast()); - // Note: If one input is a column vector (no MBLayout) and the other a sequence of frames (MBLayout), then the above will be a slice for the other only. - - size_t rows0 = inputFunctionValues0.GetNumRows(), cols0 = inputFunctionValues0.GetNumCols(); - size_t rows1 = inputFunctionValues1.GetNumRows(), cols1 = inputFunctionValues1.GetNumCols(); - - if ((rows0 == rows1 && cols0 == cols1 /*matching dimensions*/) || ((rows0 == 1 || rows1 == 1) /*one is a broadcasting row vector*/ && cols0 == cols1)) - { - functionValues.AssignSumOf(inputFunctionValues0, inputFunctionValues1); - } - else if (cols0 == 1 && rows1 % rows0 == 0 || cols1 == 1 && rows0 % rows1 == 0) // one is col vec with divisable rows, including scalar --allowing divisable rows can be useful for images - { - // REVIEW alexeyk: this hack is required to handle bias in convolution node which may - // use a format (e.g. NCHW) where bias addition cannot be represented as adding column/row vector to matrix. - // Bias does NOT have to be a vector of size equal to number of output feature map (though it's a common case). - auto convNode = dynamic_pointer_cast>(m_inputs[0]); - if (convNode != nullptr || (convNode = dynamic_pointer_cast>(m_inputs[1])) != nullptr) - { - convNode->AddBias(cols0 == 1 ? inputFunctionValues1 : inputFunctionValues0, - cols0 == 1 ? inputFunctionValues0 : inputFunctionValues1, functionValues); - } - else - { - // None of the input nodes are convolutional. - if (cols0 == 1) - { - functionValues.Reshape(rows0, rows1 * cols1 / rows0); - functionValues.AssignSumOf(inputFunctionValues1.Reshaped(rows0, rows1 * cols1 / rows0), inputFunctionValues0); - } - else - { - functionValues.Reshape(rows1, rows0 * cols0 / rows1); - functionValues.AssignSumOf(inputFunctionValues0.Reshaped(rows1, rows0 * cols0 / rows1), inputFunctionValues1); - } - } - functionValues.Reshape(max(rows0, rows1), max(cols0, cols1)); - } - else if (cols1 < cols0 && rows0 == rows1 && cols0 % cols1 == 0) // first summand is a matrix with number of columns that is a multiple of the column number of the second matrix - { - if (m_pMBLayout) - InvalidArgument("%ls %ls operation applied to mismatching number of columns when columns are samples of a minibatch", NodeName().c_str(), OperationName().c_str()); - // the children matrix is [a b] and the parent considers it as [a a a b b b] - // This can be useful for dealing with images. - Matrix tmpMat(inputFunctionValues1.GetDeviceId()); - size_t ratio = cols0 / cols1; - // TODO: Why is this different from MinusNode? - for (size_t i = 0; i < cols1; i++) - { - tmpMat = Matrix::RepMat(inputFunctionValues1.ColumnSlice(i, 1), 1, ratio); - functionValues.ColumnSlice(i * ratio, ratio).SetValue(tmpMat + inputFunctionValues0.ColumnSlice(i * ratio, ratio)); - } - } - else - LogicError("%ls %ls operation's Validate() function let invalid dimensions slip by.", NodeName().c_str(), OperationName().c_str()); -#if DUMPOUTPUT - functionValues.Print("PlusNode"); -#endif - } -}; - -template class PlusNode; -template class PlusNode; - -// ----------------------------------------------------------------------- -// MinusNode (minuend, subtrahend) -// ----------------------------------------------------------------------- - -template -class MinusNode : public BinaryElementWiseNode -{ - typedef BinaryElementWiseNode Base; - UsingBinaryElementwiseNodeBaseMembers; - static const std::wstring TypeName() - { - return L"Minus"; - } - -public: - DeclareConstructorFromConfigWithNumInputs(MinusNode); - MinusNode(DEVICEID_TYPE deviceId, const wstring& name) - : Base(deviceId, name) - { - } - - virtual void /*ComputationNode::*/ BackpropTo(const size_t inputIndex, const FrameRange& fr) override - { - ElemType sign = inputIndex == 0 ? 1.0f : -1.0f; - Matrix gradientValues = GradientFor(fr); - - Matrix childGradientValues = Input(inputIndex)->GradientFor(fr.AllowBroadcast()); - - size_t rowsc = Input(inputIndex)->GetNumRows(), colsc = Input(inputIndex)->GetNumColsFor(fr.AllowBroadcast()); - size_t rowsp = this->GetNumRows(), colsp = this->GetNumColsFor(fr); - - if (colsc == colsp && rowsc == rowsp) // matching dimensions - { - // BUGBUG: if we reduce from a frame of a MB into a one-column vector, then we must also mask gaps - if (sign > 0) - childGradientValues += gradientValues; - else - childGradientValues -= gradientValues; - } - else if (colsc == 1 && rowsc == 1) // child is a scalar (1 x 1) - { - MaskMissingGradientColumnsToZero(fr); // reducing over frames, so we must zero out the gaps - if (sign > 0) - childGradientValues += gradientValues.SumOfElements(); - else - childGradientValues -= gradientValues.SumOfElements(); - } - else if (colsc == 1 && colsp != 1) // child is broadcasting column vector - { - size_t colspExpand = rowsp * colsp / rowsc; - MaskMissingGradientColumnsToZero(fr); // reducing over frames, so we must zero out the gaps - Matrix::MultiplyAndWeightedAdd(sign, gradientValues.Reshaped(rowsc, colspExpand), false, ConstOnes(colspExpand, 1, Value().GetDeviceId()), false, 1, childGradientValues); - } - else if (rowsc == 1 && rowsp != 1) // child is a broadcasting row vector - { - Matrix::MultiplyAndWeightedAdd(sign, ConstOnes(1, rowsp, Value().GetDeviceId()), false, gradientValues, false, 1, childGradientValues); - } - else - LogicError("%ls %ls operation's Validate() function let invalid dimensions slip by.", NodeName().c_str(), OperationName().c_str()); - } - - virtual void /*ComputationNode::*/ ForwardProp(const FrameRange& fr) override - { - Matrix functionValues = ValueFor(fr); - Matrix inputFunctionValues0 = Input(0)->ValueFor(fr.AllowBroadcast()); - Matrix inputFunctionValues1 = Input(1)->ValueFor(fr.AllowBroadcast()); - - size_t rows0 = inputFunctionValues0.GetNumRows(), cols0 = inputFunctionValues0.GetNumCols(); - size_t rows1 = inputFunctionValues1.GetNumRows(), cols1 = inputFunctionValues1.GetNumCols(); - functionValues.VerifySize(max(rows0, rows1), max(cols0, cols1)); - - if ((rows0 == rows1 && cols0 == cols1 /*match*/) || ((rows0 == 1 || rows1 == 1) /*one is a broadcasting row vector*/ && cols0 == cols1)) - { - functionValues.AssignDifferenceOf(inputFunctionValues0, inputFunctionValues1); - } - else if (cols0 == 1 && rows1 % rows0 == 0) // one is col vec with divisable rows, including scalar - { - functionValues.AssignDifferenceOf(inputFunctionValues0, inputFunctionValues1.Reshaped(rows0, rows1 * cols1 / rows0)); - functionValues.Reshape(max(rows0, rows1), max(cols0, cols1)); - } - else if (cols1 == 1 && rows0 % rows1 == 0) // one is col vec with divisable rows, including scalar - { - functionValues.AssignDifferenceOf(inputFunctionValues0.Reshaped(rows1, rows0 * cols0 / rows1), inputFunctionValues1); - functionValues.Reshape(max(rows0, rows1), max(cols0, cols1)); - } - else - LogicError("%ls %ls operation's Validate() function let invalid dimensions slip by.", NodeName().c_str(), OperationName().c_str()); - } -}; - -template class MinusNode; -template class MinusNode; - -// ----------------------------------------------------------------------- -// ElementTimesNode (factor1, factor2) -// -// This allows broadcasting, and can thus also scale with a row, a column, or a scalar. -// ----------------------------------------------------------------------- - -template -class ElementTimesNode : public BinaryElementWiseNode -{ - typedef BinaryElementWiseNode Base; - UsingBinaryElementwiseNodeBaseMembers; - static const std::wstring TypeName() - { - return L"ElementTimes"; - } - -public: - DeclareConstructorFromConfigWithNumInputs(ElementTimesNode); - ElementTimesNode(DEVICEID_TYPE deviceId, const wstring& name) - : Base(deviceId, name) - { - } - - virtual void /*ComputationNode::*/ BackpropTo(const size_t inputIndex, const FrameRange& fr) override - { - Matrix sliceInput0Grad = Input(inputIndex)->GradientFor(fr); - Matrix sliceOutputGrad = GradientFor(fr); - Matrix sliceInput1Value = Input(1 - inputIndex)->ValueFor(fr); - - // depending on inputIndex, all the input variables change meaning - // inputIndex == 0 (left) - inputGradientValues[0], inputFunctionValues[1] - // inputIndex == 1 (right) - inputGradientValues[1], inputFunctionValues[0] - sliceInput0Grad.AddElementProductOf(sliceOutputGrad, sliceInput1Value); - } - - virtual bool InputUsedInComputingInputNodesGradients(size_t childIndex) const override - { - return true; - } - - virtual void /*ComputationNode::*/ ForwardProp(const FrameRange& fr) override - { - Matrix sliceInput0Value = Input(0)->ValueFor(fr); - Matrix sliceInput1Value = Input(1)->ValueFor(fr); - Matrix sliceOutputValue = ValueFor(fr); - - //ForwardPropS(sliceOutputValue, sliceInput0Value, sliceInput1Value); - sliceOutputValue.AssignElementProductOf(sliceInput0Value, sliceInput1Value); - } -}; - -template class ElementTimesNode; -template class ElementTimesNode; - -// ----------------------------------------------------------------------- -// ScaleNode (scalar scaling factor, matrix) -// -// Identical to ElementTimesNode with tensor lib (broadcasting). Can be removed. -// ----------------------------------------------------------------------- - -template -class ScaleNode : public ComputationNode, public NumInputs<2> -{ - typedef ComputationNode Base; - UsingComputationNodeMembersBoilerplate; - static const std::wstring TypeName() - { - return L"Scale"; - } - -public: - DeclareConstructorFromConfigWithNumInputs(ScaleNode); - ScaleNode(DEVICEID_TYPE deviceId, const wstring& name) - : Base(deviceId, name) - { - } - - virtual void /*ComputationNode::*/ BackpropTo(const size_t inputIndex, const FrameRange& fr) override - { -#if 1 //def ENABLE_TENSORVIEW // This takes a big perf hit since our reduction uses only a single thread in this case. Needs to be fixed. - size_t rank = DetermineElementwiseTensorRank(); - auto gradient = GradientTensorFor(rank, fr); - auto inputGradient = Input(inputIndex)->GradientTensorFor(rank, fr.AllowBroadcast()); - auto otherInputValue = Input(1 - inputIndex)->ValueTensorFor(rank, fr.AllowBroadcast()); - - // if reduction then mask the respective input(s) (zero out the gaps) - if (Input(inputIndex)->GetNumCols() < GetNumCols()) - MaskMissingGradientColumnsToZero(fr); - if (Input(inputIndex)->GetNumCols() < Input(1 - inputIndex)->GetNumCols()) - Input(1 - inputIndex)->MaskMissingValueColumnsToZero(fr); - - inputGradient.AddElementwiseProductOf(gradient, otherInputValue); -#else - if (inputIndex == 0) // left derivative - { - // this is a reduction over frames, so we must mask gaps to zero - Input(0)->Gradient() += Matrix::InnerProductOfMatrices(MaskedGradientFor(fr), Input(1)->MaskedValueFor(fr)); // element-wise product summed up over all - } - else if (inputIndex == 1) // right derivative - { - Matrix sliceInput1Grad = Input(1)->GradientFor(fr); - Matrix::Multiply1x1AndWeightedAdd(+1.0f, Input(0)->Value() /*1x1*/, GradientFor(fr), 1.0f, sliceInput1Grad); - } -#endif - } - - virtual bool OutputUsedInComputingInputNodesGradients() const override - { - // The ScaleNode does not require its output value for computing - // the gradients of its input nodes - return false; - } - - virtual void /*ComputationNode::*/ ForwardProp(const FrameRange& fr) override - { -#if 1 //def ENABLE_TENSORVIEW - static int c = 0; - if (c++ == 0) - { - fprintf(stderr, "#SCALE#\n"); - } - size_t rank = DetermineElementwiseTensorRank(); - auto result = ValueTensorFor(rank, fr); - auto input0 = Input(0)->ValueTensorFor(rank, fr.AllowBroadcast()); - auto input1 = Input(1)->ValueTensorFor(rank, fr.AllowBroadcast()); - result.AssignElementwiseProductOf(input0, input1); -#else - ValueFor(fr).Assign1x1ProductOf(Input(0)->Value() /*1x1*/, Input(1)->ValueFor(fr)); -#endif - } - - virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override - { - Base::Validate(isFinalValidationPass); - InferMBLayoutFromInputsForStandardCase(); - - // left node must be a scalar - if (isFinalValidationPass && (Input(0)->GetNumRows() != 1 || Input(0)->GetNumCols() != 1)) - RuntimeError("The left value of ScaleNode must be a scalar value."); - - SetDims(Input(1)); - } -}; - -template class ScaleNode; -template class ScaleNode; - -// ----------------------------------------------------------------------- -// RowElementTimesNode (left, right) --TODO: what are left and right? -// -// TODO: This is subsumed by ElementTimes with tensor lib. -// ----------------------------------------------------------------------- - -template -class RowElementTimesNode : public ComputationNode, public NumInputs<2> -{ - typedef ComputationNode Base; - UsingComputationNodeMembersBoilerplate; - static const std::wstring TypeName() - { - return L"RowElementTimes"; - } - -public: - DeclareConstructorFromConfigWithNumInputs(RowElementTimesNode); - RowElementTimesNode(DEVICEID_TYPE deviceId, const wstring& name) - : Base(deviceId, name) - { - } - - void BackpropToMap(const size_t inputIndex) - { - if (inputIndex > 1) - InvalidArgument("RowElementTimes operation only takes two inputs."); - - if (inputIndex == 0) - { - BackpropToLeftS(Input(1)->Value(), Input(0)->Gradient(), Gradient(), *m_tempMatrix); - } - else - { - BackpropToRightS(Input(0)->Value(), Input(1)->Gradient(), Gradient(), *m_tempMatrix); - } - } - - virtual void /*ComputationNode::*/ BackpropTo(const size_t inputIndex, const FrameRange& fr) override - { - if (fr.IsAllFrames()) - { - BackpropToMap(inputIndex); - return; - } // TODO: remove these one by one - Matrix sliceInput0Grad = Input(inputIndex)->GradientFor(fr); - Matrix sliceOutputGrad = GradientFor(fr); - - Matrix sliceInput1Value = Input(1 - inputIndex)->ValueFor(fr); - - if (inputIndex == 0) - { - BackpropToLeftS(sliceInput1Value, sliceInput0Grad, sliceOutputGrad, *m_tempMatrix); - } - else - { - BackpropToRightS(sliceInput1Value, sliceInput0Grad, sliceOutputGrad, *m_tempMatrix); - } - } - - virtual bool OutputUsedInComputingInputNodesGradients() const override - { - // The RowElementTimesNode does not require its output value for computing - // the gradients of its input nodes - return false; - } - - //left (input 0) is a matrix - /*TODO: merge with call site*/ void BackpropToLeftS(Matrix& input1FunctionValues, - Matrix& input0GradientValues, - const Matrix& gradientValues, - Matrix& tempMatrix) - { - tempMatrix.SetValue(gradientValues); - tempMatrix.RowElementMultiplyWith(input1FunctionValues); - input0GradientValues += tempMatrix; - -#if NANCHECK - input0GradientValues.HasNan("RowElementTimes"); -#endif - } - - //right (input 1) is a row vector - /*TODO: merge with call site*/ void BackpropToRightS(Matrix& input0FunctionValues, - Matrix& input1GradientValues, - const Matrix& gradientValues, - Matrix& tempMatrix) - { - tempMatrix.AssignInnerProductOf(gradientValues, input0FunctionValues, true); - input1GradientValues += tempMatrix; - -#if NANCHECK - input1GradientValues.HasNan("RowElementTimes"); -#endif - } - void ForwardPropMap() // TODO: This is a stop-gap; in most cases, we should just be able to delete this (but need to review one by one) - { - ForwardPropS(Value(), Input(0)->Value(), Input(1)->Value()); - } - - virtual void /*ComputationNode::*/ ForwardProp(const FrameRange& fr) override - { - //if (fr.IsAllFrames()) { ForwardPropMap(); return; } - Matrix sliceInput0Value = Input(0)->ValueFor(fr); - Matrix sliceInput1Value = Input(1)->ValueFor(fr); - Matrix sliceOutputValue = ValueFor(fr); - - ForwardPropS(sliceOutputValue, sliceInput0Value, sliceInput1Value); - } - - /*TODO: merge with call site*/ void ForwardPropS(Matrix& functionValues, const Matrix& input0, const Matrix& input1) - { - functionValues.SetValue(input0); - functionValues.RowElementMultiplyWith(input1); - -#if NANCHECK - functionValues.HasNan("RowElementTimes"); -#endif - } - - virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override - { - Base::Validate(isFinalValidationPass); - InferMBLayoutFromInputsForStandardCase(); - - size_t rows0 = Input(0)->GetNumRows(), cols0 = Input(0)->GetNumCols(); - size_t rows1 = Input(1)->GetNumRows(), cols1 = Input(1)->GetNumCols(); - rows0; - if (isFinalValidationPass && cols0 != cols1 || rows1 != 1) - LogicError("RowElementTimes: Either the second operand is not a row vector or the number of columns of operands does not match."); - - SetDims(Input(0)); - } - - //request matrices that are needed for gradient computation - virtual void RequestMatricesBeforeBackprop(MatrixPool& matrixPool) - { - Base::RequestMatricesBeforeBackprop(matrixPool); - RequestMatrixFromPool(m_tempMatrix, matrixPool); - } - - //release gradient and temp matrices that no longer needed after all the children's gradients are computed. - virtual void ReleaseMatricesAfterBackprop(MatrixPool& matrixPool) - { - Base::ReleaseMatricesAfterBackprop(matrixPool); - ReleaseMatrixToPool(m_tempMatrix, matrixPool); - } - -private: - shared_ptr> m_tempMatrix; -}; - -template class RowElementTimesNode; -template class RowElementTimesNode; - -// ----------------------------------------------------------------------- -// ColumnElementTimesNode (left, right) --TODO: what are left and right? -// -// TODO: This is subsumed by ElementTimes with tensor lib. -// ----------------------------------------------------------------------- - -template -class ColumnElementTimesNode : public ComputationNode, public NumInputs<2> -{ - typedef ComputationNode Base; - UsingComputationNodeMembersBoilerplate; - static const std::wstring TypeName() - { - return L"ColumnElementTimes"; - } - -public: - DeclareConstructorFromConfigWithNumInputs(ColumnElementTimesNode); - ColumnElementTimesNode(DEVICEID_TYPE deviceId, const wstring& name) - : Base(deviceId, name) - { - } - - void BackpropToMap(const size_t inputIndex) - { - if (inputIndex > 1) - InvalidArgument("ColumnElementTimes operation only takes two inputs."); - - if (inputIndex == 0) - { - BackpropToLeftS(Input(1)->Value(), Input(0)->Gradient(), Gradient(), *m_tempMatrix); - } - else - { - BackpropToRightS(Input(0)->Value(), Input(1)->Gradient(), Gradient(), *m_tempMatrix); - } - } - - virtual void /*ComputationNode::*/ BackpropTo(const size_t inputIndex, const FrameRange& fr) override - { - if (fr.IsAllFrames()) - { - BackpropToMap(inputIndex); - return; - } // TODO: remove these one by one - Matrix sliceOutputGrad = GradientFor(fr); - - if (inputIndex == 0) - { - Matrix sliceInput0Grad = Input(0)->GradientFor(fr); - - BackpropToLeftS(Input(1)->Value(), sliceInput0Grad, sliceOutputGrad, *m_tempMatrix); - } - else - { - Matrix sliceInput0Value = Input(0)->ValueFor(fr); - BackpropToRightS(sliceInput0Value, Input(1)->Gradient(), sliceOutputGrad, *m_tempMatrix); - } - } - - virtual bool OutputUsedInComputingInputNodesGradients() const override - { - // The ColumnElementTimesNode does not require its output value for computing - // the gradients of its input nodes - return false; - } - - //left (input 0) is a matrix - /*TODO: merge with call site*/ void BackpropToLeftS(Matrix& input1FunctionValues, - Matrix& input0GradientValues, - const Matrix& gradientValues, - Matrix& tempMatrix) - { - tempMatrix.SetValue(gradientValues); - tempMatrix.ColumnElementMultiplyWith(input1FunctionValues); - input0GradientValues += tempMatrix; - -#if NANCHECK - input0GradientValues.HasNan("ColumnElementTimes"); -#endif - } - - //right (input 1) is a col vector - /*TODO: merge with call site*/ void BackpropToRightS(Matrix& input0FunctionValues, - Matrix& input1GradientValues, - const Matrix& gradientValues, - Matrix& tempMatrix) - { - tempMatrix.AssignInnerProductOf(gradientValues, input0FunctionValues, false); - input1GradientValues += tempMatrix; - -#if NANCHECK - input1GradientValues.HasNan("ColumnElementTimes"); -#endif - } - void ForwardPropMap() // TODO: This is a stop-gap; in most cases, we should just be able to delete this (but need to review one by one) - { - ForwardPropS(Value(), Input(0)->Value(), Input(1)->Value()); - } - - virtual void /*ComputationNode::*/ ForwardProp(const FrameRange& fr) override - { - //if (fr.IsAllFrames()) { ForwardPropMap(); return; } - Matrix sliceInput0Value = Input(0)->ValueFor(fr); - Matrix sliceOutputValue = ValueFor(fr); - - ForwardPropS(sliceOutputValue, sliceInput0Value, Input(1)->Value()); - } - - /*TODO: merge with call site*/ void ForwardPropS(Matrix& functionValues, const Matrix& input0, const Matrix& input1) - { - functionValues.SetValue(input0); - functionValues.ColumnElementMultiplyWith(input1); - -#if NANCHECK - functionValues.HasNan("ColumnElementTimes"); -#endif - } - - virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override - { - Base::Validate(isFinalValidationPass); - InferMBLayoutFromInputsForStandardCase(); - - //derive number of rows if possible - for (size_t index = 0; index < 2; index++) - { - size_t rows = Input(index)->GetNumRows() == 0 ? Input(1 - index)->GetNumRows() : Input(index)->GetNumRows(); - size_t cols = Input(index)->GetNumCols() == 0 ? Input(1 - index)->GetNumCols() : Input(index)->GetNumCols(); - ValidateInferInputDimsFrom(index, rows, cols); - } - - size_t rows0 = Input(0)->GetNumRows(), cols0 = Input(0)->GetNumCols(); - size_t rows1 = Input(1)->GetNumRows(), cols1 = Input(1)->GetNumCols(); - cols0; - if (isFinalValidationPass && (rows0 != rows1 || cols1 != 1)) - LogicError("ColumnElementTimes: Either the second operand is not a column vector or the number of rows of operands does not match."); - - SetDims(Input(0)); - } - - //request matrices that are needed for gradient computation - virtual void RequestMatricesBeforeBackprop(MatrixPool& matrixPool) - { - Base::RequestMatricesBeforeBackprop(matrixPool); - RequestMatrixFromPool(m_tempMatrix, matrixPool); - } - - //release gradient and temp matrices that no longer needed after all the children's gradients are computed. - virtual void ReleaseMatricesAfterBackprop(MatrixPool& matrixPool) - { - Base::ReleaseMatricesAfterBackprop(matrixPool); - ReleaseMatrixToPool(m_tempMatrix, matrixPool); - } - -private: - shared_ptr> m_tempMatrix; -}; - -template class ColumnElementTimesNode; -template class ColumnElementTimesNode; - -// ----------------------------------------------------------------------- -// RectifiedLinearNode (input) -- ReLU non-linearity -// ----------------------------------------------------------------------- - -template -class RectifiedLinearNode : public SoftmaxNodeBase -{ - typedef SoftmaxNodeBase Base; - UsingSoftmaxNodeBaseMembers; - static const std::wstring TypeName() - { - return L"RectifiedLinear"; - } - -public: - DeclareConstructorFromConfigWithNumInputs(RectifiedLinearNode); - RectifiedLinearNode(DEVICEID_TYPE deviceId, const wstring& name) - : Base(deviceId, name) - { - } - - void BackpropToV(Matrix& gradient, const Matrix& inputFunctionValues, Matrix& inputGradientValues, const Matrix& gradientValues, const Matrix& functionValues) override - { - gradient.AssignLinearRectifierDerivativeOf(inputFunctionValues); -#if DUMPOUTPUT - inputGradientValues.Print("RecitifiedLinearNode-Partial-in"); -#endif - inputGradientValues.AddElementProductOf(gradientValues, gradient); -#if DUMPOUTPUT - inputGradientValues.Print("RecitifiedLinearNode-Partial-out"); -#endif - } - - virtual bool OutputUsedInComputingInputNodesGradients() const override - { - // The ReLU node does not require its output value for computing - // the gradients of its input nodes - return false; - } - - void ForwardPropV(Matrix& functionValues, const Matrix& inputFunctionValues) override - { - functionValues.AssignTruncateBottomOf(inputFunctionValues, 0); -#if DUMPOUTPUT - functionValues.Print("RectifiedLinearNode"); -#endif - } -}; - -template class RectifiedLinearNode; -template class RectifiedLinearNode; - -// ----------------------------------------------------------------------- -// SigmoidNode (input) -- sigmoid non-linearity -// ----------------------------------------------------------------------- - -template -class SigmoidNode : public SoftmaxNodeBase -{ - typedef SoftmaxNodeBase Base; - UsingSoftmaxNodeBaseMembers; - static const std::wstring TypeName() - { - return L"Sigmoid"; - } - -public: - DeclareConstructorFromConfigWithNumInputs(SigmoidNode); - SigmoidNode(DEVICEID_TYPE deviceId, const wstring& name) - : Base(deviceId, name) - { - } - - virtual bool InputUsedInComputingInputNodesGradients(size_t childIndex) const override - { - // The Sigmoid node does not require any of it's input's values for computing - // the gradients of its input nodes - UNREFERENCED_PARAMETER(childIndex); - return false; - } - - /*virtual*/ void BackpropToV(Matrix& gradient, const Matrix& inputFunctionValues, Matrix& inputGradientValues, const Matrix& gradientValues, const Matrix& functionValues) - { - gradient.AssignSigmoidDerivativeOf(functionValues); - inputGradientValues.AddElementProductOf(gradientValues, gradient); - } - - /*virtual*/ void ForwardPropV(Matrix& functionValues, const Matrix& inputFunctionValues) override - { - functionValues.AssignSigmoidOf(inputFunctionValues); - } -}; - -template class SigmoidNode; -template class SigmoidNode; - -// ----------------------------------------------------------------------- -// TanhNode (input) -- tanh non-linearity -// ----------------------------------------------------------------------- - -template -class TanhNode : public SoftmaxNodeBase -{ - typedef SoftmaxNodeBase Base; - UsingSoftmaxNodeBaseMembers; - static const std::wstring TypeName() - { - return L"Tanh"; - } - -public: - DeclareConstructorFromConfigWithNumInputs(TanhNode); - TanhNode(DEVICEID_TYPE deviceId, const wstring& name) - : Base(deviceId, name) - { - } - - virtual bool InputUsedInComputingInputNodesGradients(size_t childIndex) const override - { - // The plus node does not require any of it's input's values for computing - // the gradients of its input nodes - UNREFERENCED_PARAMETER(childIndex); - return false; - } - - /*virtual*/ void BackpropToV(Matrix& gradient, const Matrix& inputFunctionValues, Matrix& inputGradientValues, const Matrix& gradientValues, const Matrix& functionValues) - { - gradient.AssignElementProductOf(functionValues, functionValues); // v .* v - gradient.AssignDifferenceOf(1, gradient); // 1-v^2 - - inputGradientValues.AddElementProductOf(gradientValues, gradient); // += d .* ((1-v) .* v)) - } - - /*virtual*/ void ForwardPropV(Matrix& functionValues, const Matrix& inputFunctionValues) override - { - functionValues.AssignTanhOf(inputFunctionValues); - } -}; - -template class TanhNode; -template class TanhNode; - -// ----------------------------------------------------------------------- -// LogNode (input) -- component-wise log() of input -// ----------------------------------------------------------------------- - -template -class LogNode : public SoftmaxNodeBase -{ - typedef SoftmaxNodeBase Base; - UsingSoftmaxNodeBaseMembers; - static const std::wstring TypeName() - { - return L"Log"; - } - -public: - DeclareConstructorFromConfigWithNumInputs(LogNode); - LogNode(DEVICEID_TYPE deviceId, const wstring& name) - : Base(deviceId, name) - { - } - - virtual bool OutputUsedInComputingInputNodesGradients() const override - { - // The plus node does not require its output value for computing - // the gradients of its input nodes - return false; - } - - /*virtual*/ void BackpropToV(Matrix& gradient, const Matrix& inputFunctionValues, Matrix& inputGradientValues, const Matrix& gradientValues, const Matrix& functionValues) - { - gradient.AssignElementInverseOf(inputFunctionValues); // 1/x (x is input to log(x)) - inputGradientValues.AddElementProductOf(gradientValues, gradient); - // TODO: with tensor lib: - //inputGradientValues.AddElementDivisionOf(gradientValues, inputFunctionValues); // 1/x (x is input to log(x)) - } - - /*virtual*/ void ForwardPropV(Matrix& functionValues, const Matrix& inputFunctionValues) override - { - functionValues.AssignLogOf(inputFunctionValues); - } -}; - -template class LogNode; -template class LogNode; - -// ----------------------------------------------------------------------- -// ExpNode (input) -- component-wise exp() of input -// ----------------------------------------------------------------------- - -template -class ExpNode : public SoftmaxNodeBase -{ - typedef SoftmaxNodeBase Base; - UsingSoftmaxNodeBaseMembers; - static const std::wstring TypeName() - { - return L"Exp"; - } - -public: - DeclareConstructorFromConfigWithNumInputs(ExpNode); - ExpNode(DEVICEID_TYPE deviceId, const wstring& name) - : Base(deviceId, name) - { - } - - virtual void /*ComputationNode::*/ BackpropTo(const size_t inputIndex, const FrameRange& fr) override - { - assert(inputIndex == 0); - inputIndex; - - Matrix sliceInputGrad = Input(0)->GradientFor(fr); - Matrix sliceOutputGrad = GradientFor(fr); - Matrix sliceInputValue = Input(0)->ValueFor(fr); - - m_gradientTemp->AssignExpOf(sliceInputValue); // Exp(x) is its own partial - sliceInputGrad.AddElementProductOf(sliceOutputGrad, *m_gradientTemp); - // TODO: with tensor lib: - // sliceInputGrad.AddElementProductOf(sliceOutputGrad, functionValues); - // and set OutputUsed - } - - virtual bool OutputUsedInComputingInputNodesGradients() const override - { - // The ExpNode does not require its output value for computing - // the gradients of its input nodes - return false; - } - - virtual void BackpropToV(Matrix& gradient, const Matrix& inputFunctionValues, Matrix& inputGradientValues, const Matrix& gradientValues, const Matrix& functionValues) override - { - NOT_IMPLEMENTED; - } // not needed - - void ForwardPropV(Matrix& functionValues, const Matrix& inputFunctionValues) override - { - functionValues.AssignExpOf(inputFunctionValues); - } -}; - -template class ExpNode; -template class ExpNode; - -// ----------------------------------------------------------------------- -// CosineNode (input) -- component-wise cos() of input -// ----------------------------------------------------------------------- - -template -class CosineNode : public SoftmaxNodeBase -{ - typedef SoftmaxNodeBase Base; - UsingSoftmaxNodeBaseMembers; - static const std::wstring TypeName() - { - return L"Cosine"; - } - -public: - DeclareConstructorFromConfigWithNumInputs(CosineNode); - CosineNode(DEVICEID_TYPE deviceId, const wstring& name) - : Base(deviceId, name) - { - } - - virtual bool OutputUsedInComputingInputNodesGradients() const override - { - // The CosineNode does not require its output value for computing - // the gradients of its input nodes - return false; - } - - /*virtual*/ void BackpropToV(Matrix& gradient, const Matrix& inputFunctionValues, Matrix& inputGradientValues, const Matrix& gradientValues, const Matrix& functionValues) - { - gradient.AssignNegativeSineOf(inputFunctionValues); // -sin(x) (x is input to Cosine(x)) - inputGradientValues.AddElementProductOf(gradientValues, gradient); - // TODO: tensor lib: make a joint kernel, since neg sin is never used for anything else - } - - /*virtual*/ void ForwardPropV(Matrix& functionValues, const Matrix& inputFunctionValues) override - { - functionValues.AssignCosineOf(inputFunctionValues); - } -}; - -template class CosineNode; -template class CosineNode; -#endif - -// ----------------------------------------------------------------------- -/// DummyCriterionNode (objectives, derivatives, prediction) -// ----------------------------------------------------------------------- - -// This training criterion node needs derivatives and objectives to be -// computed out of the node. Derivatives and objectives will be fed to the -// node as input features. It has 3 inputs: -// 1. feature node that feeds objectives -// 2. feature node that feeds derivatives -// 3. neural network output -// -// This node is useful in sequence training for speech recognition, so that -// we can separate lattice computation (which may rely other softwares, such -// as Kaldi) with the neural network training. - -template -class DummyCriterionNode : public ComputationNodeNonLooping /*ComputationNode*/, public NumInputs<3> -{ - typedef ComputationNodeNonLooping Base; - UsingComputationNodeMembersBoilerplate; - static const std::wstring TypeName() - { - return L"DummyCriterion"; - } - -public: - DeclareConstructorFromConfigWithNumInputs(DummyCriterionNode); - DummyCriterionNode(DEVICEID_TYPE deviceId, const wstring& name) - : Base(deviceId, name) - { - } - - virtual void BackpropToNonLooping(size_t inputIndex) override - { - FrameRange fr(Input(0)->GetMBLayout()); - if (inputIndex == 0) - LogicError("DummyCriterionNode: derivatives with respect to objective features are not necessary, not implemented yet.\n"); - else if (inputIndex == 1) - LogicError("DummyCriterionNode: derivatives with respect to derivative features are not necessary, not implemented yet.\n"); - else if (inputIndex == 2) - { - auto gradient = Input(2)->GradientFor(fr); - //Matrix::ScaleAndAdd(Gradient().Get00Element(), Input(1)->ValueFor(fr), gradient); - Matrix::Multiply1x1AndWeightedAdd(+1.0f, Gradient() /*1x1*/, Input(1)->ValueFor(fr), 1.0f, gradient); - } - } - - virtual bool OutputUsedInComputingInputNodesGradients() const override - { - return false; - } - - virtual void /*ComputationNodeNonLooping::*/ ForwardPropNonLooping() override - { - Value().VerifySize(1, 1); - Input(0)->Value().VerifySize(1, 1); - Value().SetValue(Input(0)->Value()); -#if NANCHECK - Value().HasNan("DummyCriterionNode"); -#endif - } - - virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override - { - Base::Validate(isFinalValidationPass); - m_pMBLayout = nullptr; // this node does not hold mini-batch data - - if (Input(0)->OperationName() != L"InputValue") - LogicError("DummyCriterionNode criterion requires the first input to be computed objectives."); - if (Input(0)->OperationName() != L"InputValue") - LogicError("DummyCriterionNode criterion requires the first input to be computed derivatives."); - if (isFinalValidationPass) - { - if (Input(0)->GetSampleMatrixNumRows() != 1) - LogicError("DummyCriterionNode criterion requires the first input to have dimension 1."); - if (Input(0)->GetSampleMatrixNumRows() == 0 || Input(1)->GetSampleMatrixNumRows() == 0 || Input(2)->GetSampleMatrixNumRows() == 0) - LogicError("DummyCriterionNode operation: one of the operands has 0 elements."); - if (Input(1)->GetSampleMatrixNumRows() != Input(2)->GetSampleMatrixNumRows()) - LogicError("The Matrix dimension in the DummyCriterionNode operation does not match."); - } - - SetDims(TensorShape(1), false); - } -}; - -template class DummyCriterionNode; -template class DummyCriterionNode; - -// ----------------------------------------------------------------------- -// SequenceDecoderNode (label, position_dependent_score, transition_score) -// this node does sequence decoding only -// it corresponds to a decoder -// - label : output label vector of [0:T-1] -// - position_dependent_score : score from position dependent node, -// in the R-CRF case, it is the RNN output score before softmax -// - transition score : score from the transition node, -// in the R-CRF case, it is the transition probability between labels -// ----------------------------------------------------------------------- - -template -class SequenceDecoderNode : public ComputationNodeNonLooping /*ComputationNode*/, public NumInputs<3> -{ - typedef ComputationNodeNonLooping Base; - UsingComputationNodeMembersBoilerplate; - static const std::wstring TypeName() - { - return L"SequenceDecoderNode"; - } - -private: - // TODO: member variables go to the end - Matrix mAlpha; - Matrix mBacktrace; - - int mStartLab; // the starting output label - int mEndLab; // the ending output label, if avaliable - ElemType m_default_activity; - -public: - DeclareConstructorFromConfigWithNumInputs(SequenceDecoderNode); - SequenceDecoderNode(DEVICEID_TYPE deviceId, const wstring& name) - : Base(deviceId, name), - mAlpha(deviceId), - mBacktrace(deviceId), - mStartLab(-1), - mEndLab(-1) - { - } - - static void DecideStartEndingOutputLab(const Matrix& lbls, int& stt, int& stp) - { - if (stt != -1 && stp != -1) - return; /// have computed before - - int iNumPos = lbls.GetNumCols(); - - int firstLbl = -1; - for (int ik = 0; ik < lbls.GetNumRows(); ik++) - if (lbls(ik, 0) != 0) - { - firstLbl = ik; - break; - } - - int lastLbl = -1; - for (int ik = 0; ik < lbls.GetNumRows(); ik++) - if (lbls(ik, iNumPos - 1) != 0) - { - lastLbl = ik; - break; - } - - stt = firstLbl; - stp = lastLbl; - }; - - virtual void BackpropToNonLooping(size_t /*inputIndex*/) override //scaled by 2*number of elements in the Matrix - { - LogicError("SequenceDecoder is used for evaluation only."); - } - - virtual bool OutputUsedInComputingInputNodesGradients() const override - { - return false; - } - virtual bool InputUsedInComputingInputNodesGradients(size_t /*childIndex*/) const override - { - return false; - } - - /// compute posterior probability of label y at position t - virtual void /*ComputationNodeNonLooping::*/ ForwardPropNonLooping() override - { - DecideStartEndingOutputLab(Input(0)->Value(), mStartLab, mEndLab); - ForwardPropS(mAlpha, mBacktrace, Value(), Input(1)->Value(), - Input(2)->Value(), mStartLab, mEndLab); - } - - // compute forward backward algorithm - void ForwardPropS(Matrix& alpha, Matrix& backtrace, Matrix& functionValues, const Matrix& pos_scores, const Matrix& pair_scores, const size_t stt, const size_t stp) - { - /// to-do, each slice is for one sentence - /// to-do, number of slices correspond to number of frames - /// this implementation only supports one sentence per minibatch - - /// change to other values so can support multiple sentences in each minibatch - ForwardCompute(alpha, backtrace, pos_scores, pair_scores, stt); - BackwardCompute(functionValues, backtrace, stp); - }; - - /// compute forward backward algorithm - static void ForwardCompute(Matrix& alpha, - Matrix& backtrace, - const Matrix& pos_scores, const Matrix& pair_scores, - const size_t stt) - { - /// to-do, shift more than 1 to support muliple sentences per minibatch - int iNumPos = pos_scores.GetNumCols(); - int iNumLab = pos_scores.GetNumRows(); - size_t iTmp = 0; - - /// need to have - alpha.Resize(iNumLab, iNumPos); - backtrace.Resize(iNumLab, iNumPos); - - for (int t = 0; t < iNumPos; t++) - { - for (int k = 0; k < iNumLab; k++) - { - ElemType fTmp = (ElemType) LZERO; - if (t > 1) - { - for (int j = 0; j < iNumLab; j++) - { - ElemType fAlpha = alpha(j, t - 1) + pair_scores(k, j); - if (fAlpha > fTmp) - { - fTmp = fAlpha; - iTmp = j; - } - } - fTmp += pos_scores(k, t); /// include position dependent score - } - else - { - /// with constrain that the first word is labeled as a given symbol - iTmp = stt; - fTmp = 0; - if (t == 1) - { - fTmp = alpha(iTmp, t - 1); - fTmp += pair_scores(k, iTmp); - fTmp += pos_scores(k, t); - } - else - { - fTmp = (k == stt) ? pos_scores(k, t) : (ElemType) LZERO; - } - } - alpha(k, t) = fTmp; - backtrace(k, t) = (ElemType) iTmp; - } - } - }; - - /// compute backward algorithm - static void BackwardCompute( - Matrix& decodedpath, - const Matrix& backtrace, const size_t stp) - { - int iNumPos = backtrace.GetNumCols(); - int iNumLab = backtrace.GetNumRows(); - - decodedpath.Resize(iNumLab, iNumPos); - decodedpath.SetValue(0); - - size_t lastlbl = stp; - decodedpath(lastlbl, iNumPos - 1) = 1; - - for (int t = iNumPos - 1; t > 0; t--) - { - lastlbl = (size_t) backtrace(lastlbl, t); - decodedpath(lastlbl, t - 1) = 1; - } - }; - - /// need to feed in pseudo label data, which tells the decoder what is the beginning - /// and ending output symbol. these symbols will constrain the search space - virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override - { - Base::Validate(isFinalValidationPass); - InferMBLayoutFromInputsForStandardCase(); - - if (isFinalValidationPass) - if (!(Input(1)->GetSampleMatrixNumRows() == Input(2)->GetSampleMatrixNumRows() && // position dependent and pair scores have same number of labels - Input(0)->GetSampleMatrixNumRows() == Input(1)->GetSampleMatrixNumRows() && - Input(0)->GetSampleMatrixNumCols() == Input(1)->GetSampleMatrixNumCols() && // position dependent and pair scores have the same observation numbers - Input(2)->GetSampleMatrixNumCols() == Input(2)->GetSampleMatrixNumRows())) - { - LogicError("The Matrix dimension in the SequenceDecoderNode operation does not match."); - } - // BUGBUG: No SetDims()? - m_sampleLayout = TensorShape(); - } -}; - -template class SequenceDecoderNode; -template class SequenceDecoderNode; - -// ----------------------------------------------------------------------- -// StrideTimesNode (left, right, stride/*0=row, 1=col*/) -// TODO: why is 'stride' an Input and not just an initialization parameter? -// ----------------------------------------------------------------------- - -/** - Has a stride in particular dimensions of left matrix when doing times operation. - Example 1: column stride s - A in d x [s x T1] - B in T1 x s - C = A x B in d x s, and each element is computed as - c_{i,k} = \sum_j a_{i,j*s+k} b_{j,k} - where s is the stride in column. - - Example 2: - A in [s x T1] x d - B in d x s - C = A x B in T1 x s, and each element is computed as - c_{i,k} = \sum_j a_{i*s+k,j} b_{j,k} - where s is the stride in rows. - - Notice that s is equal to k. - */ -template -class StrideTimesNode : public ComputationNode, public NumInputs<3> -{ - typedef ComputationNode Base; - UsingComputationNodeMembersBoilerplate; - static const std::wstring TypeName() - { - return L"StrideTimes"; - } - - size_t m_strideDim; // the dimension index on which stride works - size_t m_stride; // the stride -private: - void UpdateStride(const Matrix& input1) - { - m_stride = input1.GetNumCols(); - } - -public: - DeclareConstructorFromConfigWithNumInputs(StrideTimesNode); - StrideTimesNode(DEVICEID_TYPE deviceId, const wstring& name) - : Base(deviceId, name), - m_stride(1) - { - } - // BUGBUG: This node needs to serialize and CopyTo m_stride - - virtual void /*ComputationNode::*/ BackpropTo(const size_t inputIndex, const FrameRange& fr) override - { - if (fr.IsAllFrames()) - { - NOT_IMPLEMENTED; - return; - } // TODO: remove these one by one. And why is this not implemented? - if (inputIndex > 2) - InvalidArgument("StrideTimes operation only takes three inputs."); - else if (inputIndex == 2) - return; // that's a constant - - Matrix sliceOutputGrad = GradientFor(fr); - - if (m_strideDim == 1) // column stride - { - if (inputIndex == 0) //left derivative - { - Matrix sliceInput1Value = Input(1)->ValueFor(fr); - - //BackpropToLeft1(sliceInput1Value, Input(0)->Gradient(), sliceOutputGrad); - - size_t r = Input(0)->GetSampleMatrixNumRows(); - size_t T1 = Input(0)->GetSampleMatrixNumCols() / GetNumParallelSequences(); // TODO: if T1 == GetNumTimeSteps() then we can simplify code below. - Matrix mTmp1(r, T1, sliceInput1Value.GetDeviceId()); - - // process sequence by sequence - for (size_t k = 0; k < GetNumParallelSequences(); k++) - { - mTmp1.SetValue(0); - auto mTmp2 = sliceInput1Value.ColumnSlice(k, 1); - auto mTmp3 = sliceOutputGrad.ColumnSlice(k, 1); - - BackpropToLeft1(mTmp2, mTmp1, mTmp3); - - for (size_t t = 0; t < T1; t++) - { - Input(0)->Gradient().ColumnSlice(t * GetNumParallelSequences() + k, 1) += mTmp1.ColumnSlice(t, 1); - } - } - } - else //right derivative - { - Matrix sliceInput1Grad = Input(1)->GradientFor(fr); - - //BackpropToRight(Input(0)->Value(), sliceInput1Grad, sliceOutputGrad); - - // process sequence by sequence - for (size_t k = 0; k < GetNumParallelSequences(); k++) - { - size_t r = Input(0)->GetSampleMatrixNumRows(); - size_t T1 = Input(0)->GetSampleMatrixNumCols() / GetNumParallelSequences(); // TODO: if T1 == GetNumTimeSteps() then we can simplify code below. - Matrix mTmp1(r, T1, sliceOutputGrad.GetDeviceId()); - for (size_t t = 0; t < T1; t++) - { - mTmp1.ColumnSlice(t, 1).SetValue(Input(0)->Value().ColumnSlice(t * GetNumParallelSequences() + k, 1)); - } - auto mTmp2 = sliceInput1Grad.ColumnSlice(k, 1); - auto mTmp3 = sliceOutputGrad.ColumnSlice(k, 1); - - BackpropToRight(mTmp1, mTmp2, mTmp3); - } - } - } - else if (m_strideDim == 0) // row stride - { - if (inputIndex == 0) //left derivative - { - Matrix sliceInput1Value = Input(1)->ValueFor(fr); - - for (size_t k = 0; k < GetNumParallelSequences(); k++) - { - size_t d = Input(1)->GetSampleMatrixNumRows(); - size_t T1 = Input(0)->GetSampleMatrixNumRows() / GetNumParallelSequences(); - Matrix mTmp1(sliceInput1Value.GetDeviceId()); - mTmp1.Resize(d, T1); - Matrix mTmp2 = sliceInput1Value.ColumnSlice(k, 1); - Matrix mTmp3 = sliceOutputGrad.ColumnSlice(k, 1); - BackpropToLeft(mTmp2, mTmp1, mTmp3); - - Matrix mTmp4(sliceInput1Value.GetDeviceId()); - for (size_t t = 0; t < T1; t++) - { - mTmp4 = mTmp1.ColumnSlice(t, 1); - mTmp4.Reshape(1, d); - Input(0)->Gradient().AddToRowSliceValuesOf(mTmp4, t * GetNumParallelSequences() + k, 1); - } - } - } - else //right derivative - { - Matrix sliceInput1Grad = Input(1)->GradientFor(fr); - - for (size_t k = 0; k < GetNumParallelSequences(); k++) - { - size_t d = Input(1)->GetSampleMatrixNumRows(); - size_t T1 = Input(0)->GetSampleMatrixNumRows() / GetNumParallelSequences(); - - Matrix mTmp0(sliceOutputGrad.GetDeviceId()); - mTmp0.Resize(1, d); - - Matrix mTmp1(sliceOutputGrad.GetDeviceId()); - mTmp1.Resize(T1, d); - for (size_t t = 0; t < T1; t++) - { - mTmp0.SetValue(0); - mTmp0.AddWithRowSliceValuesOf(Input(0)->Value(), t * GetNumParallelSequences() + k, 1); - mTmp1.AssignToRowSliceValuesOf(mTmp0, t, 1); - } - Matrix mTmp2 = sliceInput1Grad.ColumnSlice(k, 1); - Matrix mTmp3 = sliceOutputGrad.ColumnSlice(k, 1); - - BackpropToRight(mTmp1, mTmp2, mTmp3); - } - } - } - } - - // TODO: the following two functions only differ in the order of argument use in the final MultiplyAndAdd() --is that intended?? - static /*TODO: merge with call site*/ void BackpropToLeft1(const Matrix& inputFunctionValues, Matrix& inputGradientValues, const Matrix& gradientValues) - { -#if DUMPOUTPUT - gradientValues.Print("Gradient-in"); - inputGradientValues.Print("child Gradient-in/out"); - inputFunctionValues.Print("child Function values"); -#endif - //currently we only support one combination when the input is sparse. - if (inputFunctionValues.GetMatrixType() == SPARSE && inputGradientValues.GetMatrixType() == DENSE && gradientValues.GetMatrixType() == DENSE) - inputGradientValues.SwitchToMatrixType(SPARSE, MatrixFormat::matrixFormatSparseBlockCol, false); - - Matrix::MultiplyAndAdd(gradientValues, false, inputFunctionValues, true, inputGradientValues); -#if DUMPOUTPUT - inputGradientValues.Print("child Gradient-out"); -#endif - } - - static /*TODO: merge with call site*/ void BackpropToLeft(Matrix& inputFunctionValues, Matrix& inputGradientValues, const Matrix& gradientValues) - { -#if DUMPOUTPUT - gradientValues.Print("Gradient-in"); - inputGradientValues.Print("child Gradient-in/out"); - inputFunctionValues.Print("child Function values"); -#endif - //currently we only support one combination when the input is sparse. - if (inputFunctionValues.GetMatrixType() == SPARSE && inputGradientValues.GetMatrixType() == DENSE && gradientValues.GetMatrixType() == DENSE) - inputGradientValues.SwitchToMatrixType(SPARSE, MatrixFormat::matrixFormatSparseBlockCol, false); - - Matrix::MultiplyAndAdd(inputFunctionValues, false, gradientValues, true, inputGradientValues); - -#if DUMPOUTPUT - inputGradientValues.Print("child Gradient-out"); -#endif - } - - static /*TODO: merge with call site*/ void BackpropToRight(Matrix& inputFunctionValues, Matrix& inputGradientValues, const Matrix& gradientValues) - { -#if DUMPOUTPUT - gradientValues.Print("Gradient-in"); - inputGradientValues.Print("child Gradient-in/out"); - inputFunctionValues.Print("child Function values"); -#endif - Matrix::MultiplyAndAdd(inputFunctionValues, true, gradientValues, false, inputGradientValues); -#if DUMPOUTPUT - inputGradientValues.Print("child Gradient-out"); -#endif - } - - virtual bool OutputUsedInComputingInputNodesGradients() const override - { - return false; - } - - virtual void /*ComputationNode::*/ ForwardProp(const FrameRange& fr) override - { - size_t rows0 = Input(0)->GetSampleMatrixNumRows(); - Matrix sliceInput1Value = Input(1)->ValueFor(fr); - UpdateStride(sliceInput1Value); - - if (m_strideDim == 0) - SetDims(TensorShape(rows0 / GetNumParallelSequences()), HasMBLayout()); - else - SetDims(Input(0)->GetSampleLayout(), HasMBLayout()); - - Matrix sliceOutputValue = ValueFor(fr); - - // (TODO: these following assignments are leftovers of refactoring and can be short-circuited) - Matrix& functionValues = sliceOutputValue; - const Matrix& input0 = Input(0)->Value(); - const Matrix& input1 = sliceInput1Value; - -/** - A in d x [s x T1] - B in T1 x s - C = A x B in d x s, and each element is computed as - c_{i,k} = \sum_j a_{i,j*s+k} b_{j,k} - C in d x s - where s is the stride in column. - - Example 2: - A in [s x T1] x d - B in d x s - C = A x B in T1 x s, and each element is computed as - c_{i,k} = \sum_j a_{i*s+k,j} b_{j,k} - where s is the stride in rows. - C in T1 x s - - strideDim : 0 or 1 (meaning to apply to row or column) - */ -#if DUMPOUTPUT - input0.Print("StrideTimesNode - Input0"); -#endif - assert(m_strideDim == 0 || m_strideDim == 1); - Matrix mTmp1(input0.GetDeviceId()); - Matrix mTmp2(input0.GetDeviceId()); - if (m_strideDim == 1) // 1 = col stride; the example 1 case at column - { - assert(m_stride == input1.GetNumCols()); - size_t T1 = input0.GetNumCols() / m_stride; - assert(T1 == input1.GetNumRows()); - size_t d = input0.GetNumRows(); - functionValues.Resize(d, m_stride); - for (size_t k = 0; k < m_stride; k++) - { - mTmp1.Resize(d, T1); - for (size_t j = 0; j < T1; j++) - { - mTmp1.ColumnSlice(j, 1).SetValue(input0.ColumnSlice(j * m_stride + k, 1)); - } - - mTmp2 = input1.ColumnSlice(k, 1); - functionValues.ColumnSlice(k, 1).AssignProductOf(mTmp1, false, mTmp2, false); - } - } - else if (m_strideDim == 0) // 0 = row stride; the example 2 case at row - { - assert(m_stride == input1.GetNumCols()); - size_t T1 = input0.GetNumRows() / m_stride; - size_t d = input1.GetNumRows(); - assert(d == input0.GetNumCols()); - functionValues.Resize(T1, m_stride); - mTmp1.Resize(d, T1); - for (size_t k = 0; k < m_stride; k++) - { - for (size_t j = 0; j < T1; j++) - { - mTmp1.ColumnSlice(j, 1).AssignRowSliceValuesOf(input0, k + j * m_stride, 1); - } - - mTmp2 = input1.ColumnSlice(k, 1); - functionValues.ColumnSlice(k, 1).AssignProductOf(mTmp1, true, mTmp2, false); - } - } -#if NANCHECK - functionValues.HasNan("StrideTimes"); -#endif -#if DUMPOUTPUT - functionValues.Print("StrideTimesNode"); -#endif - } - - /** - three inputs - input0: left matrix - input1: right matrix - stridedim: single element no gradient matrix, 0 row stride / 1 column stride - */ - virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override - { - Base::Validate(isFinalValidationPass); - LinkToMBLayout(Input(1)->GetMBLayout()); // retains the layout of the right input - - if (Input(2)->Value().GetNumElements() != 1) - RuntimeError("%ls %ls operation: Input(2) should be a single element matrix and have the value 0 (row) or 1 (col).", NodeName().c_str(), OperationName().c_str()); - m_strideDim = (size_t) Input(2)->Value().Get00Element(); - if (m_strideDim != 0 && m_strideDim != 1) - RuntimeError("%ls %ls operation: Input(2) should be a single element matrix and have the value 0 (row) or 1 (col).", NodeName().c_str(), OperationName().c_str()); - //if (Input(2)->m_needGradient) // disabled because this is a flag that belongs to Network. Node should simply not propagate anything into it - // RuntimeError("StrideTimes: No gradient update should be on input(2)."); - - size_t rows0 = Input(0)->GetSampleMatrixNumRows(), cols0 = Input(0)->GetSampleMatrixNumCols(); - size_t rows1 = Input(1)->GetSampleMatrixNumRows(); - - if (m_strideDim == 0) // by row - { - if (isFinalValidationPass && rows1 != cols0) - RuntimeError("The Matrix dimension in the StrideTimes operation in dim %d does not match for cols %d in A and rows %d in B.", (int) m_strideDim, (int) cols0, (int) rows1); - size_t T1 = rows0 / m_stride; - SetDims(TensorShape(T1), HasMBLayout()); - //after multiplication the structure is lost - } - - else // by col - { - if (isFinalValidationPass && cols0 != rows1 * m_stride) - RuntimeError("The Matrix dimension in the StrideTimes operation in dim %d does not match for cols %d in A and row number %d in B.", (int) m_strideDim, (int) cols0, (int) rows1); - SetDims(TensorShape(rows0), HasMBLayout()); - //after multiplication the structure is lost - } - } -}; - -template class StrideTimesNode; -template class StrideTimesNode; - -// ----------------------------------------------------------------------- -// ParallelNode (input0, input1) -// TODO: How is this different from RowStack? -// ----------------------------------------------------------------------- - -/** - parallel node to join two streams into one - - join parallel children node, avoids any operations except putting outputs from children to corresponding columns - input(0) : [nDim0 X T] - input(1) : [nDim1 X T] - output : [[nDim0 + nDim1] X T] - */ -template -class ParallelNode : public ComputationNodeNonLooping /*ComputationNode*/, public NumInputs<2> -{ - typedef ComputationNodeNonLooping Base; - UsingComputationNodeMembersBoilerplate; - static const std::wstring TypeName() - { - return L"Parallel"; - } - -public: - DeclareConstructorFromConfigWithNumInputs(ParallelNode); - ParallelNode(DEVICEID_TYPE deviceId, const wstring& name) - : Base(deviceId, name) - { - } - - virtual void BackpropToNonLooping(size_t inputIndex) override - { - if (inputIndex > 1) - InvalidArgument("Parallel operation only takes two input."); - ComputationNodePtr child = Input(inputIndex); - size_t startidx = (inputIndex == 0) ? 0 : Input(0)->GetSampleMatrixNumRows(); - size_t nrows = child->GetSampleMatrixNumRows(); - - // TODO: why is this needed? If it is, it should be solved more centrally. - if (child->Gradient().GetNumRows() != child->GetSampleMatrixNumRows() || child->Gradient().GetNumCols() != GetSampleMatrixNumCols()) - { - child->Gradient().Resize(child->GetSampleMatrixNumRows(), child->GetSampleMatrixNumCols()); - child->Gradient().SetValue(0); - } - - Matrix tmpMat(m_deviceId); - tmpMat.AssignRowSliceValuesOf(Gradient(), startidx, nrows); - - BackpropToS(tmpMat, child->Gradient()); - } - - virtual bool OutputUsedInComputingInputNodesGradients() const override - { - // The ParallelNode does not require its output value for computing - // the gradients of its input nodes - return false; - } - - virtual bool InputUsedInComputingInputNodesGradients(size_t childIndex) const override - { - // The ParallelNode does not require any of it's input's values for computing - // the gradients of its input nodes - UNREFERENCED_PARAMETER(childIndex); - return false; - } - - /*TODO: merge with call site*/ void BackpropToS(Matrix& gradientValues, Matrix& inputGradientValues) - { - inputGradientValues += gradientValues; - } - - virtual void /*ComputationNodeNonLooping::*/ ForwardPropNonLooping() override - { - ForwardPropS(Value(), Input(0)->Value(), Input(1)->Value()); - } - - /*TODO: merge with call site*/ void ForwardPropS(Matrix& functionValues, Matrix& inputFunctionValues0, Matrix& inputFunctionValues1) - { - size_t rows0 = inputFunctionValues0.GetNumRows(), cols0 = inputFunctionValues0.GetNumCols(); - size_t rows1 = inputFunctionValues1.GetNumRows(), cols1 = inputFunctionValues1.GetNumCols(); - - if (cols0 != cols1) - LogicError("ParallelNode: column dimension mismatched!"); - - functionValues.Resize(rows0 + rows1, cols0); - functionValues.SetValue(0); - - functionValues.AssignToRowSliceValuesOf(inputFunctionValues0, 0, rows0); - functionValues.AssignToRowSliceValuesOf(inputFunctionValues1, rows0, rows1); - } - - /// input(0) : [nDim1 X T] - /// input(1) : [nDim2 X T] - /// output : [[nDim1 + nDim2] X T] - virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override - { - Base::Validate(isFinalValidationPass); - InferMBLayoutFromInputsForStandardCase(); - - size_t rows1 = Input(1)->GetSampleMatrixNumRows(); - - size_t rows0 = Input(0)->GetSampleMatrixNumRows(); - - size_t rows = rows0 + rows1; - - SetDims(TensorShape(rows), HasMBLayout()); - m_sampleLayout = GetInputSampleLayout(0); - // BUGBUG: Inconsistent with 'rows' - } - -public: - virtual bool UnitTest() - { - size_t nT = 3; - size_t nInput0 = 3; - size_t nInput1 = 3; - - Matrix f0(m_deviceId), func(m_deviceId), f1(m_deviceId); - - f0 = Input(0)->Value(); - f1 = Input(1)->Value(); - func = Value(); - - Input(0)->SetDims1(nInput0, nT); - Input(0)->UpdateFunctionValuesSize(); - Input(0)->Value().SetValue(0); - Input(0)->Value()(0, 0) = 1; - Input(0)->Value()(0, 1) = 2; - Input(0)->Value()(0, 2) = 3; - - Input(1)->SetDims1(nInput1, nT); - Input(1)->UpdateFunctionValuesSize(); - Input(1)->Value().SetValue(0); - Input(1)->Value()(0, 0) = 4; - Input(1)->Value()(0, 1) = 5; - Input(1)->Value()(0, 2) = 6; - SetDims1(nInput0 + nInput1, nT); - UpdateFunctionValuesSize(); - - ForwardProp(FrameRange(m_pMBLayout)); - - /// check with expected values - if (!ISCLOSE(Value()(0, 0), 1, EPSILON) || - !ISCLOSE(Value()(0, 1), 2, EPSILON) || - !ISCLOSE(Value()(0, 2), 3, EPSILON) || - !ISCLOSE(Value()(3, 0), 4, EPSILON) || - !ISCLOSE(Value()(3, 1), 5, EPSILON) || - !ISCLOSE(Value()(3, 2), 6, EPSILON)) - return false; - Value().TransferToDeviceIfNotThere(m_deviceId, true); - - Gradient().Resize(nInput0 + nInput1, nT); - Gradient().SetValue(0); - Input(0)->Gradient().Resize(nInput0, nT); - Input(1)->Gradient().Resize(nInput1, nT); - Input(0)->Gradient().SetValue(0); - Input(1)->Gradient().SetValue(0); - Gradient()(0, 0) = 1; - Gradient()(0, 1) = 2; - Gradient()(0, 2) = 3; - Gradient()(3, 0) = 4; - Gradient()(3, 1) = 5; - Gradient()(3, 2) = 6; - - BackpropTo(0, FrameRange(m_pMBLayout)); - BackpropTo(1, FrameRange(m_pMBLayout)); - - /// check with expected values - if (!ISCLOSE(Input(0)->Gradient()(0, 0), 1, EPSILON) || !ISCLOSE(Input(0)->Gradient()(0, 1), 2, EPSILON) || !ISCLOSE(Input(0)->Gradient()(0, 2), 3, EPSILON) || !ISCLOSE(Input(1)->Gradient()(0, 0), 4, EPSILON) || !ISCLOSE(Input(1)->Gradient()(0, 1), 5, EPSILON) || !ISCLOSE(Input(1)->Gradient()(0, 2), 6, EPSILON)) - return false; - - Input(0)->Gradient().TransferToDeviceIfNotThere(m_deviceId, true); - Input(1)->Gradient().TransferToDeviceIfNotThere(m_deviceId, true); - - return true; - } -}; - -template class ParallelNode; -template class ParallelNode; - -// ----------------------------------------------------------------------- -// LSTMNode (obs, inputGate, forgetGate, outputGate, memoryCellWgt) -// deprecated early implementation of LSTM operating on minibatches directly -// - input(0) : child with dimension [inputdim x T] -// - input(1) : input gate [outputdim x [inputdim + outputdim + 2]] bi, Wxi, Whi, Wci -// - input(2) : forget gate [outputdim x [inputdim + outputdim + 2]] for bf, Wxf, Whf, Wcf -// - input(3) : output gate [outputdim x [inputdim + outputdim + 2]] for bo, Wxo, Who, and Wco -// - input(4) : memory cell weight [outputdim x [inputdim + outputdim + 1]] for bc, Wxc, and Whc -// - output : dimension [outputdim x T] -// ----------------------------------------------------------------------- - -/** - LSTM specific node. This node uses matrix operations to have LSTM functionality. - It avoids using general recurrent loop operations in the network operations in ComputationNetwork. - - Developed by Kaisheng Yao - Used in the following works: - K. Yao, G. Zweig, "Sequence to sequence neural net models for graphone to phoneme conversion", in Interspeech 2015 - */ -template -class LSTMNode : public ComputationNodeNonLooping /*ComputationNode*/, public NumInputs<5> -{ - typedef ComputationNodeNonLooping Base; - UsingComputationNodeMembersBoilerplate; - static const std::wstring TypeName() - { - return L"LSTM"; - } - - // BUGBUG: These flags no longer exist outside. I moved this here to make it compile, but this node is no longer functional. - enum class MinibatchPackingFlags : char // (note: not using unsigned char because these go into a matrix, and we use Matrix, since we use it as a data holder) - { - None = 0, - SequenceStart = 1 << 0, // binary 0001 frame is first of an utterance - SequenceEnd = 1 << 1, // binary 0010 frame is last of an utterance - NoFeature = 1 << 2, // binary 0100 frame has no feature (e.g. a gap due to BPTT) - NoLabel = 1 << 3, // binary 1000 frame has no label - - NoInput = NoFeature | NoLabel, // Note: Once we refactorized the reader, NoInput will no longer needed. - SequenceStartOrNoFeature = SequenceStart | NoFeature, - SequenceEndOrNoFeature = SequenceEnd | NoFeature, - SequenceStartOrEndOrNoFeature = SequenceStart | SequenceEnd | NoFeature, - }; - -public: - DeclareConstructorFromConfigWithNumInputs(LSTMNode); - LSTMNode(DEVICEID_TYPE deviceId, const wstring& name) - : Base(deviceId, name), - m_State(deviceId), - m_PastState(deviceId), - m_PastOutput(deviceId), - m_Gi(deviceId), - m_Gf(deviceId), - m_Go(deviceId), - grdToObs(deviceId), - grdToInputGate(deviceId), - grdToForgetGate(deviceId), - grdToOutputGate(deviceId), - grdToCellWgt(deviceId), - tanhObs(deviceId), - tanhState(deviceId), - m_tempMatrix(deviceId), - mSlicePrevState(deviceId), - mSlicePrevOutput(deviceId), - grdBeforeInputGate(deviceId), - grdBeforeForget(deviceId), - grdBeforeGo(deviceId), - grdToCell(deviceId), - grdBeforeTanhInputGate(deviceId), - m_obs_error_from_future_minibatch(deviceId), - m_state_error_from_future_minibatch(deviceId), - mLastState(deviceId), - mLastOutput(deviceId), - m_inputDim(0), - m_outputDim(0), - m_use_errors_from_future_minibatch(false), - m_DefaultState((ElemType) DEFAULT_HIDDEN_ACTIVATION) - { - } - - virtual void Save(File& fstream) const override - { - Base::Save(fstream); - fstream << m_inputDim << m_outputDim; - fstream << m_DefaultState; - } - - virtual void Load(File& fstream, size_t modelVersion) override - { - Base::Load(fstream, modelVersion); - if (modelVersion == 2) - fstream >> m_inputDim >> m_outputDim; - fstream >> m_DefaultState; - } - - virtual void CopyTo(ComputationNodeBasePtr nodeP, const std::wstring& newName, const CopyNodeFlags flags) const override - { - Base::CopyTo(nodeP, newName, flags); - if (flags & CopyNodeFlags::copyNodeValue) - { - auto node = dynamic_pointer_cast>(nodeP); - node->m_inputDim = m_inputDim; - node->m_outputDim = m_outputDim; - - node->m_State = m_State; // hidden state activity - node->m_PastState = m_PastState; // state activity in the previous minibatch - node->m_PastOutput = m_PastOutput; // output in the previou minibatch - - node->m_Gi = m_Gi; // input gate activity - node->m_Gf = m_Gf; // forget gate activity - node->m_Go = m_Go; // output gate activity - - node->mSlicePrevOutput = mSlicePrevOutput; - node->mSlicePrevState = mSlicePrevState; - - node->m_use_errors_from_future_minibatch = m_use_errors_from_future_minibatch; - - node->m_DefaultState = m_DefaultState; - } - } - - virtual void BackpropToNonLooping(size_t inputIndex) override - { - if (inputIndex > 4) - InvalidArgument("LSTM operation only takes five inputs."); - - size_t nT = Input(0)->GetSampleMatrixNumCols(); - size_t inputDim = Input(0)->GetSampleMatrixNumRows(); - size_t outputDim = Input(1)->GetSampleMatrixNumRows(); - - if (m_GradientComputed == false) - { - if (GetSampleMatrixNumCols() != Gradient().GetNumCols() || - GetSampleMatrixNumRows() != Gradient().GetNumRows()) - { - RuntimeError("LSTMNode::GradientValue size doesn't match to the function value size"); - } - - // reset gradients - grdToObs.Resize(inputDim, nT); - grdToObs.SetValue(0); - grdToInputGate.Resize(Input(1)->GetSampleMatrixNumRows(), Input(1)->GetSampleMatrixNumCols()); - grdToInputGate.SetValue(0); - grdToForgetGate.Resize(Input(2)->GetSampleMatrixNumRows(), Input(2)->GetSampleMatrixNumCols()); - grdToForgetGate.SetValue(0); - grdToOutputGate.Resize(Input(3)->GetSampleMatrixNumRows(), Input(3)->GetSampleMatrixNumCols()); - grdToOutputGate.SetValue(0); - grdToCellWgt.Resize(Input(4)->GetSampleMatrixNumRows(), Input(4)->GetSampleMatrixNumCols()); - grdToCellWgt.SetValue(0); - - Matrix slicePrevOutput(m_deviceId), slicePrevState(m_deviceId); - Matrix grdToPrevOutput(m_deviceId), grdToPrevState(m_deviceId); - Matrix stateError(m_deviceId); - slicePrevState.Resize(outputDim, GetNumParallelSequences()); - slicePrevOutput.Resize(outputDim, GetNumParallelSequences()); - slicePrevOutput.SetValue(0); - - stateError.Resize(slicePrevState.GetNumRows(), slicePrevState.GetNumCols()); - - grdToPrevOutput.Resize(slicePrevOutput.GetNumRows(), slicePrevOutput.GetNumCols()); - grdToPrevState.Resize(slicePrevState.GetNumRows(), slicePrevState.GetNumCols()); - grdToPrevOutput.SetValue(0); - grdToPrevState.SetValue(0); - - for (int timeIdxInSeq = nT - GetNumParallelSequences(); timeIdxInSeq >= 0; timeIdxInSeq -= GetNumParallelSequences()) - { - FrameRange fr(m_pMBLayout, timeIdxInSeq); - Matrix sliceObs = Input(0)->ValueFor(fr); - Matrix sliceOutput = ValueFor(fr); - Matrix sliceState = DataFor(m_State, fr); - - Matrix sliceGi = DataFor(m_Gi, fr); - Matrix sliceGf = DataFor(m_Gf, fr); - Matrix sliceGo = DataFor(m_Go, fr); - - Matrix sliceTanhState = DataFor(tanhState, fr); - Matrix sliceTanhObs = DataFor(tanhObs, fr); - - Matrix error = GradientFor(fr); - - Matrix grdToObsSlice(this->m_deviceId); - -#ifdef DEBUG_DECODER - fprintf(stderr, "original output error [%ld] norm = %.8e\n", timeIdxInSeq, error.FrobeniusNorm()); -#endif - - PrepareThisErrorsBeforeBackProp(timeIdxInSeq, nT, error, stateError, grdToPrevOutput, grdToPrevState, - m_obs_error_from_future_minibatch, m_state_error_from_future_minibatch, GetNumParallelSequences(), nullptr /*&m_pMBLayout->GetM()*/ /*BUGBUG: no longer functional*/); - -#ifdef DEBUG_DECODER - fprintf(stderr, "output error [%ld] norm = %.8e\n", timeIdxInSeq, error.FrobeniusNorm()); - fprintf(stderr, "state error [%ld] norm = %.8e\n", timeIdxInSeq, stateError.FrobeniusNorm()); -#endif - - grdToPrevOutput.Resize(slicePrevOutput.GetNumRows(), slicePrevOutput.GetNumCols()); - grdToPrevState.Resize(slicePrevState.GetNumRows(), slicePrevState.GetNumCols()); - grdToPrevOutput.SetValue(0); - grdToPrevState.SetValue(0); - - PrepareHistory(timeIdxInSeq, mSlicePrevOutput, mSlicePrevState, Value(), m_State, m_PastOutput, m_PastState, GetNumParallelSequences(), m_DefaultState, nullptr /*&m_pMBLayout->GetM()*/ /*BUGBUG: no longer functional*/); - - ComputeInputGradientWrtGates( - error, - sliceObs, - grdToObsSlice, - Input(1)->Value(), - grdToInputGate, - Input(2)->Value(), - grdToForgetGate, - Input(3)->Value(), - grdToOutputGate, - Input(4)->Value(), - grdToCellWgt, - mSlicePrevOutput, - mSlicePrevState, - stateError, - sliceState, - sliceTanhState, - sliceTanhObs, - sliceGi, - sliceGf, - sliceGo, - grdToPrevOutput, - grdToPrevState, - m_tempMatrix); - DataFor(grdToObs, fr).SetValue(grdToObsSlice); - - PrepareErrors(timeIdxInSeq, grdToPrevOutput, grdToPrevState, GetNumParallelSequences(), nullptr /*&m_pMBLayout->GetM()*/ /*BUGBUG: no longer functional*/); - } -#ifdef DEBUG_DECODER - fprintf(stderr, "after error prop b_c norm = %.8e\n", Input(4)->Value().ColumnSlice(0, 1).FrobeniusNorm()); -#endif - m_obs_error_from_future_minibatch = grdToPrevOutput; - m_state_error_from_future_minibatch = grdToPrevState; - -#ifdef DEBUG_DECODER - fprintf(stderr, "pass error to encoder error = %.4e state error = %.4e\n", m_obs_error_from_future_minibatch.FrobeniusNorm(), m_state_error_from_future_minibatch.FrobeniusNorm()); -#endif - m_GradientComputed = true; - } - - if (inputIndex == 0) //derivative with regard to the observation - { - if (Input(inputIndex)->Gradient().HasNoElements()) - Input(inputIndex)->Gradient().SetValue(grdToObs); - else - Input(inputIndex)->Gradient() += grdToObs; - } - - if (inputIndex == 1) - { - if (Input(inputIndex)->Gradient().HasNoElements()) - Input(inputIndex)->Gradient().SetValue(grdToInputGate); - else - Input(inputIndex)->Gradient() += grdToInputGate; - } - - if (inputIndex == 2) - { - if (Input(inputIndex)->Gradient().HasNoElements()) - Input(inputIndex)->Gradient().SetValue(grdToForgetGate); - else - Input(inputIndex)->Gradient() += grdToForgetGate; - } - - if (inputIndex == 3) - { - if (Input(inputIndex)->Gradient().HasNoElements()) - Input(inputIndex)->Gradient().SetValue(grdToOutputGate); - else - Input(inputIndex)->Gradient() += grdToOutputGate; - } - - if (inputIndex == 4) - { - if (Input(inputIndex)->Gradient().HasNoElements()) - Input(inputIndex)->Gradient().SetValue(grdToCellWgt); - else - Input(inputIndex)->Gradient() += grdToCellWgt; - } -#ifdef DEBUG_DECODER - fprintf(stderr, "LSTM gradient[%d] norm = %.8e\n", inputIndex, Input(inputIndex)->Gradient().FrobeniusNorm()); -#endif - } - - static void WINAPI GradientOfTanh(const Matrix& functionValues, - const Matrix& gradientOut, - Matrix& inputGradientValues, - Matrix& extTmp) - { - Matrix mTmp(inputGradientValues.GetDeviceId()); - extTmp.AssignElementProductOf(functionValues, functionValues); // v .* v - mTmp.AssignDifferenceOf(1, extTmp); // 1-v^2 - if (inputGradientValues.GetNumRows() != functionValues.GetNumRows() || - inputGradientValues.GetNumCols() != functionValues.GetNumCols()) - LogicError("LSTMNode::GradientOfTanh : inputGradientValues need to be pre-allocated!"); - inputGradientValues.AddElementProductOf(gradientOut, mTmp); // d .* ((1-v) .* v)) - } - - static void WINAPI ComputeInputGradientWrtGates( - const Matrix& outGrd, // the error to h_t from upper layer - const Matrix& obs, - Matrix& grdToObs, - const Matrix& mInputGate, - Matrix& grdToInputGate, - const Matrix& mForgetGate, - Matrix& grdToForgetGate, - const Matrix& mOutputGate, - Matrix& grdToOutputGate, - const Matrix& mCellWgt, - Matrix& grdToCellWgt, - const Matrix& prevOutput, - const Matrix& prevState, - const Matrix& stateError, // the error propagated to cell from t+1 - const Matrix& state, - const Matrix& tanhState, - const Matrix& tanhBeforeApplyingInputGating, - const Matrix& gi, - const Matrix& gf, - const Matrix& go, - Matrix& grdToPrevOutput, - Matrix& grdToPrevState, - Matrix& tmpMat) - { - int inputDim = obs.GetNumRows(); - int outputDim = mOutputGate.GetNumRows(); - - assert(grdToPrevOutput.FrobeniusNorm() == 0); - assert(grdToPrevState.FrobeniusNorm() == 0); - assert(state.FrobeniusNorm() > 0); - Matrix Who = mOutputGate.ColumnSlice(1 + inputDim, outputDim); - Matrix Wco = mOutputGate.ColumnSlice(1 + inputDim + outputDim, 1); - Matrix Wxo = mOutputGate.ColumnSlice(1, inputDim); - Matrix grdToWho = grdToOutputGate.ColumnSlice(1 + inputDim, outputDim); - Matrix grdToWco = grdToOutputGate.ColumnSlice(1 + inputDim + outputDim, 1); - Matrix grdToWxo = grdToOutputGate.ColumnSlice(1, inputDim); - Matrix grdTobo = grdToOutputGate.ColumnSlice(0, 1); - - Matrix Whf = mForgetGate.ColumnSlice(1 + inputDim, outputDim); - Matrix Wcf = mForgetGate.ColumnSlice(1 + inputDim + outputDim, 1); - Matrix Wxf = mForgetGate.ColumnSlice(1, inputDim); - Matrix grdToWhf = grdToForgetGate.ColumnSlice(1 + inputDim, outputDim); - Matrix grdToWcf = grdToForgetGate.ColumnSlice(1 + inputDim + outputDim, 1); - Matrix grdToWxf = grdToForgetGate.ColumnSlice(1, inputDim); - Matrix grdTobf = grdToForgetGate.ColumnSlice(0, 1); - - Matrix Wxc = mCellWgt.ColumnSlice(1, inputDim); - Matrix Whc = mCellWgt.ColumnSlice(1 + inputDim, outputDim); - Matrix grdToWxc = grdToCellWgt.ColumnSlice(1, inputDim); - Matrix grdToWhc = grdToCellWgt.ColumnSlice(1 + inputDim, outputDim); - Matrix grdTobc = grdToCellWgt.ColumnSlice(0, 1); - - Matrix Whi = mInputGate.ColumnSlice(1 + inputDim, outputDim); - Matrix Wci = mInputGate.ColumnSlice(1 + inputDim + outputDim, 1); - Matrix Wxi = mInputGate.ColumnSlice(1, inputDim); - Matrix grdToWhi = grdToInputGate.ColumnSlice(1 + inputDim, outputDim); - Matrix grdToWci = grdToInputGate.ColumnSlice(1 + inputDim + outputDim, 1); - Matrix grdToWxi = grdToInputGate.ColumnSlice(1, inputDim); - Matrix grdTobi = grdToInputGate.ColumnSlice(0, 1); - - // error backpropagate to output gate - Matrix grdToGo(tmpMat.GetDeviceId()), gradientOfSigmoid(tmpMat.GetDeviceId()); - Matrix grdBeforeGo(tmpMat.GetDeviceId()), grdBeforeInputGate(tmpMat.GetDeviceId()); - Matrix grdToCell(tmpMat.GetDeviceId()); - - tmpMat.AssignElementProductOf(outGrd, tanhState); // error to o_t - gradientOfSigmoid.AssignSigmoidDerivativeOf(go); - grdBeforeGo.AssignElementProductOf(tmpMat, gradientOfSigmoid); // error before softmax -#ifdef DEBUG_DECODER - fprintf(stderr, "output gate error = %.4e\n", grdBeforeGo(0, 0)); -#endif - Matrix::MultiplyAndAdd(Who, true, grdBeforeGo, false, grdToPrevOutput); // error to previous output - Matrix::MultiplyAndAdd(Wxo, true, grdBeforeGo, false, grdToObs); // error to observation - tmpMat = grdBeforeGo; - tmpMat.ColumnElementMultiplyWith(Wco); - grdToCell = tmpMat; // error to memory cell - - Matrix::MultiplyAndAdd(grdBeforeGo, false, prevOutput, true, grdToWho); // gradient to Who - Matrix::MultiplyAndAdd(grdBeforeGo, false, obs, true, grdToWxo); // gradient to Wxo - tmpMat.AssignInnerProductOf(grdBeforeGo, state, false); - grdToWco += tmpMat; // to Wco - for (size_t i = 0; i < grdBeforeGo.GetNumCols(); i++) - { - grdTobo += grdBeforeGo.ColumnSlice(i, 1); // gradient to bo - } - - grdToGo.AssignElementProductOf(outGrd, go); // error to tanh - GradientOfTanh(tanhState, grdToGo, grdToCell, tmpMat); // error to memory cell - grdToCell += stateError; // add error to memory cell from t+1 -#ifdef DEBUG_DECODER - fprintf(stderr, "previous state[0] = %.4e norm = %.4e\n", prevState(0, 0), prevState.FrobeniusNorm()); - fprintf(stderr, "state error = %.4e\n", grdToCell(0, 0)); - fprintf(stderr, "state error norm = %.4e\n", grdToCell.FrobeniusNorm()); -#endif - // error backpropagate to memory cells - grdToPrevState.AssignElementProductOf(gf, grdToCell); // error to previous memory cell - // be careful, need to double check if errors are missing - - Matrix grdBeforeForget(tmpMat.GetDeviceId()); - tmpMat.AssignElementProductOf(prevState, grdToCell); // error to f_t - gradientOfSigmoid.AssignSigmoidDerivativeOf(gf); - grdBeforeForget.AssignElementProductOf(gradientOfSigmoid, tmpMat); // error before forget gate -#ifdef DEBUG_DECODER - fprintf(stderr, "forget gate error = %.4e\n", grdBeforeForget(0, 0)); -#endif - - Matrix::MultiplyAndAdd(Whf, true, grdBeforeForget, false, grdToPrevOutput); // error to previous output - tmpMat = grdBeforeForget; - tmpMat.ColumnElementMultiplyWith(Wcf); - grdToPrevState += tmpMat; // error to previous state - - Matrix::MultiplyAndAdd(Wxf, true, grdBeforeForget, false, grdToObs); // error to observation - - Matrix::MultiplyAndAdd(grdBeforeForget, false, prevOutput, true, grdToWhf); // gradient to Whf - tmpMat.AssignInnerProductOf(grdBeforeForget, prevState, false); - grdToWcf += tmpMat; // gradient to Wcf - - Matrix::MultiplyAndAdd(grdBeforeForget, false, obs, true, grdToWxf); // gradient to Wxf - for (size_t i = 0; i < grdBeforeForget.GetNumCols(); i++) - grdTobf += grdBeforeForget.ColumnSlice(i, 1); // gradient to bf - - // error backpropagate to input gate - tmpMat.AssignElementProductOf(tanhBeforeApplyingInputGating, grdToCell); - gradientOfSigmoid.AssignSigmoidDerivativeOf(gi); - grdBeforeInputGate.AssignElementProductOf(gradientOfSigmoid, tmpMat); // error before input gate -#ifdef DEBUG_DECODER - fprintf(stderr, "input gate error = %.4e\n", grdBeforeInputGate(0, 0)); -#endif - - Matrix::MultiplyAndAdd(Whi, true, grdBeforeInputGate, false, grdToPrevOutput); // error to previous output - tmpMat = grdBeforeInputGate; - tmpMat.ColumnElementMultiplyWith(Wci); - grdToPrevState += tmpMat; // error to previous state - -#ifdef DEBUG_DECODER - fprintf(stderr, "to previous state error = %.4e\n", grdToPrevState(0, 0)); - fprintf(stderr, "to previous state error norm = %.4e\n", grdToPrevState.FrobeniusNorm()); -#endif - Matrix::MultiplyAndAdd(Wxi, true, grdBeforeInputGate, false, grdToObs); // error to observation - - Matrix::MultiplyAndAdd(grdBeforeInputGate, false, prevOutput, true, grdToWhi); // gradient to Whi - tmpMat.AssignInnerProductOf(grdBeforeInputGate, prevState, false); - grdToWci += tmpMat; // gradient to Wci - Matrix::MultiplyAndAdd(grdBeforeInputGate, false, obs, true, grdToWxi); // gradient to Wxi - for (size_t i = 0; i < grdBeforeInputGate.GetNumCols(); i++) - grdTobi += grdBeforeInputGate.ColumnSlice(i, 1); // gradient to bi - - // error backpropagate to inputs - Matrix grdTmp2(tmpMat.GetDeviceId()); - Matrix grdBeforeTanhInputGate(tmpMat.GetDeviceId()); - grdTmp2.AssignElementProductOf(gi, grdToCell); - grdBeforeTanhInputGate.Resize(tanhBeforeApplyingInputGating.GetNumRows(), tanhBeforeApplyingInputGating.GetNumCols()); - GradientOfTanh(tanhBeforeApplyingInputGating, grdTmp2, grdBeforeTanhInputGate, tmpMat); // error to memory cell - Matrix::MultiplyAndAdd(Wxc, true, grdBeforeTanhInputGate, false, grdToObs); // error to observation -#ifdef DEBUG_DECODER - fprintf(stderr, "to observation error = %.4e\n", grdToObs(0, 0)); -#endif - - Matrix::MultiplyAndAdd(Whc, true, grdBeforeTanhInputGate, false, grdToPrevOutput); // error to previous output - Matrix::MultiplyAndAdd(grdBeforeTanhInputGate, false, obs, true, grdToWxc); // gradient to Wxc - - Matrix::MultiplyAndAdd(grdBeforeTanhInputGate, false, prevOutput, true, grdToWhc); // gradient to Whc - for (size_t i = 0; i < grdBeforeTanhInputGate.GetNumCols(); i++) - grdTobc += grdBeforeTanhInputGate.ColumnSlice(i, 1); // gradient to bc - } - - /** - get the segmentation information, SENTENECE_BEGIN, ((int) MinibatchPackingFlags::None), ((int) MinibatchPackingFlags::NoInput) - for time at t and stream of streamid - */ - int GetSegInfo(size_t t, size_t streamid) - { - if (streamid >= GetNumParallelSequences()) - LogicError("GetSegInfo: stream id %d is larger than the number of streams %d", (int) streamid, (int) GetNumParallelSequences()); - - Matrix thisCol; // BUGBUG: These flags no longer exist. This code is no longer functional. - //size_t nT = Input(0)->GetSampleMatrixNumCols(); - //if (t >= nT) - // LogicError("GetSegInfo: time %d times is larger than the total number of observations %d", (int)t, (int)nT); - //int utt_t = (int)t / GetNumParallelSequences(); - //auto thisCol = m_pMBLayout->GetFrame(utt_t).first; - thisCol.Reshape(1, GetNumParallelSequences()); - return (int) thisCol.ColumnSlice(streamid, 1).Get00Element(); - } - - /** - save the last hidden layer activity and output - */ - void SaveLastStateActity() - { - size_t nT = Input(0)->GetSampleMatrixNumCols(); - size_t outputDim = Input(1)->GetSampleMatrixNumRows(); - - // save the hidden activities and output for the next minibatch - mLastOutput.Resize(outputDim, GetNumParallelSequences()); - mLastState.Resize(outputDim, GetNumParallelSequences()); - - for (size_t i = 0; i < GetNumParallelSequences(); i++) - { - for (int t = nT - GetNumParallelSequences() + i; t >= 0; t -= GetNumParallelSequences()) - { - if (GetSegInfo(t, i) == ((int) MinibatchPackingFlags::None)) - { - mLastOutput.ColumnSlice(i, 1).SetValue(Value().ColumnSlice(t, 1)); - mLastState.ColumnSlice(i, 1).SetValue(m_State.ColumnSlice(t, 1)); - break; - } - } - } - } - - virtual void /*ComputationNodeNonLooping::*/ ForwardPropNonLooping() override - { - size_t nT = Input(0)->GetSampleMatrixNumCols(); - size_t outputDim = Input(1)->GetSampleMatrixNumRows(); - - { - SetDims1(outputDim, nT); - Value().SetValue(NAN); // set to this extrem value so, if anything wrong in later procedure, problems can be easily spotted. - m_State.Resize(outputDim, nT); - m_State.SetValue(NAN); // set to this extrem value so, if anything wrong in later procedure, problems can be easily spotted. - m_Gi.Resize(outputDim, nT); - m_Gi.SetValue(NAN); // set to this extrem value so, if anything wrong in later procedure, problems can be easily spotted. - m_Gf.Resize(outputDim, nT); - m_Gf.SetValue(NAN); // set to this extrem value so, if anything wrong in later procedure, problems can be easily spotted. - m_Go.Resize(outputDim, nT); - m_Go.SetValue(NAN); // set to this extrem value so, if anything wrong in later procedure, problems can be easily spotted. - tanhState.Resize(outputDim, nT); - tanhState.SetValue(NAN); // set to this extrem value so, if anything wrong in later procedure, problems can be easily spotted. - tanhObs.Resize(outputDim, nT); - tanhObs.SetValue(NAN); // set to this extrem value so, if anything wrong in later procedure, problems can be easily spotted. - - if (m_PastState.IsEmpty() || m_PastState.GetNumCols() != GetNumParallelSequences()) - { - m_PastState.Resize(outputDim, GetNumParallelSequences()); - m_PastState.SetValue(m_DefaultState); - } - if (m_PastOutput.IsEmpty() || m_PastOutput.GetNumCols() != GetNumParallelSequences()) - { - m_PastOutput.Resize(outputDim, GetNumParallelSequences()); - } - -#ifdef DEBUG_DECODER - if (m_PastOutput.IsEmpty() == false) - fprintf(stderr, "LSTM node %ls past output norm = %.8e\n", this->NodeName().c_str(), m_PastOutput.FrobeniusNorm()); - if (m_PastState.IsEmpty() == false) - fprintf(stderr, "LSTM node %ls past state norm = %.8e\n", this->NodeName().c_str(), m_PastState.FrobeniusNorm()); -#endif - - for (size_t timeIdxInSeq = 0; timeIdxInSeq < nT; timeIdxInSeq += GetNumParallelSequences()) - { - FrameRange fr(m_pMBLayout, timeIdxInSeq); - Matrix sliceObs = Input(0)->ValueFor(fr); - Matrix sliceOutput = ValueFor(fr); - Matrix sliceState = DataFor(m_State, fr); - - Matrix sliceGi = DataFor(m_Gi, fr); - Matrix sliceGf = DataFor(m_Gf, fr); - Matrix sliceGo = DataFor(m_Go, fr); - - Matrix sliceTanhState = DataFor(tanhState, fr); - Matrix sliceTanhInput = DataFor(tanhObs, fr); - - PrepareHistory(timeIdxInSeq, mSlicePrevOutput, mSlicePrevState, Value(), m_State, m_PastOutput, m_PastState, GetNumParallelSequences(), m_DefaultState, nullptr /*&m_pMBLayout->GetM()*/ /*BUGBUG: no longer functional*/); - - ForwardPropS(Input(1)->Value(), Input(2)->Value(), Input(3)->Value(), Input(4)->Value(), - sliceObs, mSlicePrevOutput, mSlicePrevState, sliceOutput, sliceState, sliceGi, sliceGf, sliceGo, sliceTanhState, sliceTanhInput, m_tempMatrix); - } - - // save the hidden activities and output for the next minibatch - SaveLastStateActity(); - -#ifdef DEBUG_DECODER - if (mLastOutput.IsEmpty() == false) - fprintf(stderr, "LSTM node %ls last output norm = %.8e\n", this->NodeName().c_str(), mLastOutput.FrobeniusNorm()); - if (mLastState.IsEmpty() == false) - fprintf(stderr, "LSTM node %ls last state norm = %.8e\n", this->NodeName().c_str(), mLastState.FrobeniusNorm()); -#endif - -#ifdef DEBUG_DECODER - ElemType tmpnorm = Value().FrobeniusNorm(); - if (ISCLOSE(tmpnorm, 0.834251, 0.002)) - fprintf(stderr, "check!"); - fprintf(stderr, "LSTM function norm = %.8e\n", tmpnorm); - for (size_t i = 0; i < 5; i++) - fprintf(stderr, "LSTM input[%d] norm = %.8e ", i, Input(i)->Value().FrobeniusNorm()); - fprintf(stderr, "\n"); -#endif - - m_GradientComputed = false; - } - } - - /** - Prepare history for LSTMnode - - This function returns state and output from the previous time instance. For recurrent network, the initial state needs to be set in the case of sentence begining, which is carried over from sentenceBegin. In case of sentence begining, the state activity is set to an initial value. The sentenceBegin has element of ((int) MinibatchPackingFlags::SequenceStart), ((int) MinibatchPackingFlags::None) and ((int) MinibatchPackingFlags::NoInput), which are 0, 1, and -1, respectively. - To compute the initial value, we use - prevState = sentenceBegin * delayedActivation + ~sentenceBegin * initialStateValue - and ~sentenceBegin is computed as -1*(sentenceBegin - 1), assuming that sentenceBegin is either 0 or 1. For example, when sentenceBegin == 1, ~sentenceBegin == 0. - The previous-time output doesn't have initial value, so it is computed as - prevOutput = sentenceBegin * pastOutput - - */ - // prepare prevstate and prevoutput - static void WINAPI PrepareHistory( - size_t timeIdxInSeq, - Matrix& slicePrevOutput, - Matrix& slicePrevState, - const Matrix& output, - const Matrix& state, - const Matrix& pastOutput, - const Matrix& pastState, - size_t nsamples, const ElemType& initStateValue, const Matrix* sentenceBegin) - { - size_t nRow = pastOutput.GetNumRows(); - size_t nStream = sentenceBegin->GetNumRows(); - - assert(nStream == nsamples); - - int utt_t = (int) floor(timeIdxInSeq / nsamples); - if (slicePrevOutput.IsEmpty() || slicePrevOutput.GetNumRows() != nRow || slicePrevOutput.GetNumCols() != nsamples) - slicePrevOutput.Resize(nRow, nsamples); - if (slicePrevState.IsEmpty() || slicePrevState.GetNumRows() != nRow || slicePrevState.GetNumCols() != nsamples) - slicePrevState.Resize(nRow, nsamples); - - if (sentenceBegin->GetNumRows() != nsamples) - LogicError("Number of rows should be the same as the number of data streams"); - - Matrix colBegin(sentenceBegin->GetDeviceId()); - colBegin.SetValue(sentenceBegin->ColumnSlice(utt_t, 1)); - Matrix colSeg(colBegin.GetDeviceId()); - colSeg.Resize(nStream, nStream); - // will reset to 0 if sentence begining at a position is 0 - // will keep the output if it is not the sentence begining - colBegin.InplaceTruncateBottom(((int) MinibatchPackingFlags::SequenceStart)); - colBegin.InplaceTruncateTop(((int) MinibatchPackingFlags::None)); -#if 1 - initStateValue; - pastState; - pastOutput; - state; - output; - LogicError("PrepareHistory: finish this"); -#else - // BUGBUG: we need to upcast float to double here - colSeg.SetDiagonalValue(colBegin); - - Matrix newPrevOutput(colBegin.GetDeviceId()); - Matrix newPrevState(colBegin.GetDeviceId()); - if (utt_t == 0) - { - // this is the begining of this minibatch - Matrix::Multiply(pastOutput.ColumnSlice(0, nsamples), false, colSeg, false, newPrevOutput); - Matrix::Multiply(pastState.ColumnSlice(0, nsamples), false, colSeg, false, newPrevState); - } - else - { - // this is in the minibatch - FrameRange fr(timeIdxInSeq, nsamples); - Matrix::Multiply(DataFor(output, fr /*TODO: delete the next two parameters*/, fr.t() - nsamples, nsamples), false, colSeg, false, newPrevOutput); - Matrix::Multiply(DataFor(state, fr /*TODO: delete the next two parameters*/, fr.t() - nsamples, nsamples), false, colSeg, false, newPrevState); - } - - Base::SetToInitStateValueForResetSeg(sentenceBegin->ColumnSlice(utt_t, 1), nStream, initStateValue, newPrevState); - - slicePrevOutput.ColumnSlice(0, nsamples).SetValue(newPrevOutput); - slicePrevState.ColumnSlice(0, nsamples).SetValue(newPrevState); -#endif - } - - // prepare prevstate and prevoutput - void PrepareThisErrorsBeforeBackProp( - size_t timeIdxInSeq, - size_t nT, // number of columns - Matrix& error, - Matrix& stateError, - const Matrix& grdToPrevOutput, - const Matrix& grdToPrevState, - const Matrix& obs_error_from_future_minibatch, - const Matrix& state_error_from_future_minibatch, - size_t nsamples, const Matrix* sentenceBegin) - { - int utt_t = (int) floor(timeIdxInSeq / nsamples); - int total_utt_t = (int) floor(nT / nsamples); - - error += grdToPrevOutput; - stateError = grdToPrevState; - - if (m_use_errors_from_future_minibatch) - { - for (size_t utt_id = 0; utt_id < nsamples; utt_id++) - { - // if uses errors from future minibatch - if ((GetSegInfo(timeIdxInSeq, utt_id) == ((int) MinibatchPackingFlags::None) && utt_t == total_utt_t - 1) // last time - || (utt_t < total_utt_t - 1 && GetSegInfo(timeIdxInSeq, utt_id) == ((int) MinibatchPackingFlags::None) && GetSegInfo(timeIdxInSeq + nsamples, utt_id) == ((int) MinibatchPackingFlags::NoInput)) // future observation is no observation - ) - { - error.ColumnSlice(utt_id, 1) += obs_error_from_future_minibatch.ColumnSlice(utt_id, 1); - stateError.ColumnSlice(utt_id, 1) += state_error_from_future_minibatch.ColumnSlice(utt_id, 1); - } - } - } - -#if 1 - sentenceBegin; - LogicError("PrepareThisErrorsBeforeBackProp: finish this"); -#else - Matrix colBegin(sentenceBegin->GetDeviceId()); - colBegin.SetValue(sentenceBegin->ColumnSlice(utt_t, 1)); - colBegin.InplaceTruncateBottom(((int) MinibatchPackingFlags::NoInput)); - colBegin.InplaceTruncateTop(((int) MinibatchPackingFlags::SequenceStart)); - colBegin += fabs((ElemType)((int) MinibatchPackingFlags::NoInput)); // raise this so that -1 -> 0 and therefore - Matrix colSeg(colBegin.GetDeviceId()); - colSeg.Resize(nsamples, nsamples); - colSeg.SetDiagonalValue(colBegin); - - // times the errors with the mask - Matrix newOutputError(colBegin.GetDeviceId()); - Matrix newStateError(colBegin.GetDeviceId()); - - Matrix::Multiply(error, false, colSeg, false, newOutputError); - Matrix::Multiply(stateError, false, colSeg, false, newStateError); - - error.ColumnSlice(0, nsamples).SetValue(newOutputError); - stateError.ColumnSlice(0, nsamples).SetValue(newStateError); -#endif - } - - // prepare prevstate and prevoutput - static void WINAPI PrepareErrors( - size_t timeIdxInSeq, - Matrix& errors, - Matrix& stateError, - size_t nsamples, const Matrix* sentenceBegin) - { - int utt_t = (int) floor(timeIdxInSeq / nsamples); - Matrix colBegin(sentenceBegin->GetDeviceId()); -#if 1 - errors; - stateError; - utt_t; - LogicError("PrepareErrors: finish this"); -#else - colBegin.SetValue(sentenceBegin->ColumnSlice(utt_t, 1)); - // will reset to 0 if sentence begining at a posiiton is 0 - // will keep the output if it is not the sentence begining - colBegin.InplaceTruncateBottom(((int) MinibatchPackingFlags::SequenceStart)); - colBegin.InplaceTruncateTop(((int) MinibatchPackingFlags::None)); - - Matrix colSeg(colBegin.GetDeviceId()); - colSeg.Resize(nsamples, nsamples); - colSeg.SetDiagonalValue(colBegin); - - // times the errors with the mask - Matrix newOutputError(colBegin.GetDeviceId()); - Matrix newStateError(colBegin.GetDeviceId()); - - Matrix::Multiply(errors, false, colSeg, false, newOutputError); - Matrix::Multiply(stateError, false, colSeg, false, newStateError); - - errors.ColumnSlice(0, nsamples).SetValue(newOutputError); - stateError.ColumnSlice(0, nsamples).SetValue(newStateError); -#endif - } - - /*TODO: merge with call site*/ void ForwardPropS( - const Matrix& mInputGate, - const Matrix& mForgetGate, const Matrix& mOutputGate, - const Matrix& mCellWgt, - const Matrix& obs, - const Matrix& prevOutput, - const Matrix& prevState, - Matrix& output, - Matrix& state, - Matrix& gi, - Matrix& gf, - Matrix& go, - Matrix& tanhState, - Matrix& tanhObs, - Matrix& tmp) - { - int inputDim = obs.GetNumRows(); - int outputDim = mOutputGate.GetNumRows(); - - // for input gate - Matrix::Multiply(mInputGate.ColumnSlice(1, inputDim), false, obs, false, gi); - Matrix::MultiplyAndAdd(mInputGate.ColumnSlice(1 + inputDim, outputDim), false, prevOutput, false, gi); - gi += mInputGate.ColumnSlice(0, 1); - tmp = prevState; - tmp.ColumnElementMultiplyWith(mInputGate.ColumnSlice(1 + inputDim + outputDim, 1)); - gi += tmp; - gi.AssignSigmoidOf(gi); - - // for forget gate - Matrix::Multiply(mForgetGate.ColumnSlice(1, inputDim), false, obs, false, gf); - Matrix::MultiplyAndAdd(mForgetGate.ColumnSlice(1 + inputDim, outputDim), false, prevOutput, false, gf); - gf += mForgetGate.ColumnSlice(0, 1); - tmp = prevState; - tmp.ColumnElementMultiplyWith(mForgetGate.ColumnSlice(1 + inputDim + outputDim, 1)); - gf += tmp; - gf.AssignSigmoidOf(gf); - - // for cell state - Matrix::Multiply(mCellWgt.ColumnSlice(1, inputDim), false, obs, false, state); - Matrix::MultiplyAndAdd(mCellWgt.ColumnSlice(1 + inputDim, outputDim), false, prevOutput, false, state); - state += mCellWgt.ColumnSlice(0, 1); -#ifdef DEBUG_DECODER -// fprintf(stderr, "W_xc norm = %.8e\n", mCellWgt.ColumnSlice(1, inputDim).FrobeniusNorm()); -// fprintf(stderr, "W_hc norm = %.8e\n", mCellWgt.ColumnSlice(1 + inputDim, outputDim).FrobeniusNorm()); -// fprintf(stderr, "b_c norm = %.8e\n", mCellWgt.ColumnSlice(0, 1).FrobeniusNorm()); -#endif - tanhObs.AssignTanhOf(state); - state.AssignElementProductOf(gi, tanhObs); - state.AddElementProductOf(gf, prevState); - - // for output gate - Matrix::Multiply(mOutputGate.ColumnSlice(1, inputDim), false, obs, false, go); - Matrix::MultiplyAndAdd(mOutputGate.ColumnSlice(1 + inputDim, outputDim), false, prevOutput, false, go); - go += mOutputGate.ColumnSlice(0, 1); - tmp = state; - tmp.ColumnElementMultiplyWith(mOutputGate.ColumnSlice(1 + inputDim + outputDim, 1)); - go += tmp; - go.AssignSigmoidOf(go); - - // to return output - tanhState.AssignTanhOf(state); - output.AssignElementProductOf(go, tanhState); - } - - // input(0) : child with dimension [inputdim x T] - // input(1) : input gate [outputdim x [inputdim + outputdim + 2]] bi, Wxi, Whi, Wci - // input(2) : forget gate [outputdim x [inputdim + outputdim + 2]] for bf, Wxf, Whf, Wcf - // input(3) : output gate [outputdim x [inputdim + outputdim + 2]] for bo, Wxo, Who, and Wco - // input(4) : memory cell weight [outputdim x [inputdim + outputdim + 1]] for bc, Wxc, and Whc - // output : dimension [outputdim x T] - virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override - { - Base::Validate(isFinalValidationPass); - InferMBLayoutFromInputsForStandardCase(); - - if (Input(0)->Value().GetMatrixType() == SPARSE) - LogicError("LSTMNode: input to LSTM has to be dense matrix. Consider adding a project layer using lookuptable before LSTM node. "); - -#if 0 - // TODO: use dynamic_pointer_cast instead - if (Input(1)->OperationName() != OperationNameOf(LearnableParameter) || - Input(2)->OperationName() != OperationNameOf(LearnableParameter) || - Input(3)->OperationName() != OperationNameOf(LearnableParameter) || - Input(4)->OperationName() != OperationNameOf(LearnableParameter)) - LogicError("LSTM validation: need to have learnable parameters "); -#endif - - //if (Input(0)->GetSampleMatrixNumRows() == 0) - // LogicError("LSTM validation: input size is zero!"); - - //if (Input(1)->GetSampleMatrixNumRows() == 0 || - // Input(2)->GetSampleMatrixNumRows() == 0 || - // Input(3)->GetSampleMatrixNumRows() == 0 || - // Input(4)->GetSampleMatrixNumRows() == 0) - // LogicError("LSTM validation : parameter size is zero!"); - - size_t nindim = Input(0)->GetSampleMatrixNumRows(); - size_t noutdim = Input(1)->GetSampleMatrixNumRows(); - //size_t nT = Input(0)->GetSampleMatrixNumCols(); - size_t nCol = nindim + noutdim + 2; - if (isFinalValidationPass) - { - if (Input(1)->GetSampleMatrixNumCols() != nCol) - { - LogicError("LSTM validation : dimension mismatched between child and inputGate"); - } - if (Input(2)->GetSampleMatrixNumCols() != nCol) - { - LogicError("LSTM validation : dimension mismatched between child and forgetGate"); - } - if (Input(3)->GetSampleMatrixNumCols() != nCol) - { - LogicError("LSTM validation : dimension mismatched between child and outputGate"); - } - - if (noutdim != Input(2)->GetSampleMatrixNumRows() || - noutdim != Input(3)->GetSampleMatrixNumRows() || - noutdim != Input(4)->GetSampleMatrixNumRows()) - { - LogicError("LSTM validation: output dimension mismatched!"); - } - } - - SetDims(TensorShape(noutdim), true); - Value().SetValue(NAN); // set to this extrem value so, if anything wrong in later procedure, problems can be easily spotted. - } - - bool UnitTest() - { - { - size_t nT = 3; - size_t nInput = 2; - size_t nHidden = 3; - size_t nOutput = 3; - - // backup - Matrix f0(m_deviceId), f1(m_deviceId), f2(m_deviceId), f3(m_deviceId), f4(m_deviceId), func(m_deviceId), f5(m_deviceId); - Matrix target(m_deviceId); - Matrix giWeight, ghWeight, goWeight; - ElemType initStateValue = m_DefaultState; - auto pMBLayout = make_shared(); - pMBLayout->Init(1, nT); - //Matrix & boundary = pMBLayout->m_sentenceBoundaryFlags; - //vector & minibatchPackingFlags = pMBLayout->m_minibatchPackingFlags; - //boundary.ColumnSlice(0, 1).SetValue(((int) MinibatchPackingFlags::SequenceStart)); - //minibatchPackingFlags[1] = MinibatchPackingFlags::SequenceStart; - pMBLayout->AddSequence(NEW_SEQUENCE_ID, 0, 0, nT); - Base::LinkToMBLayout(pMBLayout); - - f0 = Input(0)->Value(); - f1 = Input(1)->Value(); - f2 = Input(2)->Value(); - f3 = Input(3)->Value(); - f4 = Input(4)->Value(); - func = Value(); - - target.Resize(nOutput, nT); - for (size_t i = 0; i < nT; i++) - target(0, i) = 1; - - Input(0)->SetDims1(nInput, nT); - Input(0)->Value().SetValue(ConstOnes(nInput, nT, m_deviceId)); - Input(0)->Value().SetValue((ElemType) 0.1); - Input(1)->SetDims1(nHidden, nInput + nOutput + 2); - Input(1)->Value().SetValue((ElemType) 0.1); - Input(2)->SetDims1(nHidden, nInput + nHidden + 2); - Input(2)->Value().SetValue((ElemType) 0.1); - Input(3)->SetDims1(nOutput, nInput + nHidden + 2); - Input(3)->Value().SetValue((ElemType) 0.1); - Input(4)->SetDims1(nOutput, nHidden + nInput + 1); - Input(4)->Value().SetValue((ElemType) 0.1); - SetDims1(nOutput, nT); - - m_DefaultState = 0.0; - ForwardProp(FrameRange(m_pMBLayout)); - - // check with expected values - if (!ISCLOSE(Value()(0, 0), 0.0335975, EPSILON) || - !ISCLOSE(Value()(0, 1), 0.05485132, EPSILON) || - !ISCLOSE(Value()(0, 2), 0.06838435, EPSILON) || - !(Value()(0, 0) == Value()(1, 0))) - throw("LSTMNode forward computation error"); - - Value().TransferToDeviceIfNotThere(m_deviceId, true); - - Gradient().Resize(nOutput, nT); - Gradient().SetValue(1.0); - for (size_t i = 0; i < 5; i++) - { - Input(i)->Gradient().Resize(Input(i)->GetSampleMatrixNumRows(), Input(i)->GetSampleMatrixNumCols()); - Input(i)->Gradient().SetValue(0); - } - for (size_t i = 0; i < 5; i++) - BackpropTo(i, FrameRange(m_pMBLayout)); - - // check with expected values - if (!ISCLOSE(Input(1)->Gradient()(0, 0), 0.07843818, EPSILON) // bi - || !ISCLOSE(Input(1)->Gradient()(0, 1), 0.00784382, EPSILON) // Wxi - || !ISCLOSE(Input(1)->Gradient()(0, 3), 0.00192997, EPSILON) // Whi - || !ISCLOSE(Input(1)->Gradient()(0, 6), 0.00362767, EPSILON) // Wci - ) - throw("LSTMNode gradient error on input gates"); - if (!ISCLOSE(Input(2)->Gradient()(0, 0), 0.02738655, EPSILON) // bf - || !ISCLOSE(Input(2)->Gradient()(0, 1), 0.00273866, EPSILON) // Wxf - || !ISCLOSE(Input(2)->Gradient()(0, 3), 0.00120922, EPSILON) // Whf - || !ISCLOSE(Input(2)->Gradient()(0, 6), 0.00227184, EPSILON) // Wcf - ) - throw("LSTMNode gradient error on forget gates"); - if (!ISCLOSE(Input(3)->Gradient()(0, 0), 0.07801557, EPSILON) // bo - || !ISCLOSE(Input(3)->Gradient()(0, 1), 0.00780156, EPSILON) // Wxo - || !ISCLOSE(Input(3)->Gradient()(0, 3), 0.00268089, EPSILON) // Who - || !ISCLOSE(Input(3)->Gradient()(0, 6), 0.00809852, EPSILON) // Wco - ) - throw("LSTMNode gradient error on output gates"); - if (!ISCLOSE(Input(4)->Gradient()(0, 0), 1.3075038, EPSILON) // bc - || !ISCLOSE(Input(4)->Gradient()(0, 1), 0.13075038, EPSILON) // Wxc - || !ISCLOSE(Input(4)->Gradient()(0, 3), 0.03080355, EPSILON) // Whc - ) - throw("LSTMNode gradient error on memory cells"); - - for (size_t i = 0; i < 5; i++) - { - - Input(i)->Gradient().TransferToDeviceIfNotThere(m_deviceId, true); - } - m_DefaultState = initStateValue; - } - - fprintf(stderr, "LSTMNode unit test passed!\n"); - return true; - } - - virtual void DumpNodeInfo(const bool printValues, File& fstream) const override - { - Base::DumpNodeInfo(printValues, fstream); - fstream << L"Input[Width:" << m_inputDim << L"] \n"; - fstream << L"Hidden[Width:" << m_outputDim << L"] Output[Width:" << m_outputDim << L"] \n"; - } - -public: - bool GetHistory(Matrix& hist, bool bLastTime) - { - size_t tRow = m_PastOutput.GetNumRows(); - size_t tCol = m_PastOutput.GetNumCols(); - size_t rCol = m_PastState.GetNumCols(); - - DEVICEID_TYPE device = hist.GetDeviceId(); - hist.TransferFromDeviceToDevice(device, m_deviceId, true); - hist.Resize(tRow, tCol + rCol); - - if (bLastTime) - { - hist.ColumnSlice(0, tCol).SetValue(mLastOutput); - hist.ColumnSlice(tCol, rCol).SetValue(mLastState); - } - else - { - hist.ColumnSlice(0, tCol).SetValue(m_PastOutput); - hist.ColumnSlice(tCol, rCol).SetValue(m_PastState); - } - - hist.TransferFromDeviceToDevice(m_deviceId, device, true); - return true; - } - - void SetHistory(const Matrix& hist) - { - size_t tRow = hist.GetNumRows(); - size_t tCol = hist.GetNumCols(); - size_t eCols = tCol / 2; - - DEVICEID_TYPE device = hist.GetDeviceId(); - hist.TransferFromDeviceToDevice(device, m_deviceId, true); - - m_PastOutput.Resize(tRow, eCols); - m_PastState.Resize(tRow, eCols); - m_PastOutput.SetValue(hist.ColumnSlice(0, eCols)); - m_PastState.SetValue(hist.ColumnSlice(eCols, eCols)); - - hist.TransferFromDeviceToDevice(m_deviceId, device, true); - } - - virtual void GetErrorsToPreviousMinibatch(Matrix& hist) - { - size_t tRow = m_obs_error_from_future_minibatch.GetNumRows(); - size_t tCol = m_obs_error_from_future_minibatch.GetNumCols(); - size_t rCol = m_state_error_from_future_minibatch.GetNumCols(); - - DEVICEID_TYPE device = hist.GetDeviceId(); - - hist.TransferFromDeviceToDevice(device, m_deviceId, true); - hist.Resize(tRow, tCol + rCol); - - hist.ColumnSlice(0, tCol).SetValue(m_obs_error_from_future_minibatch); - hist.ColumnSlice(tCol, rCol).SetValue(m_state_error_from_future_minibatch); - - hist.TransferFromDeviceToDevice(m_deviceId, device, true); - } - - virtual void SetErrorsFromFutureMinibatch(Matrix& hist) - { - size_t tCol = hist.GetNumCols(); - size_t rCol = tCol / 2; - - DEVICEID_TYPE device = hist.GetDeviceId(); - - hist.TransferFromDeviceToDevice(device, m_deviceId, true); - - m_obs_error_from_future_minibatch.SetValue(hist.ColumnSlice(0, rCol)); - m_state_error_from_future_minibatch.SetValue(hist.ColumnSlice(rCol, rCol)); - - m_use_errors_from_future_minibatch = true; - - hist.TransferFromDeviceToDevice(m_deviceId, device, true); - } - -protected: - size_t m_inputDim; - size_t m_outputDim; - - Matrix m_State; // hidden state activity - Matrix m_PastState; // state activity in the previous minibatch - Matrix m_PastOutput; // output in the previou minibatch - - Matrix mLastState; // last state activity - Matrix mLastOutput; // last output - - Matrix m_Gi; // input gate activity - Matrix m_Gf; // forget gate activity - Matrix m_Go; // output gate activity - - Matrix grdToObs, grdToInputGate, grdToForgetGate, grdToOutputGate, grdToCellWgt; - Matrix tanhState, tanhObs; - - Matrix m_tempMatrix; // temp matrix for speed-up - - bool m_GradientComputed; // true if LSTM node has computed gradients, set to false if forward computation is just finished - - Matrix mSlicePrevOutput, mSlicePrevState; - - Matrix grdBeforeInputGate, grdBeforeForget, grdBeforeGo, grdToCell, grdBeforeTanhInputGate; - -public: - // errors from future minibatch - Matrix m_obs_error_from_future_minibatch; - Matrix m_state_error_from_future_minibatch; - bool m_use_errors_from_future_minibatch; - - ElemType m_DefaultState; -}; - -template class LSTMNode; -template class LSTMNode; - // ----------------------------------------------------------------------- -// BatchModeNode +/// DummyCriterionNode (objectives, derivatives, prediction) // ----------------------------------------------------------------------- -/** - BatchModeNode is a derivative of ComputationNode. - It additionally check if needs to process data in batch before processing its parent - This is used in case of beam search decoding. Batchmode node must be processed before other nodes. - It differs from PreComputeNode in that precompute done is done before the entire corpus. - This is done before forward computation of all nodes. - */ +// This training criterion node needs derivatives and objectives to be +// computed out of the node. Derivatives and objectives will be fed to the +// node as input features. It has 3 inputs: +// 1. feature node that feeds objectives +// 2. feature node that feeds derivatives +// 3. neural network output +// +// This node is useful in sequence training for speech recognition, so that +// we can separate lattice computation (which may rely other softwares, such +// as Kaldi) with the neural network training. + template -class BatchModeNode : public ComputationNodeNonLooping /*ComputationNode*/ +class DummyCriterionNode : public ComputationNodeNonLooping /*ComputationNode*/, public NumInputs<3> { - // all nodes require precomputation should derive from this class typedef ComputationNodeNonLooping Base; - UsingComputationNodeMembers; + UsingComputationNodeMembersBoilerplate; + static const std::wstring TypeName() + { + return L"DummyCriterion"; + } public: - //virtual ComputationNodeBase * NewThis(DEVICEID_TYPE deviceId, const wstring & name) = 0; - //DeclareConstructorFromConfigWithNumInputs(BatchModeNode); - BatchModeNode(DEVICEID_TYPE deviceId, const wstring& name) - : Base(deviceId, name), - m_memory(deviceId) + DeclareConstructorFromConfigWithNumInputs(DummyCriterionNode); + DummyCriterionNode(DEVICEID_TYPE deviceId, const wstring& name) + : Base(deviceId, name) { } - virtual bool HasComputed() const = 0; - virtual void MarkComputed(const bool hasComputed) = 0; + virtual void BackpropToNonLooping(size_t inputIndex) override + { + FrameRange fr(Input(0)->GetMBLayout()); + if (inputIndex == 0) + LogicError("DummyCriterionNode: derivatives with respect to objective features are not necessary, not implemented yet.\n"); + else if (inputIndex == 1) + LogicError("DummyCriterionNode: derivatives with respect to derivative features are not necessary, not implemented yet.\n"); + else if (inputIndex == 2) + { + auto gradient = Input(2)->GradientFor(fr); + Matrix::Multiply1x1AndWeightedAdd(+1.0f, Gradient() /*1x1*/, Input(1)->ValueFor(fr), 1.0f, gradient); + } + } - virtual void Save(File& fstream) const override + virtual bool OutputUsedInComputingInputNodesGradients() const override { - Base::Save(fstream); - fstream << m_hasComputed; - fstream << Value(); + return false; } - virtual void Load(File& fstream, size_t modelVersion) override + virtual void /*ComputationNodeNonLooping::*/ ForwardPropNonLooping() override { - Base::Load(fstream, modelVersion); - fstream >> m_hasComputed; - LoadValue(fstream); + Value().VerifySize(1, 1); + Input(0)->Value().VerifySize(1, 1); + Value().SetValue(Input(0)->Value()); +#if NANCHECK + Value().HasNan("DummyCriterionNode"); +#endif } - virtual void DumpNodeInfo(const bool printValues, File& fstream) const override + virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override { - Base::DumpNodeInfo(printValues, fstream); + Base::Validate(isFinalValidationPass); + m_pMBLayout = nullptr; // this node does not hold mini-batch data - const size_t BUFLEN = 4096; - WCHAR str[BUFLEN]; - swprintf(str, BUFLEN, L"[%s%s] ", string(GetSampleLayout()).c_str(), HasMBLayout() ? " x *" : ""); - fstream << wstring(str); - swprintf(str, BUFLEN, L"HasComputed=%ls", HasComputed() ? L"true" : L"false"); - fstream << wstring(str); + if (Input(0)->OperationName() != L"InputValue") + LogicError("DummyCriterionNode criterion requires the first input to be computed objectives."); + if (Input(0)->OperationName() != L"InputValue") + LogicError("DummyCriterionNode criterion requires the first input to be computed derivatives."); + if (isFinalValidationPass) + { + if (Input(0)->GetSampleMatrixNumRows() != 1) + LogicError("DummyCriterionNode criterion requires the first input to have dimension 1."); + if (Input(0)->GetSampleMatrixNumRows() == 0 || Input(1)->GetSampleMatrixNumRows() == 0 || Input(2)->GetSampleMatrixNumRows() == 0) + LogicError("DummyCriterionNode operation: one of the operands has 0 elements."); + if (Input(1)->GetSampleMatrixNumRows() != Input(2)->GetSampleMatrixNumRows()) + LogicError("The Matrix dimension in the DummyCriterionNode operation does not match."); + } - PrintNodeValuesToFile(printValues, fstream); + SetDims(TensorShape(1), false); } - -protected: - Matrix m_memory; // the memory of input or output - bool m_hasComputed; }; -// add this at the start of each derived class, to get access to the members of ComputationNode -// See #define of 'UsingComputationNodeMembersBoilerplate' for more explanation. -#define UsingBatchModeNodeMembers \ - UsingComputationNodeMembersBoilerplate; \ - \ -protected: \ - using Base::m_memory; \ - using Base::m_hasComputed; \ - \ -public: \ - using Base::HasComputed; \ - using Base::MarkComputed +template class DummyCriterionNode; +template class DummyCriterionNode; // ----------------------------------------------------------------------- -// TimeReverseNode (input) -// BUGBUG: This must actually implement reversing the layout. -// Challenge: This reverses the layout. If we time-reverse back, we'd reverse the layout again. -// We will get the original layout. Unfortunately, it is not the same layout pointer. -// To turn it back to the same layout pointer, insert a ReconcileMBLayout node. +// SequenceDecoderNode (label, position_dependent_score, transition_score) +// this node does sequence decoding only +// it corresponds to a decoder +// - label : output label vector of [0:T-1] +// - position_dependent_score : score from position dependent node, +// in the R-CRF case, it is the RNN output score before softmax +// - transition score : score from the transition node, +// in the R-CRF case, it is the transition probability between labels // ----------------------------------------------------------------------- -/** - Developed by Kaisheng Yao. - This node is used in the following work - K. Yao and G. Zweig, "Sequence-to-Sequence Neural Net Models for Grapheme-to-Phoneme Conversion", submitted to INTERSPEECH 2015 - */ template -class TimeReverseNode : public BatchModeNode, public NumInputs<1> +class SequenceDecoderNode : public ComputationNodeNonLooping /*ComputationNode*/, public NumInputs<3> { - typedef BatchModeNode Base; - UsingBatchModeNodeMembers; + typedef ComputationNodeNonLooping Base; + UsingComputationNodeMembersBoilerplate; static const std::wstring TypeName() { - return L"TimeReverse"; + return L"SequenceDecoderNode"; } +private: + // TODO: member variables go to the end + Matrix mAlpha; + Matrix mBacktrace; + + int mStartLab; // the starting output label + int mEndLab; // the ending output label, if avaliable + ElemType m_default_activity; + public: - DeclareConstructorFromConfigWithNumInputs(TimeReverseNode); - TimeReverseNode(DEVICEID_TYPE deviceId, const wstring& name) - : BatchModeNode(deviceId, name) + DeclareConstructorFromConfigWithNumInputs(SequenceDecoderNode); + SequenceDecoderNode(DEVICEID_TYPE deviceId, const wstring& name) + : Base(deviceId, name), + mAlpha(deviceId), + mBacktrace(deviceId), + mStartLab(-1), + mEndLab(-1) { } - virtual void CopyTo(ComputationNodeBasePtr nodeP, const std::wstring& newName, const CopyNodeFlags flags) const override + static void DecideStartEndingOutputLab(const Matrix& lbls, int& stt, int& stp) { - Base::CopyTo(nodeP, newName, flags); - if (flags & CopyNodeFlags::copyNodeValue) - { - auto node = dynamic_pointer_cast>(nodeP); - // TODO: m_memory is never used inside this class, just assigned. Can it not be assigned? - node->m_memory = m_memory; - } - } + if (stt != -1 && stp != -1) + return; /// have computed before - virtual bool HasComputed() const - { - return m_hasComputed; - } - virtual void MarkComputed(const bool hasComputed) - { - m_hasComputed = hasComputed; - } + int iNumPos = lbls.GetNumCols(); - virtual void BackpropToNonLooping(size_t inputIndex) override - { - assert(inputIndex == 0); - inputIndex; - VerifyDims(Input(0)); + int firstLbl = -1; + for (int ik = 0; ik < lbls.GetNumRows(); ik++) + if (lbls(ik, 0) != 0) + { + firstLbl = ik; + break; + } - size_t nT = GetNumTimeSteps(); - for (size_t t = 0; t < nT; t++) - { - Matrix g = GradientFor(FrameRange(GetMBLayout(), t)); - Matrix ig = Input(0)->GradientFor(FrameRange(Input(0)->GetMBLayout(), nT - 1 - t)); - ig += g; - } + int lastLbl = -1; + for (int ik = 0; ik < lbls.GetNumRows(); ik++) + if (lbls(ik, iNumPos - 1) != 0) + { + lastLbl = ik; + break; + } + + stt = firstLbl; + stp = lastLbl; + }; + + virtual void BackpropToNonLooping(size_t /*inputIndex*/) override //scaled by 2*number of elements in the Matrix + { + LogicError("SequenceDecoder is used for evaluation only."); } virtual bool OutputUsedInComputingInputNodesGradients() const override { - // The TimeReverseNode does not require its output value for computing - // the gradients of its input nodes return false; } - - virtual bool InputUsedInComputingInputNodesGradients(size_t childIndex) const override + virtual bool InputUsedInComputingInputNodesGradients(size_t /*childIndex*/) const override { - // The TimeReverseNode does not require any of it's input's values for computing - // the gradients of its input nodes - UNREFERENCED_PARAMETER(childIndex); return false; } + /// compute posterior probability of label y at position t virtual void /*ComputationNodeNonLooping::*/ ForwardPropNonLooping() override { - // BUGBUG: We must flip the layout, too. - if (GetNumParallelSequences() != 1) - LogicError("%ls %ls operation not implemented for multiple parallel sequences. It does not flip the layout either. I.e. only works for a single utterance.", NodeName().c_str(), OperationName().c_str()); - if (!m_hasComputed) - { - // this assumes this reverse node is called once, so it can set, instead add to, the function values - SetDims(Input(0)); - UpdateFunctionValuesSize(); - - size_t nT = GetNumTimeSteps(); - for (size_t t = 0; t < nT; t++) - { - Matrix v = Input(0)->ValueFor(FrameRange(Input(0)->GetMBLayout(), t)); - ValueFor(FrameRange(GetMBLayout(), nT - 1 - t)).SetValue(v); - } - -#if NANCHECK - Value().HasNan("TimeReverse"); -#endif -#if DUMPOUTPUT - Value().Print("TimeReverseNode"); -#endif - - m_memory.SetValue(Value()); - } - // TODO: don't need to set m_hasCompute? Or what is it for? + DecideStartEndingOutputLab(Input(0)->Value(), mStartLab, mEndLab); + ForwardPropS(mAlpha, mBacktrace, Value(), Input(1)->Value(), + Input(2)->Value(), mStartLab, mEndLab); } - virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override + // compute forward backward algorithm + void ForwardPropS(Matrix& alpha, Matrix& backtrace, Matrix& functionValues, const Matrix& pos_scores, const Matrix& pair_scores, const size_t stt, const size_t stp) { - Base::Validate(isFinalValidationPass); - InferMBLayoutFromInputsForStandardCase(); - if (isFinalValidationPass && !m_pMBLayout) - RuntimeError("%ls %ls operation makes no sense without a MB layout.", NodeName().c_str(), OperationName().c_str()); + /// to-do, each slice is for one sentence + /// to-do, number of slices correspond to number of frames + /// this implementation only supports one sentence per minibatch - SetDims(Input(0)); - } + /// change to other values so can support multiple sentences in each minibatch + ForwardCompute(alpha, backtrace, pos_scores, pair_scores, stt); + BackwardCompute(functionValues, backtrace, stp); + }; -public: - bool UnitTest() + /// compute forward backward algorithm + static void ForwardCompute(Matrix& alpha, + Matrix& backtrace, + const Matrix& pos_scores, const Matrix& pair_scores, + const size_t stt) { - size_t nT = 3; - size_t nInput = 3; - size_t nOutput = nInput; - - Input(0)->SetDims1(nInput, nT); - Input(0)->UpdateFunctionValuesSize(); - Input(0)->Value().SetValue(0); - Input(0)->Value()(0, 0) = 1; - Input(0)->Value()(0, 1) = 2; - Input(0)->Value()(0, 2) = 3; - SetDims1(nOutput, nT); - UpdateFunctionValuesSize(); - Input(0)->Value().TransferToDeviceIfNotThere(m_deviceId, true); - ForwardProp(FrameRange(m_pMBLayout)); - - /// check with expected values - if (!ISCLOSE(Value()(0, 0), 3, EPSILON) || - !ISCLOSE(Value()(0, 1), 2, EPSILON) || - !ISCLOSE(Value()(0, 2), 1, EPSILON)) + /// to-do, shift more than 1 to support muliple sentences per minibatch + int iNumPos = pos_scores.GetNumCols(); + int iNumLab = pos_scores.GetNumRows(); + size_t iTmp = 0; + + /// need to have + alpha.Resize(iNumLab, iNumPos); + backtrace.Resize(iNumLab, iNumPos); + + for (int t = 0; t < iNumPos; t++) { - return false; + for (int k = 0; k < iNumLab; k++) + { + ElemType fTmp = (ElemType) LZERO; + if (t > 1) + { + for (int j = 0; j < iNumLab; j++) + { + ElemType fAlpha = alpha(j, t - 1) + pair_scores(k, j); + if (fAlpha > fTmp) + { + fTmp = fAlpha; + iTmp = j; + } + } + fTmp += pos_scores(k, t); /// include position dependent score + } + else + { + /// with constrain that the first word is labeled as a given symbol + iTmp = stt; + fTmp = 0; + if (t == 1) + { + fTmp = alpha(iTmp, t - 1); + fTmp += pair_scores(k, iTmp); + fTmp += pos_scores(k, t); + } + else + { + fTmp = (k == stt) ? pos_scores(k, t) : (ElemType) LZERO; + } + } + alpha(k, t) = fTmp; + backtrace(k, t) = (ElemType) iTmp; + } } + }; - Value().TransferToDeviceIfNotThere(m_deviceId, true); + /// compute backward algorithm + static void BackwardCompute( + Matrix& decodedpath, + const Matrix& backtrace, const size_t stp) + { + int iNumPos = backtrace.GetNumCols(); + int iNumLab = backtrace.GetNumRows(); - Input(0)->Gradient().Resize(nOutput, nT); - Input(0)->Gradient().SetValue(1.0); - Gradient().Resize(nOutput, nT); - Gradient().SetValue(0); - Gradient()(0, 0) = 1; - Gradient()(0, 1) = 2; - Gradient()(0, 2) = 3; - Gradient().TransferToDeviceIfNotThere(m_deviceId, true); + decodedpath.Resize(iNumLab, iNumPos); + decodedpath.SetValue(0); - BackpropTo(0, FrameRange(m_pMBLayout)); + size_t lastlbl = stp; + decodedpath(lastlbl, iNumPos - 1) = 1; - /// check with expected values - if (!ISCLOSE(Input(0)->Gradient()(0, 0), 4, EPSILON) || - !ISCLOSE(Input(0)->Gradient()(0, 1), 3, EPSILON) || - !ISCLOSE(Input(0)->Gradient()(0, 2), 2, EPSILON)) + for (int t = iNumPos - 1; t > 0; t--) { - return false; + lastlbl = (size_t) backtrace(lastlbl, t); + decodedpath(lastlbl, t - 1) = 1; } + }; - Input(0)->Gradient().TransferToDeviceIfNotThere(m_deviceId, true); - Gradient().TransferToDeviceIfNotThere(m_deviceId, true); + /// need to feed in pseudo label data, which tells the decoder what is the beginning + /// and ending output symbol. these symbols will constrain the search space + virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override + { + Base::Validate(isFinalValidationPass); + InferMBLayoutFromInputsForStandardCase(); - return true; + if (isFinalValidationPass) + if (!(Input(1)->GetSampleMatrixNumRows() == Input(2)->GetSampleMatrixNumRows() && // position dependent and pair scores have same number of labels + Input(0)->GetSampleMatrixNumRows() == Input(1)->GetSampleMatrixNumRows() && + Input(0)->GetSampleMatrixNumCols() == Input(1)->GetSampleMatrixNumCols() && // position dependent and pair scores have the same observation numbers + Input(2)->GetSampleMatrixNumCols() == Input(2)->GetSampleMatrixNumRows())) + { + LogicError("The Matrix dimension in the SequenceDecoderNode operation does not match."); + } + // BUGBUG: No SetDims()? + m_sampleLayout = TensorShape(); } }; -template class TimeReverseNode; -template class TimeReverseNode; +template class SequenceDecoderNode; +template class SequenceDecoderNode; } } }