Skip to content

Commit

Permalink
Batch norm: added epsilon as parameter, enabled run mean/invstddev.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexey Kamenev committed Feb 12, 2016
1 parent 1379f58 commit 9ad6147
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 22 deletions.
2 changes: 1 addition & 1 deletion Source/ComputationNetworkLib/TrainingNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -1762,7 +1762,7 @@ class BatchNormalizationNode : public ComputationNode<ElemType>, public NumInput
m_saveInvStdDev->Resize(runMean.GetNumRows(), runMean.GetNumCols());

m_convEng->NormalizeBatch(*m_inT, sliceInputValue, *m_scaleBiasT, scale, bias, m_spatial, expAvgFactor, runMean, runInvStdDev,
sliceOutputValue, *m_saveMean, *m_saveInvStdDev);
sliceOutputValue, 1e-5, *m_saveMean, *m_saveInvStdDev);

m_mbCount++;
}
Expand Down
3 changes: 2 additions & 1 deletion Source/Math/ConvolutionEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ class DefaultConvolutionEngine : public ConvolutionEngine<ElemType>
}

void NormalizeBatch(const Tensor4D& inT, const Mat& in, const Tensor4D& scaleBiasT, const Mat& scale, const Mat& bias,
bool spatial, double expAvgFactor, Mat& runMean, Mat& runInvStdDev, Mat& out, Mat& saveMean, Mat& saveInvStdDev) override
bool spatial, double expAvgFactor, Mat& runMean, Mat& runInvStdDev, Mat& out, double epsilon, Mat& saveMean, Mat& saveInvStdDev) override
{
UNUSED(inT);
UNUSED(in);
Expand All @@ -296,6 +296,7 @@ class DefaultConvolutionEngine : public ConvolutionEngine<ElemType>
UNUSED(expAvgFactor);
UNUSED(runMean);
UNUSED(runInvStdDev);
UNUSED(epsilon);
UNUSED(saveMean);
UNUSED(saveInvStdDev);
RuntimeError("Not yet implemented.");
Expand Down
3 changes: 2 additions & 1 deletion Source/Math/ConvolutionEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,8 @@ class MATH_API ConvolutionEngine
virtual void BackwardBias(const Tensor4D& srcGradT, const Mat& srcGrad, const Tensor4D& biasT, Mat& biasGrad) = 0;

virtual void NormalizeBatch(const Tensor4D& inT, const Mat& in, const Tensor4D& scaleBiasT, const Mat& scale, const Mat& bias,
bool spatial, double expAvgFactor, Mat& runMean, Mat& runInvStdDev, Mat& out, Mat& saveMean, Mat& saveInvStdDev) = 0;
bool spatial, double expAvgFactor, Mat& runMean, Mat& runInvStdDev, Mat& out,
double epsilon, Mat& saveMean, Mat& saveInvStdDev) = 0;

virtual void NormalizeBatchInference(const Tensor4D& inT, const Mat& in, const Tensor4D& scaleBiasT, const Mat& scale, const Mat& bias,
bool spatial, const Mat& runMean, const Mat& runInvStdDev, Mat& out) = 0;
Expand Down
12 changes: 9 additions & 3 deletions Source/Math/CuDnnConvolutionEngine.cu
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,8 @@ public:
}

