From 264f3c7aa4954e5ad2e1dfca80b088afbac4bbde Mon Sep 17 00:00:00 2001 From: Frank Seide Date: Wed, 14 Oct 2015 08:22:17 -0700 Subject: [PATCH] further simplified RowStackNode, e.g. we actually do not need m_startRowIndices[num children] --- .../ComputationNode.h | 4 +- .../LinearAlgebraNodes.h | 39 ++++++------------- 2 files changed, 13 insertions(+), 30 deletions(-) diff --git a/MachineLearning/CNTKComputationNetworkLib/ComputationNode.h b/MachineLearning/CNTKComputationNetworkLib/ComputationNode.h index 738d50609df4..a9c52fe5545f 100644 --- a/MachineLearning/CNTKComputationNetworkLib/ComputationNode.h +++ b/MachineLearning/CNTKComputationNetworkLib/ComputationNode.h @@ -1146,9 +1146,9 @@ namespace Microsoft { namespace MSR { namespace CNTK { inline ComputationNodePtr Inputs(const size_t childIndex) const // TODO: rename to Input { -#ifdef DEBUG // profile shows this is range check very expensive in release mode, skip it +#ifdef _DEBUG // profile shows this is range check very expensive in release mode, skip it if (childIndex >= m_children.size()) - InvalidArgument ("childIndex is out of range."); + LogicError("Inputs: childIndex is out of range."); #endif return UpCast(m_children[childIndex]); } diff --git a/MachineLearning/CNTKComputationNetworkLib/LinearAlgebraNodes.h b/MachineLearning/CNTKComputationNetworkLib/LinearAlgebraNodes.h index ac75b8387bb8..75d31570cd90 100644 --- a/MachineLearning/CNTKComputationNetworkLib/LinearAlgebraNodes.h +++ b/MachineLearning/CNTKComputationNetworkLib/LinearAlgebraNodes.h @@ -385,11 +385,8 @@ namespace Microsoft { namespace MSR { namespace CNTK { // stacks multiple inputs on top of each other // ----------------------------------------------------------------------- - //this node is used to extract part of the input by rows as the output - // TODO: Really? RowStack indicates something different. - //it has to be continuous segments of rows since each column is treated as one sample template - class RowStackNode : public ComputationNode // note: not deriving from NumInputs<> like most other nodes since this one takes a variable number of inputs + class RowStackNode : public ComputationNode // note: not deriving from NumInputs<> like most other nodes, because this one takes a variable number of inputs { typedef ComputationNode Base; UsingComputationNodeMembersBoilerplate; static const std::wstring TypeName() { return L"RowStack"; } @@ -401,20 +398,16 @@ namespace Microsoft { namespace MSR { namespace CNTK { virtual void CopyTo(const ComputationNodePtr nodeP, const std::wstring& newName, const CopyNodeFlags flags) const { Base::CopyTo(nodeP, newName, flags); - auto node = dynamic_pointer_cast>(nodeP); - if (flags & CopyNodeFlags::copyNodeChildren) { - node->m_children = m_children; + auto node = dynamic_pointer_cast>(nodeP); node->m_startRowIndices = m_startRowIndices; } } virtual void ComputeInputPartial(const size_t inputIndex) { - if (inputIndex >= ChildrenSize()) - InvalidArgument("RowStack-ComputeInputPartial: inputIndex out of range."); - ComputeInputPartialS(Inputs(inputIndex)->GradientValues(), GradientValues(), m_startRowIndices[inputIndex], m_startRowIndices[inputIndex + 1] - m_startRowIndices[inputIndex]); + ComputeInputPartialS(Inputs(inputIndex)->GradientValues(), GradientValues(), m_startRowIndices[inputIndex]); } virtual void /*ComputationNode::*/ComputeInputPartial(const size_t inputIndex, const FrameRange & frameRange) override @@ -422,50 +415,40 @@ namespace Microsoft { namespace MSR { namespace CNTK { Matrix sliceInputGrad = Inputs(inputIndex)->GradientSlice(frameRange/*TODO: delete this:*/.Check_t(GetNumParallelSequences(), m_pMBLayout)); Matrix sliceOutputGrad = GradientSlice(frameRange/*TODO: delete this:*/.Check_t(GetNumParallelSequences(), m_pMBLayout)); - ComputeInputPartialS(sliceInputGrad, sliceOutputGrad, m_startRowIndices[inputIndex], m_startRowIndices[inputIndex+1] - m_startRowIndices[inputIndex]); + ComputeInputPartialS(sliceInputGrad, sliceOutputGrad, m_startRowIndices[inputIndex]); } - /*TODO: merge with call site*/void ComputeInputPartialS(Matrix& inputGradientValues, const Matrix& gradientValues, const size_t startIndex, const size_t numRows) + /*TODO: merge with call site*/void ComputeInputPartialS(Matrix& inputGradientValues, const Matrix& gradientValues, const size_t startIndex) { - inputGradientValues.AddWithRowSliceValuesOf(gradientValues, startIndex, numRows); + inputGradientValues.AddWithRowSliceValuesOf(gradientValues, startIndex, inputGradientValues.GetNumRows()); } virtual void /*ComputationNode::*/EvaluateThisNode(const FrameRange & frameRange) override { - -#if 1 // assign as row slices, as that allows us to use the ValueSlice() function for (size_t i = 0; i < ChildrenSize(); i++) ValueSlice(frameRange).AssignRowSliceValuesOf(Inputs(i)->ValueSlice(frameRange), m_startRowIndices[i], Inputs(i)->GetNumRows()); -#else - Matrix sliceFunctionValues = ValueSlice(frameRange/*TODO: delete this:*/.Check_t(GetNumParallelSequences(), m_pMBLayout)); - sliceFunctionValues.AssignRowStackValuesOf(m_inputMatrices, frameRange.t() * GetNumParallelSequences(), GetNumParallelSequences()); -#endif } virtual void /*ComputationNodeBase::*/Validate(bool isFinalValidationPass) override { Base::Validate(isFinalValidationPass); + InferMBLayoutFromInputsForStandardCase(); size_t numCols = Inputs(0)->GetNumCols(); - m_startRowIndices.resize(ChildrenSize()+1); + // count totalRows and form m_startRowIndices[] array, which is the cumulative sum of matrix heights + m_startRowIndices.resize(ChildrenSize()); size_t totalRows = 0; - m_startRowIndices[0] = 0; - - // TODO: why do we need Inputs(xxx)->FunctionValues()]? Why not operate directly on Inputs(.)->FunctionValues()? for (int i = 0; i < ChildrenSize(); i++) { - size_t numRows = Inputs(i)->GetNumRows(); - if (isFinalValidationPass && Inputs(i)->GetNumCols() != numCols) LogicError("RowStack operation: the input node %ls has different number of columns.", Inputs(i)->NodeName().c_str()); - totalRows += numRows; - m_startRowIndices[i + 1] = m_startRowIndices[i] + numRows; + m_startRowIndices[i] = totalRows; + totalRows += Inputs(i)->GetNumRows(); } Resize(totalRows, numCols); - InferMBLayoutFromInputsForStandardCase(); InferImageDimsFromInputs(); }