Skip to content

Commit

Permalink
Addressed code review feedback.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexey Kamenev committed Feb 12, 2016
1 parent e12427d commit 03e8cc5
Show file tree
Hide file tree
Showing 6 changed files with 306 additions and 163 deletions.
2 changes: 1 addition & 1 deletion Source/CNTK/BrainScript/ExperimentalNetworkBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ wstring computationNodes = // TODO: use actual TypeName() here? would first need
L"ColumnwiseCrossProduct = KhatriRaoProduct // deprecated \n" // TODO: should it be deprecated? It is described as easier to understand in the CNTKBook.
L"ClassificationError = ErrorPrediction \n"
L"Delay = PastValue \n" // TODO: should it allow negative offsets and an if test here?
L"BatchNormalization(input, scale, bias, runMean, runInvStdDev, eval, spatial, expAvgFactor, epsilon, useCntkEngine, imageLayout='CHW', tag='') = new ComputationNode [ operation = 'BatchNormalization' ; inputs = (input : scale : bias : runMean : runInvStdDev) /*plus the function args*/ ]\n"
L"BatchNormalization(input, scale, bias, runMean, runInvStdDev, eval, spatial, expAvgFactor = 1.0, epsilon = 0.00001, useCntkEngine = true, imageLayout='CHW', tag='') = new ComputationNode [ operation = 'BatchNormalization' ; inputs = (input : scale : bias : runMean : runInvStdDev) /*plus the function args*/ ]\n"
// standard nodes. We use macros to define these strings.
#define UnaryStandardNode(Op, a) L## #Op L"(" L## #a L", tag='') = new ComputationNode [ operation = '" L## #Op L"' ; inputs = " L## #a L" /*plus the function args*/ ]\n"
#define BinaryStandardNode(Op, a, b) L## #Op L"(" L## #a L", " L## #b L", tag='') = new ComputationNode [ operation = '" L## #Op L"' ; inputs = (" L## #a L" : " L## #b L") /*plus the function args*/ ]\n"
Expand Down
16 changes: 13 additions & 3 deletions Source/Math/CuDnnConvolutionEngine.cu
Original file line number Diff line number Diff line change
Expand Up @@ -456,9 +456,19 @@ public:
UNUSED(crowIn); // crowIn used only in asserts.
#endif

cudnnBatchNormMode_t mode = spatial ? CUDNN_BATCHNORM_SPATIAL : CUDNN_BATCHNORM_PER_ACTIVATION;
CUDNN_CALL(cudnnBatchNormalizationForwardInference(m_cudnn, mode, &C::One, &C::Zero, t(inT), ptr(in), t(inT), ptr(out),
t(scaleBiasT), ptr(scale), ptr(bias), ptr(runMean), ptr(runInvStdDev), CUDNN_BN_MIN_EPSILON));
if (m_bnImpl == BatchNormImpl::CuDnn)
{
cudnnBatchNormMode_t mode = spatial ? CUDNN_BATCHNORM_SPATIAL : CUDNN_BATCHNORM_PER_ACTIVATION;
CUDNN_CALL(cudnnBatchNormalizationForwardInference(m_cudnn, mode, &C::One, &C::Zero, t(inT), ptr(in), t(inT), ptr(out),
t(scaleBiasT), ptr(scale), ptr(bias), ptr(runMean), ptr(runInvStdDev), CUDNN_BN_MIN_EPSILON));
}
else if (m_bnImpl == BatchNormImpl::Cntk)
{;
CUDA_CALL(BatchNormalizationForwardInference(inT, spatial, ptr(in), ptr(out), ptr(scale), ptr(bias),
ptr(runMean), ptr(runInvStdDev), m_stream));
}
else
RuntimeError("Provided batch norm implementation (%d) is not supported.", m_bnImpl);
}

void BackwardNormalizeBatch(const Tensor4D& inT, const Mat& in, const Mat& srcGrad, Mat& grad,
Expand Down
Loading

0 comments on commit 03e8cc5

Please sign in to comment.