Skip to content

Commit

Permalink
Integrate wdarling/newopnodes into master
Browse files Browse the repository at this point in the history
  • Loading branch information
Project Philly committed Mar 17, 2016
2 parents 4205c47 + 1d3aa94 commit 98ffc6f
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 89 deletions.
1 change: 1 addition & 0 deletions Source/ActionsLib/NetworkDescriptionLanguage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ bool CheckFunction(std::string& p_nodeType, bool* allowUndeterminedVariable)
else if (EqualInsensitive(nodeType, OperationNameOf(SigmoidNode))) ret = true;
else if (EqualInsensitive(nodeType, OperationNameOf(SoftmaxNode))) ret = true;
else if (EqualInsensitive(nodeType, OperationNameOf(SparseInputValue), L"SparseInput")) ret = true;
else if (EqualInsensitive(nodeType, OperationNameOf(SqrtNode))) ret = true;
else if (EqualInsensitive(nodeType, OperationNameOf(SquareErrorNode), L"SE")) ret = true;
else if (EqualInsensitive(nodeType, OperationNameOf(SumColumnElementsNode))) ret = true;
else if (EqualInsensitive(nodeType, OperationNameOf(SumElementsNode))) ret = true;
Expand Down
5 changes: 3 additions & 2 deletions Source/CNTK/BrainScript/CNTKCoreLib/CNTK.core.bs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ Chr(c) = new StringFunction [ what = 'Chr' ; arg = c ]
Floor(x) = new NumericFunction [ what = 'Floor' ; arg = x ]
Length(x) = new NumericFunction [ what = 'Length' ; arg = x ]
Ceil(x) = -Floor(-x)
Round(x) = Floor(x+0.5)
Abs(x) = if x >= 0 then x else -x
Round(x) = Floor(x+0.5)
Sign(x) = if x > 0 then 1 else if x < 0 then -1 else 0
Min(a,b) = if a < b then a else b
Max(a,b) = if a > b then a else b
Expand Down Expand Up @@ -62,6 +61,7 @@ ColumnwiseCrossProduct = KhatriRaoProduct // deprecated
ClassificationError = ErrorPrediction
Delay = PastValue
BatchNormalization(input, scale, bias, runMean, runInvStdDev, eval, spatial, normalizationTimeConstant = 0, epsilon = 0.00001, useCntkEngine = true, imageLayout='CHW', tag='') = new ComputationNode [ operation = 'BatchNormalization' ; inputs = (input : scale : bias : runMean : runInvStdDev) /*plus the function args*/ ]
Abs(x, tag='') = new ComputationNode [ operation = 'Abs' ; inputs = x /*plus the function args*/ ]
ClassBasedCrossEntropyWithSoftmax(labelClassDescriptorVectorSequence, mainInputInfo, mainWeight, classLogProbsBeforeSoftmax, tag='') = new ComputationNode [ operation = 'ClassBasedCrossEntropyWithSoftmax' ; inputs = (labelClassDescriptorVectorSequence : mainInputInfo : mainWeight : classLogProbsBeforeSoftmax) /*plus the function args*/ ]
ColumnElementTimes(aVectorSequence, anotherVectorSequence, tag='') = new ComputationNode [ operation = 'ColumnElementTimes' ; inputs = (aVectorSequence : anotherVectorSequence) /*plus the function args*/ ]
CosDistance(aVectorSequence, anotherVectorSequence, tag='') = new ComputationNode [ operation = 'CosDistance' ; inputs = (aVectorSequence : anotherVectorSequence) /*plus the function args*/ ]
Expand Down Expand Up @@ -92,6 +92,7 @@ Scale(scalarScalingFactor, matrix, tag='') = new ComputationNode [ operation = '
Sigmoid(z, tag='') = new ComputationNode [ operation = 'Sigmoid' ; inputs = z /*plus the function args*/ ]
Softmax(z, tag='') = new ComputationNode [ operation = 'Softmax' ; inputs = z /*plus the function args*/ ]
Hardmax(z, tag='') = new ComputationNode [ operation = 'Hardmax' ; inputs = z /*plus the function args*/ ]
Sqrt(z, tag='') = new ComputationNode [ operation = 'Sqrt' ; inputs = z /*plus the function args*/ ]
SquareError(aMatrix, anotherMatrix, tag='') = new ComputationNode [ operation = 'SquareError' ; inputs = (aMatrix : anotherMatrix) /*plus the function args*/ ]
SumColumnElements(z, tag='') = new ComputationNode [ operation = 'SumColumnElements' ; inputs = z /*plus the function args*/ ]
SumElements(matrix, tag='') = new ComputationNode [ operation = 'SumElements' ; inputs = matrix /*plus the function args*/ ]
Expand Down
7 changes: 7 additions & 0 deletions Source/ComputationNetworkLib/ComputationNetworkBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ static shared_ptr<ComputationNode<ElemType>> CreateStandardNode(const std::wstri
#endif
else if (nodeType == OperationNameOf(SigmoidNode)) return New<SigmoidNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(SoftmaxNode)) return New<SoftmaxNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(SqrtNode)) return New<SqrtNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(SquareErrorNode)) return New<SquareErrorNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(LogisticNode)) return New<LogisticNode<ElemType>>(forward<_Types>(_Args)...);
else if (nodeType == OperationNameOf(SumColumnElementsNode)) return New<SumColumnElementsNode<ElemType>>(forward<_Types>(_Args)...);
Expand Down Expand Up @@ -464,6 +465,12 @@ shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::LogSo
return net.AddNodeToNetAndAttachInputs(New<LogSoftmaxNode<ElemType>>(net.GetDeviceId(), nodeName), a);
}

