Skip to content

Commit

Permalink
Added GEMM convo engine basic implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexey Kamenev committed Apr 21, 2016
1 parent fb50278 commit dcb6654
Show file tree
Hide file tree
Showing 7 changed files with 205 additions and 68 deletions.
7 changes: 7 additions & 0 deletions Source/Math/CPUMatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4135,6 +4135,13 @@ void CPUMatrix<ElemType>::ConvolutionBackwardKernel(const CPUMatrix<ElemType>& i
}
}

template <class ElemType>
void CPUMatrix<ElemType>::UnrollConvolutionInput(const CPUMatrix<ElemType>& kernel, const CPUMatrix<int>& mpRowCol, const CPUMatrix<int>& mpRowIwht,
const CPUMatrix<int>& mpRowRun, const CPUMatrix<int>& runs, CPUMatrix<ElemType>& output) const
{
UNUSED(kernel); UNUSED(mpRowCol); UNUSED(mpRowIwht); UNUSED(mpRowRun); UNUSED(runs); UNUSED(output);
}

template <class ElemType>
void CPUMatrix<ElemType>::MaxPoolingForward(const CPUMatrix<int>& mpRowCol, const CPUMatrix<int>& mpRowIndices, const CPUMatrix<int>& indices, CPUMatrix<ElemType>& output) const
{
Expand Down
3 changes: 3 additions & 0 deletions Source/Math/CPUMatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,9 @@ class MATH_API CPUMatrix : public BaseMatrix<ElemType>
void ConvolutionBackwardKernel(const CPUMatrix<ElemType>& in, const CPUMatrix<int>& mpRowCol, const CPUMatrix<int>& mpRowIwht,
const CPUMatrix<int>& mpRowRun, const CPUMatrix<int>& runs, CPUMatrix<ElemType>& kernelGrad) const;

void UnrollConvolutionInput(const CPUMatrix<ElemType>& kernel, const CPUMatrix<int>& mpRowCol, const CPUMatrix<int>& mpRowIwht,
const CPUMatrix<int>& mpRowRun, const CPUMatrix<int>& runs, CPUMatrix<ElemType>& output) const;

void MaxPoolingForward(const CPUMatrix<int>& mpRowCol, const CPUMatrix<int>& mpRowIndices, const CPUMatrix<int>& indices, CPUMatrix<ElemType>& output) const;
void MaxPoolingBackward(const CPUMatrix<ElemType>& out, const CPUMatrix<ElemType>& in,
const CPUMatrix<int>& mpRowCol, const CPUMatrix<int>& mpRowIndices, const CPUMatrix<int>& indices,
Expand Down
100 changes: 98 additions & 2 deletions Source/Math/ConvolutionEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,13 +210,13 @@ class ReferenceConvolutionEngine : public ConvolutionEngine<ElemType>
InvalidArgument("Pooling type %d is not supported.", (int)m_poolKind);
}

private:
protected:
static bool IsGpu(DEVICEID_TYPE deviceId)
{
return deviceId >= 0;
}

private:
protected:
using IntMatPtr = std::unique_ptr<Matrix<int>>;

Matrix<int> m_mpRowCol;
Expand Down Expand Up @@ -511,6 +511,96 @@ class LegacyConvolutionEngine : public ConvolutionEngine<ElemType>
bool m_gpuSparse1D;
};

