From c691fc6dc711814a06107d4a9b763f34bff5afca Mon Sep 17 00:00:00 2001 From: Christian Sarofeen Date: Sun, 2 Jul 2017 21:39:40 -0700 Subject: [PATCH] Add a nonContigDim reduction kernel to improve latency for small tensors. (#768) --- THCReduce.cuh | 132 +++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 126 insertions(+), 6 deletions(-) diff --git a/THCReduce.cuh b/THCReduce.cuh index 067d796a76be62..b7df49b9c13d88 100644 --- a/THCReduce.cuh +++ b/THCReduce.cuh @@ -20,6 +20,99 @@ __device__ __forceinline__ IndexType getReduceNoncontigDimSliceIndex() { return getLinearBlockId() * THC_NONCONTIG_REDUCE_BLOCK_SIZE + threadIdx.x; } +// Kernel that handles an entire reduction of a slice of a tensor per each thread +template +#if __CUDA_ARCH__ >= 350 +__launch_bounds__(32 * 16, 4) +#endif +__global__ void +kernelReduceNoncontigDim_shared(TensorInfo out, + TensorInfo in, + IndexType reductionStride, + IndexType reductionSize, + IndexType totalSlices, + T init, + ModifyOp modifyOp, + ReduceOp reduceOp) { + + IndexType sliceIndex = blockIdx.x * blockDim.x + threadIdx.x; + IndexType sliceStride = gridDim.x * blockDim.x; + + __shared__ T local_reduce[THC_NONCONTIG_REDUCE_BLOCK_SIZE]; + T* shmem = &local_reduce[threadIdx.x + threadIdx.y * blockDim.x]; + T load_reg[4]; + T local_reg; + + for(;sliceIndex::get(sliceIndex, out); + const IndexType inOffset = + IndexToOffset::get(sliceIndex, in); + + //Unroll this loop + //for(IndexType i=threadIdx.y; i 1){ + __syncthreads(); + if( threadIdx.y == 0 && (dimy%2 != 0) ){ + *shmem = reduceOp(*shmem, *(shmem + (dimy-1) * blockDim.x) ); + } + if(threadIdx.y < dimy/2){ + *shmem = reduceOp(*shmem, *(shmem + (dimy/2)*blockDim.x) ); + } + dimy /= 2; + } + if(threadIdx.y == 0) + out.data[outOffset] = *shmem; + } +} + + // Kernel that handles an entire reduction of a slice of a tensor per each thread template 1){ + block.x /= 2; + block.y *= 2; + ydim /= 2; + } + THC_getGridFromTiles(THCCeilDiv(outElements, (long)block.x), grid); + + } + } // Resize out to correspond to the reduced size THLongStorage* sizes = TensorUtils::newSizeOf(state, in); THLongStorage_set(sizes, dim, 1); @@ -231,12 +342,21 @@ bool THC_reduceDim(THCState* state, outInfo, inInfo, reductionSize, \ (TYPE) outElements, init, modifyOp, reduceOp); \ } else { \ - kernelReduceNoncontigDim::DataType, \ - TYPE, OUT, IN> \ - <<>>( \ - outInfo, inInfo, reductionStride, reductionSize, \ + if(block.y == 1){ \ + kernelReduceNoncontigDim::DataType, \ + TYPE, OUT, IN> \ + <<>>( \ + outInfo, inInfo, reductionStride, reductionSize, \ (TYPE) outElements, init, modifyOp, reduceOp); \ + }else{ \ + kernelReduceNoncontigDim_shared::DataType, \ + TYPE, OUT, IN> \ + <<>>( \ + outInfo, inInfo, reductionStride, reductionSize, \ + (TYPE) outElements, init, modifyOp, reduceOp); \ + } \ } \ #define HANDLE_IN_CASE(TYPE, OUT, IN) \