template <class ElemType>
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::Sqrt(const ComputationNodePtr a, const std::wstring nodeName)
{
return net.AddNodeToNetAndAttachInputs(New<SqrtNode<ElemType>>(net.GetDeviceId(), nodeName), a);
}

template <class ElemType>
shared_ptr<ComputationNode<ElemType>> ComputationNetworkBuilder<ElemType>::Sum(const ComputationNodePtr a, const std::wstring nodeName)
{
Expand Down
1 change: 1 addition & 0 deletions Source/ComputationNetworkLib/ComputationNetworkBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ class ComputationNetworkBuilder
ComputationNodePtr SequenceWithSoftmax(const ComputationNodePtr label, const ComputationNodePtr prediction, const ComputationNodePtr loglikelihood, const std::wstring nodeName = L"");
ComputationNodePtr Sigmoid(const ComputationNodePtr a, const std::wstring nodeName = L"");
ComputationNodePtr Softmax(const ComputationNodePtr a, const std::wstring nodeName = L"");
ComputationNodePtr Sqrt(const ComputationNodePtr a, const std::wstring nodeName = L"");
ComputationNodePtr SquareError(const ComputationNodePtr a, const ComputationNodePtr b, const std::wstring nodeName = L"");
ComputationNodePtr Sum(const ComputationNodePtr a, const std::wstring nodeName = L"");
ComputationNodePtr Tanh(const ComputationNodePtr a, const std::wstring nodeName = L"");
Expand Down
52 changes: 0 additions & 52 deletions Source/ComputationNetworkLib/LinearAlgebraNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,58 +114,6 @@ class MinusNode : public BinaryElementWiseNode<ElemType>
template class MinusNode<float>;
template class MinusNode<double>;

// -----------------------------------------------------------------------
// NegateNode (input)
// computes the negative of its input
// -----------------------------------------------------------------------

template <class ElemType>
class NegateNode : public ComputationNode<ElemType>, public NumInputs<1>
{
typedef ComputationNode<ElemType> Base; UsingComputationNodeMembersBoilerplate;
static const std::wstring TypeName() { return L"Negate"; }

public:
DeclareConstructorFromConfigWithNumInputs(NegateNode);
NegateNode(DEVICEID_TYPE deviceId, const wstring& name)
: Base(deviceId, name)
{
}

virtual void /*ComputationNode::*/ BackpropTo(const size_t /*inputIndex*/, const FrameRange& fr) override
{
Input(0)->GradientFor(fr) -= GradientFor(fr);
}

virtual bool OutputUsedInComputingInputNodesGradients() const override
{
// The NegateNode does not require its output value for computing
// the gradients of its input nodes
return false;
}

virtual bool InputUsedInComputingInputNodesGradients(size_t childIndex) const override
{
// The NegateNode does not require any of it's input's values for computing
// the gradients of its input nodes
UNREFERENCED_PARAMETER(childIndex);
return false;
}

virtual void /*ComputationNode::*/ ForwardProp(const FrameRange& fr) override
{
ValueFor(fr).AssignDifferenceOf(0, Input(0)->ValueFor(fr));
}

virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override
{
ValidateUnaryMap(isFinalValidationPass);
}
};

template class NegateNode<float>;
template class NegateNode<double>;

// -----------------------------------------------------------------------
// TimesNodeBase (A, B, outputRank=1)
// shared code of TimesNode and TransposeTimesNode (which transposes A)
Expand Down
91 changes: 56 additions & 35 deletions Source/ComputationNetworkLib/NonlinearityNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,14 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// only inputs (but not // function values) are used.
// -----------------------------------------------------------------------

template <class ElemType, ElementWiseOperator opForward, ElementWiseOperator opBackward, bool gradientFromOutput>
enum GradientOperationType
{
UnaryGradient,
BinaryWithInputGradient,
BinaryWithOutputGradient
};

template <class ElemType, ElementWiseOperator opForward, ElementWiseOperator opBackward, GradientOperationType opType>
class UnaryElementWiseWithOpCodeNodeBase : public ComputationNode<ElemType>, public NumInputs<1>
{
typedef ComputationNode<ElemType> Base;
Expand Down Expand Up @@ -56,11 +63,23 @@ class UnaryElementWiseWithOpCodeNodeBase : public ComputationNode<ElemType>, pub
size_t rank = DetermineElementwiseTensorRank();
auto sliceOutputGrad = GradientTensorFor(rank, fr); // propagate from this one...
auto sliceInputGrad = Input(0)->GradientTensorFor(rank, fr); // ...to this one
auto sliceValue = gradientFromOutput ? ValueTensorFor(rank, fr) : // using input or output value
Input(0)->ValueTensorFor(rank, fr);
// If gradient can be compute from output rather than input, then that's better for mem sharing (and faster in most cases).
// Not possible for Cos().
sliceInputGrad.DoBinaryOpOf(1, sliceOutputGrad, sliceValue, 1, opBackward);

// we expect a constant conditional expression here -- suppress the warning that leads to an error
#pragma warning( push )
#pragma warning( disable : 4127 )
if (opType == UnaryGradient)
{
sliceInputGrad.DoUnaryOpOf(1, sliceOutputGrad, 1, opBackward);
}
else
{
// If gradient can be compute from output rather than input, then that's better for mem sharing (and faster in most cases).
// Not possible for Cos().
auto sliceValue = (opType == BinaryWithOutputGradient) ? ValueTensorFor(rank, fr) : // using input or output value
Input(0)->ValueTensorFor(rank, fr);
sliceInputGrad.DoBinaryOpOf(1, sliceOutputGrad, sliceValue, 1, opBackward);
}
#pragma warning( pop )
}

virtual void /*ComputationNodeBase::*/ Validate(bool isFinalValidationPass) override
Expand All @@ -70,11 +89,11 @@ class UnaryElementWiseWithOpCodeNodeBase : public ComputationNode<ElemType>, pub

virtual bool OutputUsedInComputingInputNodesGradients() const override
{
return gradientFromOutput;
return opType == BinaryWithOutputGradient;
}
virtual bool InputUsedInComputingInputNodesGradients(size_t /*childIndex*/) const override
{
return !gradientFromOutput;
return opType == BinaryWithInputGradient;
}
};

