Skip to content

Commit

Permalink
Added convolutional implementation for batch norm.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexey Kamenev committed Feb 12, 2016
1 parent 5846fae commit 1379f58
Show file tree
Hide file tree
Showing 3 changed files with 418 additions and 123 deletions.
20 changes: 6 additions & 14 deletions Source/Math/CuDnnConvolutionEngine.cu
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ public:
assert(inT.n() == in.GetNumCols());
assert(saveMean.GetNumElements() >= runMean.GetNumElements());
assert(saveInvStdDev.GetNumElements() >= runInvStdDev.GetNumElements());
UNUSED(crowIn);

if (m_bnImpl == BatchNormImpl::CuDnn)
{
Expand All @@ -403,13 +404,8 @@ public:
}
else if (m_bnImpl == BatchNormImpl::Cntk)
{
if (spatial)
assert(false);
else
{
CUDA_CALL(BatchNormalizationForwardTraining(crowIn, inT.n(), ptr(in), ptr(out), ptr(scale), ptr(bias),
CUDNN_BN_MIN_EPSILON, ptr(saveMean), ptr(saveInvStdDev), m_stream));
}
CUDA_CALL(BatchNormalizationForwardTraining(inT, spatial, ptr(in), ptr(out), ptr(scale), ptr(bias),
CUDNN_BN_MIN_EPSILON, ptr(saveMean), ptr(saveInvStdDev), m_stream));
}
else
RuntimeError("Provided batch norm implementation (%d) is not supported.", m_bnImpl);
Expand Down Expand Up @@ -474,6 +470,7 @@ public:
assert(scaleGrad.GetNumCols() == scale.GetNumCols());
assert(biasGrad.GetNumRows() == scale.GetNumRows());
assert(biasGrad.GetNumCols() == scale.GetNumCols());
UNUSED(crowIn);

if (m_bnImpl == BatchNormImpl::CuDnn)
{
Expand All @@ -483,13 +480,8 @@ public:
}
else if (m_bnImpl == BatchNormImpl::Cntk)
{
if (spatial)
assert(false);
else
{
CUDA_CALL(BatchNormalizationBackward(crowIn, inT.n(), ptr(in), ptr(srcGrad), ptr(grad), ptr(scale), ptr(scaleGrad), ptr(biasGrad),
ptr(saveMean), ptr(saveInvStdDev), m_stream));
}
CUDA_CALL(BatchNormalizationBackward(inT, spatial, ptr(in), ptr(srcGrad), ptr(grad), ptr(scale), ptr(scaleGrad), ptr(biasGrad),
ptr(saveMean), ptr(saveInvStdDev), m_stream));
}
else
RuntimeError("Provided batch norm implementation (%d) is not supported.", m_bnImpl);
Expand Down
Loading

0 comments on commit 1379f58

Please sign in to comment.