//------------------------------------------------------------------
// GEMM convolution engine implementation.
// This engine supports arbitrary convolution configuration with full
// sharing and implemented using unroll + GEMM technique
// (High performance convolutional neural networks for document processing; Chellapilla, Puri, Simard)
// Uses reference engine for pooling operations.
//------------------------------------------------------------------
template <class ElemType>
class GemmConvolutionEngine : public ReferenceConvolutionEngine<ElemType>
{
public:
using Base = ReferenceConvolutionEngine<ElemType>;
using typename Base::Mat;

public:
GemmConvolutionEngine(ConvolveGeometryPtr geometry, DEVICEID_TYPE deviceId, ImageLayoutKind imageLayout, size_t maxTempMemSizeInSamples, PoolKind poolKind)
: Base(geometry, deviceId, imageLayout, maxTempMemSizeInSamples, poolKind)
{
}

protected:
using typename Base::IntMatPtr;

using Base::m_geometry;
using Base::m_deviceId;
using Base::m_imageLayout;
using Base::m_maxTempMemSizeInSamples;

using Base::m_mpRowCol;
using Base::m_mpRowIwht;
using Base::m_mpRowRun;
using Base::m_runs;

void EnsureCompatible() override
{
if (m_imageLayout != ImageLayoutKind::CHW)
RuntimeError("GEMM convolution engine supports only CHW/cudnn layout.");
if (IsGpu(m_deviceId))
RuntimeError("GEMM convolution engine currently supports only CPU device.");
}

void ForwardCore(const Mat& in, const Mat& kernel, Mat& out, Mat& workspace) override
{
size_t batchSize = in.GetNumCols();
size_t subBatchSize = m_maxTempMemSizeInSamples == 0 ? batchSize : min(batchSize, m_maxTempMemSizeInSamples);

size_t mapCount = m_geometry->GetMapCount(m_geometry->InputShape().GetRank() - 1);
size_t mapOutSize = m_geometry->OutputShape().GetNumElements() / mapCount;
size_t unrollRows = mapOutSize * subBatchSize;
size_t unrollCols = m_geometry->KernelShape().GetNumElements();
// Reserve space for unrolled inputs and, if needed, intermediate outputs.
// Intermediate outputs will be transposed to final outputs after GEMM operation.
// Transpose is not required if subBatchSize == 1.
workspace.Resize(unrollRows, unrollCols + (subBatchSize > 1 ? mapCount : 0));

for (size_t start = 0; start < batchSize; start += subBatchSize)
{
size_t curBatchSize = min(subBatchSize, batchSize - start);
auto inputSlice = in.ColumnSlice(start, curBatchSize);
auto unrolledInput = workspace.ColumnSlice(0, unrollCols);
// Need to reshape (soft transpose) as matrices are column-major.
unrolledInput.Reshape(unrollCols, unrollRows);
inputSlice.UnrollConvolutionInput(kernel, m_mpRowCol, *m_mpRowIwht, *m_mpRowRun, *m_runs, unrolledInput);

auto outputSlice = out.ColumnSlice(start, curBatchSize);
outputSlice.Reshape(unrollRows, mapCount);
Mat::Multiply(unrolledInput, true, kernel, true, outputSlice);
}
}

void BackwardDataCore(const Mat& srcGrad, const Mat& kernel, Mat& grad, Mat& workspace) override
{
UNUSED(srcGrad); UNUSED(kernel); UNUSED(grad); UNUSED(workspace);
//srcGrad.ConvolutionBackwardData(kernel, m_mpRowCol, *m_mpRowIwht, *m_mpRowRun, *m_runs, grad);
}

void BackwardKernelCore(const Mat& srcGrad, const Mat& in, Mat& kernelGrad, bool /*allowReuse*/, Mat& workspace) override
{
UNUSED(srcGrad); UNUSED(in); UNUSED(kernelGrad); UNUSED(workspace);
//srcGrad.ConvolutionBackwardKernel(in, m_mpRowCol, *m_mpRowIwht, *m_mpRowRun, *m_runs, kernelGrad);
}

public:
static bool IsSupported(DEVICEID_TYPE deviceId, ConvolveGeometryPtr geometry)
{
return deviceId < 0 &&
find(begin(geometry->Sharing()), end(geometry->Sharing()), false) == end(geometry->Sharing());
}
};

template <class ElemType>
std::unique_ptr<ConvolutionEngine<ElemType>> ConvolutionEngine<ElemType>::Create(ConvolveGeometryPtr geometry, DEVICEID_TYPE deviceId,
ImageLayoutKind imageLayout, size_t maxTempMemSizeInSamples, PoolKind poolKind,
Expand Down Expand Up @@ -539,6 +629,12 @@ std::unique_ptr<ConvolutionEngine<ElemType>> ConvolutionEngine<ElemType>::Create
return CuDnnConvolutionEngineFactory<ElemType>::Create(geometry, deviceId, imageLayout, maxTempMemSizeInSamples, poolKind);
}