Expand All @@ -91,33 +110,35 @@ class UnaryElementWiseWithOpCodeNodeBase : public ComputationNode<ElemType>, pub
// -----------------------------------------------------------------------

#pragma push_macro("DeclareUnaryElementWiseWithOpCodeNode")
#define DeclareUnaryElementWiseWithOpCodeNode(Name, Forward, Backward, gradientFromOutput) \
template <class ElemType> \
class Name##Node : public UnaryElementWiseWithOpCodeNodeBase<ElemType, op##Forward, op##Backward, gradientFromOutput> \
{ \
typedef UnaryElementWiseWithOpCodeNodeBase<ElemType, op##Forward, op##Backward, gradientFromOutput> Base; \
UnaryElementWiseWithOpCodeNodeBaseMembers; \
static const std::wstring TypeName() \
{ \
return L## #Name; \
} \
\
public: \
DeclareConstructorFromConfigWithNumInputs(Name##Node); \
Name##Node(DEVICEID_TYPE deviceId, const wstring& Name) \
: Base(deviceId, Name) \
{ \
} \
}

// Name Forward and Backward opcodes Gradient from output?
DeclareUnaryElementWiseWithOpCodeNode(Sigmoid, Sigmoid, ElementwiseProductWithSigmoidDerivativeFromOutput, true);
DeclareUnaryElementWiseWithOpCodeNode(Tanh, Tanh, ElementwiseProductWithTanhDerivativeFromOutput, true);
DeclareUnaryElementWiseWithOpCodeNode(RectifiedLinear, LinearRectifier, ElementwiseProductWithLinearRectifierDerivativeFromOutput, true);
DeclareUnaryElementWiseWithOpCodeNode(Log, Log, ElementwiseProductWithLogDerivativeFromOutput, true);
DeclareUnaryElementWiseWithOpCodeNode(Exp, Exp, ElementwiseProduct, true);
DeclareUnaryElementWiseWithOpCodeNode(Cosine, Cosine, ElementwiseProductWithCosDerivative, false);
DeclareUnaryElementWiseWithOpCodeNode(Abs, Abs, ElementwiseProductWithAbsDerivative, false);
#define DeclareUnaryElementWiseWithOpCodeNode(Name, Forward, Backward, opType) \
template <class ElemType> \
class Name##Node : public UnaryElementWiseWithOpCodeNodeBase<ElemType, op##Forward, op##Backward, opType> \
{ \
typedef UnaryElementWiseWithOpCodeNodeBase<ElemType, op##Forward, op##Backward, opType> Base; \
UnaryElementWiseWithOpCodeNodeBaseMembers; \
static const std::wstring TypeName() \
{ \
return L## #Name; \
} \
\
public: \
DeclareConstructorFromConfigWithNumInputs(Name##Node); \
Name##Node(DEVICEID_TYPE deviceId, const wstring& Name) \
: Base(deviceId, Name) \
{ \
} \
}

// Name Forward and Backward opcodes Gradient optype
DeclareUnaryElementWiseWithOpCodeNode(Sigmoid, Sigmoid, ElementwiseProductWithSigmoidDerivativeFromOutput, BinaryWithOutputGradient);
DeclareUnaryElementWiseWithOpCodeNode(Tanh, Tanh, ElementwiseProductWithTanhDerivativeFromOutput, BinaryWithOutputGradient);
DeclareUnaryElementWiseWithOpCodeNode(RectifiedLinear, LinearRectifier, ElementwiseProductWithLinearRectifierDerivativeFromOutput, BinaryWithOutputGradient);
DeclareUnaryElementWiseWithOpCodeNode(Log, Log, ElementwiseProductWithLogDerivativeFromOutput, BinaryWithOutputGradient);
DeclareUnaryElementWiseWithOpCodeNode(Exp, Exp, ElementwiseProduct, BinaryWithOutputGradient);
DeclareUnaryElementWiseWithOpCodeNode(Cosine, Cosine, ElementwiseProductWithCosDerivative, BinaryWithInputGradient);
DeclareUnaryElementWiseWithOpCodeNode(Abs, Abs, ElementwiseProductWithAbsDerivative, BinaryWithInputGradient);
DeclareUnaryElementWiseWithOpCodeNode(Negate, Negate, Negate, UnaryGradient);
DeclareUnaryElementWiseWithOpCodeNode(Sqrt, Sqrt, ElementwiseProductWithSqrtDerivative, BinaryWithOutputGradient);

#pragma pop_macro("DeclareUnaryElementWiseWithOpCodeNode")

Expand Down
2 changes: 2 additions & 0 deletions Source/Math/CommonMatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ enum ElementWiseOperator
opElementwiseProductWithLogDerivativeFromOutput,
opElementwiseProductWithCosDerivative,
opElementwiseProductWithAbsDerivative,
opElementwiseProductWithSqrtDerivative,
opSqrOfDifference,
// binary ops for indexing
// opIndex,
Expand Down Expand Up @@ -163,6 +164,7 @@ enum ElementWiseOperator
Macro(ElementwiseProductWithLogDerivativeFromOutput); \
Macro(ElementwiseProductWithCosDerivative); \
Macro(ElementwiseProductWithAbsDerivative); \
Macro(ElementwiseProductWithSqrtDerivative); \
Macro(SqrOfDifference); \
//Macro(Index);

Expand Down
1 change: 1 addition & 0 deletions Source/Math/TensorOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ DefBinaryOp(ElementwiseProductWithLinearRectifierDerivativeFromOutput, b > 0 ? a
DefBinaryOp(ElementwiseProductWithLogDerivativeFromOutput, a* exp_(-b));
DefBinaryOp(ElementwiseProductWithCosDerivative, a * -sin_(b)); // note: b = input for cos()
DefBinaryOp(ElementwiseProductWithAbsDerivative, a * Sgn(b)); // note: b = input for abs()
DefBinaryOp(ElementwiseProductWithSqrtDerivative, a / (2 * b)); // b = output; d/dx sqrt(x) = 1/(2 * sqrt(x)) --> note this is the same as ElementwiseQuotient w a constant; if more show up like this we should add more template params
DefBinaryOp(SqrOfDifference, Sqr(a - b));
//DefBinaryOp(Index, IndexElement(a, b, i)); // note: this one uses the third argument

Expand Down

0 comments on commit 98ffc6f

Please sign in to comment.