Skip to content

Commit

Permalink
Addressed review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexey Kamenev committed Apr 21, 2016
1 parent c065733 commit c3a0f56
Showing 1 changed file with 8 additions and 12 deletions.
20 changes: 8 additions & 12 deletions Source/Math/ConvolutionEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -549,9 +549,9 @@ class GemmConvolutionEngine : public ReferenceConvolutionEngine<ElemType>
void EnsureCompatible() override
{
if (m_imageLayout != ImageLayoutKind::CHW)
RuntimeError("GEMM convolution engine supports only CHW/cudnn layout.");
LogicError("GEMM convolution engine supports only CHW/cudnn layout.");
if (IsGpu(m_deviceId))
RuntimeError("GEMM convolution engine currently supports only CPU device.");
LogicError("GEMM convolution engine currently supports only CPU device.");
}

// A note on notation used in the documentation for the next 3 functions:
Expand All @@ -566,8 +566,8 @@ class GemmConvolutionEngine : public ReferenceConvolutionEngine<ElemType>
//
// The forward method consists of 3 parts:
// 1. Unrolling convolution input (in) into a matrix: [WHC x N] -> [XYC x NW'H']
// Using this format allows to perform convolution for the whole minibatch as a single GEMM
// which is not possible with NCHW format. Alternatively, NHWC format (used in legacy engine) could be used
// Using this format allows to perform convolution for the whole minibatch as a single GEMM operation
// which is not possible with WHCN format. Alternatively, CWHN format (used in legacy engine) could be used
// but this would require both unrolling the input and transforming the weight matrix.
// 2. Performing matrix multiplication of unrolled input with weight matrix:
// [XYC x NW'H']^T * [XYC x K] -> [NW'H' x K]
Expand Down Expand Up @@ -634,13 +634,10 @@ class GemmConvolutionEngine : public ReferenceConvolutionEngine<ElemType>
}

// The backward data method works by representing this operation as a "reverse" convolution
// in case kernel's last dimension is equal to input dimension. Gradients (grad) become
// in case kernel's last dimension is equal to input dimension. Gradients matrix (grad) becomes
// an output of such reverse convolution.
// In this case, kernel matrix will have dimensions/layout of:
// [C x HWK] (row-major notation) and can be GEMM-ed with appropriately unrolled output (srcGrad in this case).
// Each row of the unrolled output will be of size/layout HWK.
// It consists of 4 steps:
// 1. Transpose and reshape kernel weights: [XYC x K]^T -> [KXY x C]
// There are 4 steps:
// 1. Transpose and reshape kernel weights: [XYC x K]^T -> [K x XYC] -> [KXY x C]
// 2. Unroll convolution output (here source gradients, srcGrad):
// [W'H'K' x N] -> [KXY x NWH]
// 3. Performing matrix multiplication of unrolled scrGrad with transposed weights:
Expand Down Expand Up @@ -669,7 +666,6 @@ class GemmConvolutionEngine : public ReferenceConvolutionEngine<ElemType>
size_t mapInSize = inT.GetNumElements() / mapInCount;

size_t unrollRows = mapInSize * subBatchSize;
// Original kernel matrix in KCHW [K x CHW] format will be transposed to CHWK [C x HWK].
size_t unrollCols = kernel.GetNumElements() / mapInCount;

// Reserve space for:
Expand All @@ -684,7 +680,7 @@ class GemmConvolutionEngine : public ReferenceConvolutionEngine<ElemType>
auto kern = kernel.ColumnSlice(0, kernel.GetNumCols());
// cudnn layout uses row-major kernel weight matrix.
kern.Reshape(kernel.GetNumCols(), kernel.GetNumRows());
// Now transpose and reshape to [C x HWK] (row-major).
// Now transpose and reshape to [KXY x C].
auto kernTran = workspace.ColumnSlice(0, kernCols);
// Reshape to transpose shape, AssignTransposeOf requires that.
kernTran.Reshape(kern.GetNumCols(), kern.GetNumRows());
Expand Down

0 comments on commit c3a0f56

Please sign in to comment.