Skip to content

Commit

Permalink
Added batch norm per-activation backprop implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexey Kamenev committed Feb 12, 2016
1 parent 4e9483b commit 21575f3
Show file tree
Hide file tree
Showing 3 changed files with 475 additions and 49 deletions.
28 changes: 22 additions & 6 deletions Source/Math/CuDnnConvolutionEngine.cu
Original file line number Diff line number Diff line change
Expand Up @@ -401,16 +401,18 @@ public:
t(scaleBiasT), ptr(scale), ptr(bias), expAvgFactor, ptr(runMean), ptr(runInvStdDev),
CUDNN_BN_MIN_EPSILON, ptr(saveMean), ptr(saveInvStdDev)));
}
else
else if (m_bnImpl == BatchNormImpl::Cntk)
{
if (spatial)
assert(false);
else
{
CUDA_CALL(BatchNormalizationForwardTraining(crowIn, inT.n(), ptr(in), ptr(out), ptr(scale), ptr(bias),
CUDNN_BN_MIN_EPSILON, ptr(saveMean), ptr(saveInvStdDev), m_stream));
CUDNN_BN_MIN_EPSILON, ptr(saveMean), ptr(saveInvStdDev), m_stream));
}
}
else
RuntimeError("Provided batch norm implementation (%d) is not supported.", m_bnImpl);
}

void NormalizeBatchInference(const Tensor4D& inT, const Mat& in, const Tensor4D& scaleBiasT, const Mat& scale, const Mat& bias,
Expand Down Expand Up @@ -472,11 +474,25 @@ public:
assert(scaleGrad.GetNumCols() == scale.GetNumCols());
assert(biasGrad.GetNumRows() == scale.GetNumRows());
assert(biasGrad.GetNumCols() == scale.GetNumCols());
UNUSED(crowIn);

cudnnBatchNormMode_t mode = spatial ? CUDNN_BATCHNORM_SPATIAL : CUDNN_BATCHNORM_PER_ACTIVATION;
CUDNN_CALL(cudnnBatchNormalizationBackward(m_cudnn, mode, &C::One, &C::One, t(inT), ptr(in), t(inT), ptr(srcGrad), t(inT), ptr(grad),
t(scaleBiasT), ptr(scale), ptr(scaleGrad), ptr(biasGrad), CUDNN_BN_MIN_EPSILON, ptr(saveMean), ptr(saveInvStdDev)));
if (m_bnImpl == BatchNormImpl::CuDnn)
{
cudnnBatchNormMode_t mode = spatial ? CUDNN_BATCHNORM_SPATIAL : CUDNN_BATCHNORM_PER_ACTIVATION;
CUDNN_CALL(cudnnBatchNormalizationBackward(m_cudnn, mode, &C::One, &C::One, t(inT), ptr(in), t(inT), ptr(srcGrad), t(inT), ptr(grad),
t(scaleBiasT), ptr(scale), ptr(scaleGrad), ptr(biasGrad), CUDNN_BN_MIN_EPSILON, ptr(saveMean), ptr(saveInvStdDev)));
}
else if (m_bnImpl == BatchNormImpl::Cntk)
{
if (spatial)
assert(false);
else
{
CUDA_CALL(BatchNormalizationBackward(crowIn, inT.n(), ptr(in), ptr(srcGrad), ptr(grad), ptr(scale), ptr(scaleGrad), ptr(biasGrad),
ptr(saveMean), ptr(saveInvStdDev), m_stream));
}
}
else
RuntimeError("Provided batch norm implementation (%d) is not supported.", m_bnImpl);
}

private:
Expand Down
Loading

0 comments on commit 21575f3

Please sign in to comment.