void NormalizeBatch(const Tensor4D& inT, const Mat& in, const Tensor4D& scaleBiasT, const Mat& scale, const Mat& bias,
bool spatial, double expAvgFactor, Mat& runMean, Mat& runInvStdDev, Mat& out, Mat& saveMean, Mat& saveInvStdDev) override
bool spatial, double expAvgFactor, Mat& runMean, Mat& runInvStdDev, Mat& out,
double epsilon, Mat& saveMean, Mat& saveInvStdDev) override
{
const size_t crowIn = inT.w() * inT.h() * inT.c();
if (spatial)
Expand Down Expand Up @@ -398,14 +399,19 @@ public:
if (m_bnImpl == BatchNormImpl::CuDnn)
{
cudnnBatchNormMode_t mode = spatial ? CUDNN_BATCHNORM_SPATIAL : CUDNN_BATCHNORM_PER_ACTIVATION;
epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);
CUDNN_CALL(cudnnBatchNormalizationForwardTraining(m_cudnn, mode, &C::One, &C::Zero, t(inT), ptr(in), t(inT), ptr(out),
t(scaleBiasT), ptr(scale), ptr(bias), expAvgFactor, ptr(runMean), ptr(runInvStdDev),
CUDNN_BN_MIN_EPSILON, ptr(saveMean), ptr(saveInvStdDev)));
epsilon, ptr(saveMean), ptr(saveInvStdDev)));
}
else if (m_bnImpl == BatchNormImpl::Cntk)
{
// No support for exp averaging for now.
assert(expAvgFactor == 1);
epsilon = std::max(epsilon, 1e-9);
CUDA_CALL(BatchNormalizationForwardTraining(inT, spatial, ptr(in), ptr(out), ptr(scale), ptr(bias),
CUDNN_BN_MIN_EPSILON, ptr(saveMean), ptr(saveInvStdDev), m_stream));
ptr(runMean), ptr(runInvStdDev), epsilon,
ptr(saveMean), ptr(saveInvStdDev), m_stream));
}
else
RuntimeError("Provided batch norm implementation (%d) is not supported.", m_bnImpl);
Expand Down
35 changes: 24 additions & 11 deletions Source/Math/CuDnnConvolutionEngine.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// for computing batch mean and variance (here inverse standard deviation) with one pass over the data.
// It uses algorithm described in: http://i.stanford.edu/pub/cstr/reports/cs/tr/79/773/CS-TR-79-773.pdf
template <int BlockDimX, int BlockDimY, int U, typename T>
__global__ void kComputeBatchMeanAndInvStdDev(int vectorSize, int batchSize, const T* x, double epsilon, T* xMean, T* xInvStdDev)
__global__ void kComputeBatchMeanAndInvStdDev(int vectorSize, int batchSize, const T* x, T* runMean, T* runInvStdDev,
double epsilon, T* xMean, T* xInvStdDev)
{
static_assert(BlockDimX * U == CUB_PTX_WARP_THREADS, "BlockDimX * U must be equal to warp size (32).");
static_assert((BlockDimX * BlockDimY % CUB_PTX_WARP_THREADS) == 0, "Block size must be a multiple of warp size (32).");
Expand Down Expand Up @@ -275,21 +276,24 @@ namespace Microsoft { namespace MSR { namespace CNTK {
}
size_t idxDstBase = (blockIdx.x * BlockDimX + threadIdx.x) * U;
StoreValues<U>(mean, xMean + idxDstBase);
StoreValues<U>(mean, runMean + idxDstBase);
Operations::RSqrt<T> rsqrtOp;
#pragma unroll
for (int k = 0; k < U; k++)
{
m2[k] = rsqrtOp(static_cast<T>(m2[k] / batchSize + epsilon));
}
StoreValues<U>(m2, xInvStdDev + idxDstBase);
StoreValues<U>(m2, runInvStdDev + idxDstBase);
}
}

// This kernel is very similar to kComputeBatchMeanAndInvStdDev except it reduces not just over N (mini-batch)
// but also W and H dimensions.
// REVIEW alexeyk: is it possible to combine this and previous kernel into a single kernel without hurting performance/readability much?
template <int BlockDimX, int BlockDimY, int U, typename T>
__global__ void kComputeSpatialBatchMeanAndInvStdDev(int vectorSize, int spatialSize, int batchSize, const T* x, double epsilon, T* xMean, T* xInvStdDev)
__global__ void kComputeSpatialBatchMeanAndInvStdDev(int vectorSize, int spatialSize, int batchSize, const T* x, T* runMean, T* runInvStdDev,
double epsilon, T* xMean, T* xInvStdDev)
{
static_assert(BlockDimX * U == CUB_PTX_WARP_THREADS, "BlockDimX * U must be equal to warp size (32).");
static_assert((BlockDimX * BlockDimY % CUB_PTX_WARP_THREADS) == 0, "Block size must be a multiple of warp size (32).");
Expand Down Expand Up @@ -421,16 +425,19 @@ namespace Microsoft { namespace MSR { namespace CNTK {
}

xMean[blockIdx.x] = mean[0];
runMean[blockIdx.x] = mean[0];
Operations::RSqrt<T> rsqrtOp;
xInvStdDev[blockIdx.x] = rsqrtOp(static_cast<T>(m2[0] / (batchSize * spatialSize) + epsilon));
m2[0] = rsqrtOp(static_cast<T>(m2[0] / (batchSize * spatialSize) + epsilon));
xInvStdDev[blockIdx.x] = m2[0];
runInvStdDev[blockIdx.x] = m2[0];
}
}

template <int U>
struct ComputeBatchMeanAndInvStdDev
{
template <typename T>
static void Call(size_t vectorSize, size_t batchSize, const T* x, double epsilon, T* xMean, T* xInvStdDev, cudaStream_t stream)
static void Call(size_t vectorSize, size_t batchSize, const T* x, T* runMean, T* runInvStdDev, double epsilon, T* xMean, T* xInvStdDev, cudaStream_t stream)
{
assert((vectorSize % U) == 0);

Expand All @@ -440,15 +447,16 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// Create grid with only one block in y(batch)-dimension as kernel uses striding.
auto gdim = dim3(static_cast<unsigned int>(RoundUpToMultiple(vectorSize, BlockDimX * U)));
kComputeBatchMeanAndInvStdDev<BlockDimX, BlockDimY, U><<<gdim, bdim, 0, stream>>>(
static_cast<int>(vectorSize), static_cast<int>(batchSize), x, epsilon, xMean, xInvStdDev);
static_cast<int>(vectorSize), static_cast<int>(batchSize),
x, runMean, runInvStdDev, epsilon, xMean, xInvStdDev);
}
};

template <int U>
struct ComputeSpatialBatchMeanAndInvStdDev
{
template <typename T>
static void Call(size_t vectorSize, size_t spatialSize, size_t batchSize, const T* x, double epsilon,
static void Call(size_t vectorSize, size_t spatialSize, size_t batchSize, const T* x, T* runMean, T* runInvStdDev, double epsilon,
T* xMean, T* xInvStdDev, cudaStream_t stream)
{
assert((vectorSize % spatialSize) == 0);
Expand All @@ -461,7 +469,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// Each thread block processes a single whole feature map independently (i.e. reduces over W, H and N dimensions).
auto gdim = dim3(static_cast<unsigned int>(vectorSize / spatialSize));
kComputeSpatialBatchMeanAndInvStdDev<BlockDimX, BlockDimY, U><<<gdim, bdim, 0, stream>>>(
static_cast<int>(vectorSize), static_cast<int>(spatialSize), static_cast<int>(batchSize), x, epsilon, xMean, xInvStdDev);
static_cast<int>(vectorSize), static_cast<int>(spatialSize), static_cast<int>(batchSize),
x, runMean, runInvStdDev,epsilon, xMean, xInvStdDev);
}
};

Expand Down Expand Up @@ -576,13 +585,16 @@ namespace Microsoft { namespace MSR { namespace CNTK {

template <typename T>
cudaError_t BatchNormalizationForwardTraining(const Tensor4D& t, bool spatial, const T* x, T* y,
const T* bnScale, const T* bnBias, double epsilon, T* saveMean, T* saveInvStdDev, cudaStream_t stream)
const T* bnScale, const T* bnBias, T* runMean, T* runInvStdDev,
double epsilon, T* saveMean, T* saveInvStdDev, cudaStream_t stream)
{
assert(nullptr != x);
assert(nullptr != y);
assert(nullptr != bnScale);
assert(nullptr != bnBias);
assert(std::isfinite(epsilon) && epsilon > 0);
assert(nullptr != runMean);
assert(nullptr != runInvStdDev);
assert(nullptr != saveMean);
assert(nullptr != saveInvStdDev);

Expand All @@ -594,16 +606,17 @@ namespace Microsoft { namespace MSR { namespace CNTK {

if (spatial)
{
Call<ComputeSpatialBatchMeanAndInvStdDev, T>(spatialSize, vectorSize, spatialSize, batchSize, x, epsilon,
saveMean, saveInvStdDev, stream);
Call<ComputeSpatialBatchMeanAndInvStdDev, T>(spatialSize, vectorSize, spatialSize, batchSize, x,
runMean, runInvStdDev, epsilon, saveMean, saveInvStdDev, stream);
cudaError_t err = GetLastCudaError();
if (cudaSuccess != err)
return err;

}
else
{
Call<ComputeBatchMeanAndInvStdDev, T>(vectorSize, vectorSize, batchSize, x, epsilon, saveMean, saveInvStdDev, stream);
Call<ComputeBatchMeanAndInvStdDev, T>(vectorSize, vectorSize, batchSize, x,
runMean, runInvStdDev, epsilon, saveMean, saveInvStdDev, stream);
cudaError_t err = GetLastCudaError();
if (cudaSuccess != err)
return err;
Expand Down
13 changes: 8 additions & 5 deletions Tests/UnitTests/MathTests/ConvolutionEngineTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,7 @@ BOOST_AUTO_TEST_CASE(BatchNormalizationForwardTrain)
auto& t = *std::move(std::get<0>(cfg));
bool spatial = std::get<1>(cfg);
double expAvg = 1;
double eps = 1e-5; // CUDNN_BN_MIN_EPSILON

size_t crow = t.w() * t.h() * t.c();
size_t ccol = t.n();
Expand Down Expand Up @@ -719,13 +720,13 @@ BOOST_AUTO_TEST_CASE(BatchNormalizationForwardTrain)
CudaTimer time1;
time1.Start();
engCntk->NormalizeBatch(t, in, *scaleBiasT, scale, bias, spatial, expAvg, runMean, runInvStdDev,
out, saveMean, saveInvStdDev);
out, eps, saveMean, saveInvStdDev);
time1.Stop();

CudaTimer time2;
time2.Start();
engCudnn->NormalizeBatch(t, in, *scaleBiasT, scale, bias, spatial, expAvg, runMeanExp, runInvStdDevExp,
outExp, saveMeanExp, saveInvStdDevExp);
outExp, eps, saveMeanExp, saveInvStdDevExp);
time2.Stop();

std::stringstream tmsg;
Expand All @@ -744,10 +745,12 @@ BOOST_AUTO_TEST_CASE(BatchNormalizationForwardTrain)
// REVIEW alexeyk: add cases for testing numerical stability.

BOOST_REQUIRE_MESSAGE(!runMean.HasNan("runMean"), "runMean" << msgNan);
//BOOST_REQUIRE_MESSAGE(runMean.IsEqualTo(runMeanExp, absErr), "runMean" << msg);
BOOST_REQUIRE_MESSAGE(CheckEqual(runMean, runMeanExp, emsg, relErr, absErr), "runMean" << msg << ". " << emsg);
BOOST_REQUIRE_MESSAGE(CountNans(runMeanBuf) == crowScaleBias * 2, "runMean" << msgNotNan);

BOOST_REQUIRE_MESSAGE(!runInvStdDev.HasNan("runInvStdDev"), "runInvStdDev" << msgNan);
//BOOST_REQUIRE_MESSAGE(runInvStdDev.IsEqualTo(runInvStdDevExp, absErr), "runInvStdDev" << msg);
BOOST_REQUIRE_MESSAGE(CheckEqual(runInvStdDev, runInvStdDevExp, emsg, relErr, absErr), "runInvStdDev" << msg << ". " << emsg);
BOOST_REQUIRE_MESSAGE(CountNans(runInvStdDevBuf) == crowScaleBias * 2, "runInvStdDev" << msgNotNan);

BOOST_REQUIRE_MESSAGE(!saveMean.HasNan("saveMean"), "saveMean" << msgNan);
BOOST_REQUIRE_MESSAGE(CheckEqual(saveMean, saveMeanExp, emsg, relErr, absErr), "saveMean" << msg << ". " << emsg);
Expand Down Expand Up @@ -874,7 +877,7 @@ BOOST_AUTO_TEST_CASE(BatchNormalizationBackward)
if (crow >= 32 && ccol >= 32)
{
// Use conservative estimates.
int speedup = 1;
float speedup = 1.3f;
BOOST_REQUIRE_MESSAGE(speedup * elapsedCntk < elapsedCudnn,
"CNTK implementation (" << elapsedCntk << "ms) must be faster than cuDNN (" << elapsedCudnn << "ms) by at least " << speedup << "x, what's changed? " << tmsg.str());
}
Expand Down

0 comments on commit 9ad6147

Please sign in to comment.