Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/Microsoft/CNTK into amita…
Browse files Browse the repository at this point in the history
…ga/evalBugFix
  • Loading branch information
amitaga committed Feb 18, 2016
2 parents 0899cee + 20d461c commit 3a8d7be
Showing 1 changed file with 38 additions and 30 deletions.
68 changes: 38 additions & 30 deletions Source/Math/CuDnnConvolutionEngine.cu
Original file line number Diff line number Diff line change
Expand Up @@ -268,13 +268,10 @@ public:
using typename Base::ConvDesc;

CuDnnConvolutionEngine(size_t maxTempMemSizeInSamples, BatchNormImpl bnImpl)
: m_maxTempMemSizeInSamples(maxTempMemSizeInSamples), m_bnImpl(bnImpl), m_stream(GetStream()), m_cudnn(nullptr), m_curMBSize(0)
: m_maxTempMemSizeInSamples(maxTempMemSizeInSamples), m_bnImpl(bnImpl), m_stream(GetStream()), m_cudnn(nullptr)
{
CUDNN_CALL(cudnnCreate(&m_cudnn));
CUDNN_CALL(cudnnSetStream(m_cudnn, m_stream));
m_fwdAlgo.status = CUDNN_STATUS_NOT_INITIALIZED;
m_backDataAlgo.status = CUDNN_STATUS_NOT_INITIALIZED;
m_backFiltAlgo.status = CUDNN_STATUS_NOT_INITIALIZED;
}

~CuDnnConvolutionEngine()
Expand All @@ -299,11 +296,11 @@ public:

// Find best algo and allocate temp buffer, if needed.
FindBestForwardAlgo(t(inT), f(filterT), cd(convDesc), t(outT));
if (m_fwdAlgo.memory > 0)
workspace.Resize((m_fwdAlgo.memory + sizeof(ElemType) - 1) / sizeof(ElemType), 1);
if (m_fwdAlgo.Algo.memory > 0)
workspace.Resize((m_fwdAlgo.Algo.memory + sizeof(ElemType) - 1) / sizeof(ElemType), 1);
// Perform forward convolution operation.
CUDNN_CALL(cudnnConvolutionForward(m_cudnn, &C::One, t(inT), ptr(in), f(filterT), ptr(filter), cd(convDesc), m_fwdAlgo.algo,
ptr(workspace), m_fwdAlgo.memory, &C::Zero, t(outT), ptr(out)));
CUDNN_CALL(cudnnConvolutionForward(m_cudnn, &C::One, t(inT), ptr(in), f(filterT), ptr(filter), cd(convDesc), m_fwdAlgo.Algo.algo,
ptr(workspace), m_fwdAlgo.Algo.memory, &C::Zero, t(outT), ptr(out)));
}

void BackwardData(const Tensor4D& srcGradT, const Mat& srcGrad, const Filter& filterT, const Mat& filter, const ConvDesc& convDesc,
Expand All @@ -320,11 +317,11 @@ public:

// Find best algo and allocate temp buffer, if needed.
FindBestBackwardDataAlgo(f(filterT), t(srcGradT), cd(convDesc), t(gradT));
if (m_backDataAlgo.memory > 0)
workspace.Resize((m_backDataAlgo.memory + sizeof(ElemType) - 1) / sizeof(ElemType), 1);
if (m_backDataAlgo.Algo.memory > 0)
workspace.Resize((m_backDataAlgo.Algo.memory + sizeof(ElemType) - 1) / sizeof(ElemType), 1);
// Compute gradients with respect to the output tensor (data).
CUDNN_CALL(cudnnConvolutionBackwardData(m_cudnn, &C::One, f(filterT), ptr(filter), t(srcGradT), ptr(srcGrad), cd(convDesc), m_backDataAlgo.algo,
ptr(workspace), m_backDataAlgo.memory, &C::One, t(gradT), ptr(grad)));
CUDNN_CALL(cudnnConvolutionBackwardData(m_cudnn, &C::One, f(filterT), ptr(filter), t(srcGradT), ptr(srcGrad), cd(convDesc), m_backDataAlgo.Algo.algo,
ptr(workspace), m_backDataAlgo.Algo.memory, &C::One, t(gradT), ptr(grad)));
}

void BackwardFilter(const Tensor4D& srcGradT, const Mat& srcGrad, const Tensor4D& inT, const Mat& in, const ConvDesc& convDesc,
Expand All @@ -341,11 +338,11 @@ public:

// Find best algo and allocate temp buffer, if needed.
FindBestBackwardFilterAlgo(t(inT), t(srcGradT), cd(convDesc), f(filterT));
if (m_backFiltAlgo.memory > 0)
workspace.Resize((m_backFiltAlgo.memory + sizeof(ElemType) - 1) / sizeof(ElemType), 1);
if (m_backFiltAlgo.Algo.memory > 0)
workspace.Resize((m_backFiltAlgo.Algo.memory + sizeof(ElemType) - 1) / sizeof(ElemType), 1);
// Compute gradients with respect to the output tensor (data).
CUDNN_CALL(cudnnConvolutionBackwardFilter(m_cudnn, &C::One, t(inT), ptr(in), t(srcGradT), ptr(srcGrad), cd(convDesc), m_backFiltAlgo.algo,
ptr(workspace), m_backFiltAlgo.memory, &C::One, f(filterT), ptr(filter)));
CUDNN_CALL(cudnnConvolutionBackwardFilter(m_cudnn, &C::One, t(inT), ptr(in), t(srcGradT), ptr(srcGrad), cd(convDesc), m_backFiltAlgo.Algo.algo,
ptr(workspace), m_backFiltAlgo.Algo.memory, &C::One, f(filterT), ptr(filter)));
}