if (isEnabled(ConvolutionEngineKind::Gemm) && GemmConvolutionEngine<ElemType>::IsSupported(deviceId, geometry))
{
fprintf(stderr, "\nUsing GEMM convolution engine for geometry: %s.\n", engStr.c_str());
return std::make_unique<GemmConvolutionEngine<ElemType>>(geometry, deviceId, imageLayout, maxTempMemSizeInSamples, poolKind);
}

if (!isEnabled(ConvolutionEngineKind::Reference))
RuntimeError("Reference convolution is disabled and no other engine supports such configuratin (or disabled).");
fprintf(stderr, "\nUsing reference convolution engine for geometry: %s.\n", engStr.c_str());
Expand Down
9 changes: 5 additions & 4 deletions Source/Math/ConvolutionEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ namespace Microsoft { namespace MSR { namespace CNTK {
enum class ConvolutionEngineKind
{
None = 0,
Reference = 1,
CuDnn = 1 << 1,
Legacy = 1 << 2,
Reference = 1, // Reference, lookup-based implementation. Very slow but works for any convo configuration.
CuDnn = 1 << 1, // cuDNN, works only for 2D/3D convos with full sharing.
Legacy = 1 << 2, // Legacy, for backwards compatibility. REVIEW alexeyk: implement sparse version and remove Legacy altogether.
Gemm = 1 << 3, // Uses convolution unrolling+GEMM technique. Works only for convos with full sharing.

All = Reference | CuDnn | Legacy
All = Reference | CuDnn | Legacy | Gemm
};

enum class PoolKind
Expand Down
20 changes: 20 additions & 0 deletions Source/Math/Matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4030,6 +4030,26 @@ void Matrix<ElemType>::ConvolutionBackwardKernel(const Matrix<ElemType>& in, con
NOT_IMPLEMENTED);
}

template <class ElemType>
void Matrix<ElemType>::UnrollConvolutionInput(const Matrix<ElemType>& kernel, const Matrix<int>& mpRowCol, const Matrix<int>& mpRowIwht,
const Matrix<int>& mpRowRun, const Matrix<int>& runs, Matrix<ElemType>& output) const
{
assert(mpRowCol.GetNumCols() == 1);
assert(mpRowIwht.GetNumCols() == 1);
assert(mpRowRun.GetNumCols() == 1);
assert(runs.GetNumCols() == 1);

DecideAndMoveToRightDevice(*this, output);

DISPATCH_MATRIX_ON_FLAG(this,
this,
m_CPUMatrix->UnrollConvolutionInput(*(kernel.m_CPUMatrix), *(mpRowCol.m_CPUMatrix), *(mpRowIwht.m_CPUMatrix),
*(mpRowRun.m_CPUMatrix), *(runs.m_CPUMatrix), *(output.m_CPUMatrix)),
NOT_IMPLEMENTED,
NOT_IMPLEMENTED,
NOT_IMPLEMENTED);
}

template <class ElemType>
void Matrix<ElemType>::MaxPoolingForward(const Matrix<int>& mpRowCol, const Matrix<int>& mpRowIndices, const Matrix<int>& indices, Matrix<ElemType>& output) const
{
Expand Down
3 changes: 3 additions & 0 deletions Source/Math/Matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,9 @@ class MATH_API Matrix : public MatrixBase
void ConvolutionBackwardKernel(const Matrix<ElemType>& in, const Matrix<int>& mpRowCol, const Matrix<int>& mpRowIwht,
const Matrix<int>& mpRowRun, const Matrix<int>& runs, Matrix<ElemType>& kernelGrad) const;

void UnrollConvolutionInput(const Matrix<ElemType>& kernel, const Matrix<int>& mpRowCol, const Matrix<int>& mpRowIwht,
const Matrix<int>& mpRowRun, const Matrix<int>& runs, Matrix<ElemType>& output) const;

void MaxPoolingForward(const Matrix<int>& mpRowCol, const Matrix<int>& mpRowIndices, const Matrix<int>& indices, Matrix<ElemType>& output) const;
void MaxPoolingBackward(const Matrix<ElemType>& out, const Matrix<ElemType>& in,
const Matrix<int>& mpRowCol, const Matrix<int>& mpRowIndices, const Matrix<int>& indices,
Expand Down
131 changes: 69 additions & 62 deletions Tests/UnitTests/MathTests/ConvolutionEngineTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,30 +51,30 @@ std::vector<ConvolveGeometryPtr> GenerateConvTestConfigs()
{
std::vector<ConvolveGeometryPtr> res;
// REVIEW alexeyk: add test cases with even dimensions of a kernel. There are some corner cases which cuDNN does not support (which essentially require negative padding).
for (size_t kW : {1, 3})
{
for (size_t kH : {1, 3})
{
for (size_t inW : {kW, 2 * kW, 2 * kW - 1})
{
for (size_t inC : {1, 3})
{
for (size_t mapCount : {1, 5})
{
for (size_t stride : {1, min((int)kW, min((int)kH, 2))})
{
// Note: must use sharing=false in channel dimension otherwise geometry will not be cuDNN compatible but cuDNN won't fail.
res.push_back(std::make_shared<ConvolveGeometry>(TensorShape(inW, max(kH, inW) + 1, inC),
TensorShape(kW, kH, inC), TensorShape(mapCount), TensorShape(stride, stride, inC),
ConvolveGeometry::BoolVec{true},
ConvolveGeometry::BoolVec{(kW & 1) != 0, (kH & 1) != 0, false},
TensorShape(0), TensorShape(0)));
}
}
}
}
}
}
//for (size_t kW : {1, 3})
//{
// for (size_t kH : {1, 3})
// {
// for (size_t inW : {kW, 2 * kW, 2 * kW - 1})
// {
// for (size_t inC : {1, 3})
// {
// for (size_t mapCount : {1, 5})
// {
// for (size_t stride : {1, min((int)kW, min((int)kH, 2))})
// {
// // Note: must use sharing=false in channel dimension otherwise geometry will not be cuDNN compatible but cuDNN won't fail.
// res.push_back(std::make_shared<ConvolveGeometry>(TensorShape(inW, max(kH, inW) + 1, inC),
// TensorShape(kW, kH, inC), TensorShape(mapCount), TensorShape(stride, stride, inC),
// ConvolveGeometry::BoolVec{true},
// ConvolveGeometry::BoolVec{(kW & 1) != 0, (kH & 1) != 0, false},
// TensorShape(0), TensorShape(0)));
// }
// }
// }
// }
// }
//}
// For debugging.
res.push_back(std::make_shared<ConvolveGeometry>(TensorShape(3, 3, 1),
TensorShape(3, 3, 1), TensorShape(2), TensorShape(1, 1, 1),
Expand Down Expand Up @@ -152,51 +152,58 @@ BOOST_AUTO_TEST_CASE(ConvolutionForward)
};

int baseDeviceId = 0;
auto engKind = ConvolutionEngineKind::Reference;
for (int deviceId : {-1, 0})
//for (auto engKind : {ConvolutionEngineKind::Reference, ConvolutionEngineKind::Gemm})
for (auto engKind : {ConvolutionEngineKind::Gemm})
{
for (const auto& g : GenerateConvTestConfigs())
for (int deviceId : {-1, 0})
{
auto baseEng = ConvEng::Create(g, baseDeviceId, ImageLayoutKind::CHW, 0, PoolKind::None, ConvolutionEngineKind::CuDnn);
auto testEng = ConvEng::Create(g, deviceId, ImageLayoutKind::CHW, 0, PoolKind::None, engKind);
// REVIEW alexeyk: Unroll engine supports CPU only for now.
if (engKind == ConvolutionEngineKind::Gemm && deviceId >= 0)
continue;

size_t n = batchSizeG(rng);
vec buf;
buf.resize(g->InputShape().GetNumElements() * n);
std::generate(begin(buf), end(buf), [&] { return nd(rng); });
SingleMatrix in(g->InputShape().GetNumElements(), n, buf.data(), deviceId, matrixFlagNormal);
SingleMatrix inB(g->InputShape().GetNumElements(), n, buf.data(), baseDeviceId, matrixFlagNormal);

size_t mapCount = g->GetMapCount(g->InputShape().GetRank() - 1);
buf.resize(g->KernelShape().GetNumElements() * mapCount);
std::generate(begin(buf), end(buf), [&] { return nd(rng); });
SingleMatrix kernel(mapCount, g->KernelShape().GetNumElements(), buf.data(), deviceId, matrixFlagNormal);
SingleMatrix kernelB(mapCount, g->KernelShape().GetNumElements(), buf.data(), baseDeviceId, matrixFlagNormal);
for (const auto& g : GenerateConvTestConfigs())
{
auto baseEng = ConvEng::Create(g, baseDeviceId, ImageLayoutKind::CHW, 0, PoolKind::None, ConvolutionEngineKind::CuDnn);
auto testEng = ConvEng::Create(g, deviceId, ImageLayoutKind::CHW, 0, PoolKind::None, engKind);

size_t n = batchSizeG(rng);
vec buf;
buf.resize(g->InputShape().GetNumElements() * n);
std::generate(begin(buf), end(buf), [&] { return nd(rng); });
SingleMatrix in(g->InputShape().GetNumElements(), n, buf.data(), deviceId, matrixFlagNormal);
SingleMatrix inB(g->InputShape().GetNumElements(), n, buf.data(), baseDeviceId, matrixFlagNormal);

size_t crowOut = g->OutputShape().GetNumElements();
SingleMatrix outBuf(deviceId);
SingleMatrix out = initMat(outBuf, crowOut, n, buf);
SingleMatrix outB(out.DeepClone(), baseDeviceId);
size_t mapCount = g->GetMapCount(g->InputShape().GetRank() - 1);
buf.resize(g->KernelShape().GetNumElements() * mapCount);
std::generate(begin(buf), end(buf), [&] { return nd(rng); });
SingleMatrix kernel(mapCount, g->KernelShape().GetNumElements(), buf.data(), deviceId, matrixFlagNormal);
SingleMatrix kernelB(mapCount, g->KernelShape().GetNumElements(), buf.data(), baseDeviceId, matrixFlagNormal);

SingleMatrix workspace(deviceId);
SingleMatrix workspaceB(baseDeviceId);

testEng->Forward(in, kernel, out, workspace);
baseEng->Forward(inB, kernelB, outB, workspaceB);

std::stringstream tmsg;
tmsg << "Geometry: " << (std::string)(*g) << ", Batch: " << n << ", Device: " << deviceId;
std::string msg = " are not equal, " + tmsg.str();
std::string msgNan = " has NaNs, " + tmsg.str();
std::string msgNotNan = " has buffer overflow/underflow, " + tmsg.str();
size_t crowOut = g->OutputShape().GetNumElements();
SingleMatrix outBuf(deviceId);
SingleMatrix out = initMat(outBuf, crowOut, n, buf);
SingleMatrix outB(out.DeepClone(), baseDeviceId);

float relErr = Err<float>::Rel;
float absErr = Err<float>::Abs;
std::string emsg;
SingleMatrix workspace(deviceId);
SingleMatrix workspaceB(baseDeviceId);

testEng->Forward(in, kernel, out, workspace);
baseEng->Forward(inB, kernelB, outB, workspaceB);

BOOST_REQUIRE_MESSAGE(!out.HasNan("out"), "out" << msgNan);
BOOST_REQUIRE_MESSAGE(CheckEqual(out, outB, emsg, relErr * 4, absErr * 8), "out" << msg << ". " << emsg);
BOOST_REQUIRE_MESSAGE(CountNans(outBuf) == crowOut * 2 * n, "out" << msgNotNan);
std::stringstream tmsg;
tmsg << "Geometry: " << (std::string)(*g) << ", Batch: " << n << ", Device: " << deviceId;
std::string msg = " are not equal, " + tmsg.str();
std::string msgNan = " has NaNs, " + tmsg.str();
std::string msgNotNan = " has buffer overflow/underflow, " + tmsg.str();

float relErr = Err<float>::Rel;
float absErr = Err<float>::Abs;
std::string emsg;

BOOST_REQUIRE_MESSAGE(!out.HasNan("out"), "out" << msgNan);
BOOST_REQUIRE_MESSAGE(CheckEqual(out, outB, emsg, relErr * 4, absErr * 8), "out" << msg << ". " << emsg);
BOOST_REQUIRE_MESSAGE(CountNans(outBuf) == crowOut * 2 * n, "out" << msgNotNan);
}
}
}
}
Expand Down

0 comments on commit dcb6654

Please sign in to comment.