void AddBias(const Tensor4D& outT, const Mat& out, const Tensor4D& biasT, const Mat& bias, Mat& dst) override
Expand Down Expand Up @@ -525,7 +522,7 @@ private:
// Need to re-run auto-tuner in case batch size has been changed.
// We assume no other dimensions of tensors can change so we don't check it.
// REVIEW alexeyk: is this a safe assumption? Can convolution configuration change in runtime?
if (m_fwdAlgo.status == CUDNN_STATUS_SUCCESS && inT.n() == m_curMBSize && outT.n() == m_curMBSize)
if (m_fwdAlgo.Algo.status == CUDNN_STATUS_SUCCESS && inT.n() == m_fwdAlgo.CurMBSize && outT.n() == m_fwdAlgo.CurMBSize)
return;
const int MaxAlgoCount = 10;
int calgo = 0;
Expand All @@ -540,13 +537,13 @@ private:
});
if (res == algoPerf + calgo)
RuntimeError("cuDNN could not find suitable algorithm for cudnnConvolutionForward.");
m_curMBSize = inT.n();
m_fwdAlgo = *res;
m_fwdAlgo.CurMBSize = inT.n();
m_fwdAlgo.Algo = *res;
}

void FindBestBackwardDataAlgo(const CuDnnFilter& filtT, const CuDnnTensor4D& srcGradT, const CuDnnConvolutionDescriptor& convDesc, const CuDnnTensor4D& gradT)
{
if (m_backDataAlgo.status == CUDNN_STATUS_SUCCESS && srcGradT.n() == m_curMBSize && gradT.n() == m_curMBSize)
if (m_backDataAlgo.Algo.status == CUDNN_STATUS_SUCCESS && srcGradT.n() == m_backDataAlgo.CurMBSize && gradT.n() == m_backDataAlgo.CurMBSize)
return;
const int MaxAlgoCount = 10;
int calgo = 0;
Expand All @@ -561,13 +558,13 @@ private:
});
if (res == algoPerf + calgo)
RuntimeError("cuDNN could not find suitable algorithm for cudnnConvolutionBackwardData.");
m_curMBSize = srcGradT.n();
m_backDataAlgo = *res;
m_backDataAlgo.CurMBSize = srcGradT.n();
m_backDataAlgo.Algo = *res;
}

void FindBestBackwardFilterAlgo(const CuDnnTensor4D& inT, const CuDnnTensor4D& srcGradT, const CuDnnConvolutionDescriptor& convDesc, const CuDnnFilter& filtT)
{
if (m_backFiltAlgo.status == CUDNN_STATUS_SUCCESS && inT.n() == m_curMBSize && srcGradT.n() == m_curMBSize)
if (m_backFiltAlgo.Algo.status == CUDNN_STATUS_SUCCESS && inT.n() == m_backFiltAlgo.CurMBSize && srcGradT.n() == m_backFiltAlgo.CurMBSize)
return;
const int MaxAlgoCount = 10;
int calgo = 0;
Expand All @@ -582,23 +579,34 @@ private:
});
if (res == algoPerf + calgo)
RuntimeError("cuDNN could not find suitable algorithm for cudnnConvolutionBackwardFilter.");
m_curMBSize = inT.n();
m_backFiltAlgo = *res;
m_backFiltAlgo.CurMBSize = inT.n();
m_backFiltAlgo.Algo = *res;
}

private:
template <typename T>
struct ConvAlgoInfo
{
ConvAlgoInfo()
: CurMBSize(0)
{
Algo.status = CUDNN_STATUS_NOT_INITIALIZED;
}
// Current mini-batch size, needed for re-computing statistics in auto-tuner.
size_t CurMBSize;
T Algo;
};

using C = Consts<ElemType>;

// REVIEW alexeyk: currently limit is set once in ctor though in CNTK it can be, theoretically, changed in runtime.
size_t m_maxTempMemSizeInSamples;
BatchNormImpl m_bnImpl;
cudnnHandle_t m_cudnn;
cudaStream_t m_stream;
// Current mini-batch size, needed for re-computing statistics in auto-tuner.
size_t m_curMBSize;
cudnnConvolutionFwdAlgoPerf_t m_fwdAlgo;
cudnnConvolutionBwdDataAlgoPerf_t m_backDataAlgo;
cudnnConvolutionBwdFilterAlgoPerf_t m_backFiltAlgo;
ConvAlgoInfo<cudnnConvolutionFwdAlgoPerf_t> m_fwdAlgo;
ConvAlgoInfo<cudnnConvolutionBwdDataAlgoPerf_t> m_backDataAlgo;
ConvAlgoInfo<cudnnConvolutionBwdFilterAlgoPerf_t> m_backFiltAlgo;
};

template <class ElemType>
Expand Down

0 comments on commit 3a8d7be

Please sign in to comment.