Skip to content

Commit fd28c55

Browse files
committed
reviewed uses of atomicAdd(),;
temporarily disallowed Scatter() reductions
1 parent 80461f4 commit fd28c55

5 files changed

+73
-35
lines changed

Source/Math/Convolution.cuh

+1-1
Original file line numberDiff line numberDiff line change
@@ -269,4 +269,4 @@ __global__ void kAveragePoolingBackward(int batchSize, const int* mpRowCol, cons
269269
}
270270
}
271271

272-
} } }
272+
}}}

Source/Math/GPUMatrix.cu

+46-19
Original file line numberDiff line numberDiff line change
@@ -885,11 +885,11 @@ __global__ void _doGatherColumnsOf(ElemType* us, size_t usStride, const ElemType
885885
CUDA_LONG jOut = id / usStride; // col index into 'us' and 'idx'
886886

887887
auto jInF = idx[jOut * idxStride]; // this is the column we need to get
888-
if (jInF < 0) // negative index means gap
888+
if (isnan(jInF) || jInF < 0) // negative index means gap
889889
return;
890890
size_t jIn = (size_t)jInF;
891-
if (jIn >= aCols)
892-
return; // actually a failure
891+
//if (jIn >= aCols)
892+
// return; // actually a failure
893893

894894
const ElemType& ra = a[ i + jIn * aStride ];
895895
ElemType& rus = us[id/*i + jOut * usStride*/];
@@ -928,6 +928,21 @@ GPUMatrix<ElemType>& GPUMatrix<ElemType>::DoGatherColumnsOf(ElemType beta, const
928928
return *this;
929929
}
930930

931+
// little helper for debugging
932+
template <class ElemType>
933+
static void Peek(const GPUMatrix<ElemType>& m, const char* which)
934+
{
935+
size_t rows = m.GetNumRows();
936+
size_t cols = m.GetNumCols();
937+
ElemType buf[10000] = { 0 };
938+
size_t n = min(rows * cols, _countof(buf));
939+
CUDA_CALL(cudaMemcpy(buf, m.Data(), sizeof(ElemType) * n, cudaMemcpyDeviceToHost));
940+
UNUSED(which); UNUSED(rows); UNUSED(cols); sin(1.0f); // set breakpoint here
941+
//CUDA_CALL(cudaMemcpy(const_cast<ElemType*>(m.Data()), buf, sizeof(ElemType) * n, cudaMemcpyHostToDevice));
942+
}
943+
944+
#undef ALLOW_ATOMIC_SCATTER // allow to disable this, until we know atomicAdd() works properly here
945+
931946
template <class ElemType>
932947
__global__ void _doScatterColumnsOf(ElemType* us, size_t usStride, size_t usCols, const ElemType* idx, size_t idxStride, const ElemType* a, size_t aStride, const ElemType alpha, CUDA_LONG numElements)
933948
{
@@ -941,34 +956,25 @@ __global__ void _doScatterColumnsOf(ElemType* us, size_t usStride, size_t usCols
941956
CUDA_LONG jIn = id / aStride; // col index into 'a' and 'idx'
942957

943958
auto jOutF = idx[jIn * idxStride]; // this is the column we copy/add into
944-
if (jOutF < 0) // negative index means gap
959+
if (isnan(jOutF) || jOutF < 0) // negative index means gap
945960
return;
946961
size_t jOut = (size_t)jOutF;
947-
if (jOut >= usCols)
948-
return; // actually a failure --TODO: This should not be necessary. Why is it?
962+
//if (jOut >= usCols)
963+
// return; // actually a failure --TODO: This should not be necessary. Why is it?
949964

950965
const ElemType& ra = a[id/*i + jIn * aStride*/];
951966
ElemType& rus = us[ i + jOut * usStride ];
952967

953968
ElemType res = ra * alpha;
969+
#ifdef ALLOW_ATOMIC_SCATTER
954970
if (res != 0) // avoid memory conflict if e.g. an entire column has no gradient
955971
atomicAdd(&rus, res); // rus += res;
972+
#else
973+
rus += res;
974+
#endif
956975
// Note: atomicAdd() is supposed to be fast in case of no conflict (the simple case of Scatter())
957976
}
958977

959-
// little helper for debugging
960-
template <class ElemType>
961-
static void Peek(const GPUMatrix<ElemType>& m, const char* which)
962-
{
963-
size_t rows = m.GetNumRows();
964-
size_t cols = m.GetNumCols();
965-
ElemType buf[10000] = { 0 };
966-
size_t n = min(rows * cols, _countof(buf));
967-
CUDA_CALL(cudaMemcpy(buf, m.Data(), sizeof(ElemType) * n, cudaMemcpyDeviceToHost));
968-
UNUSED(which); UNUSED(rows); UNUSED(cols); sin(1.0f); // set breakpoint here
969-
//CUDA_CALL(cudaMemcpy(const_cast<ElemType*>(m.Data()), buf, sizeof(ElemType) * n, cudaMemcpyHostToDevice));
970-
}
971-
972978
// *this[:,idx[j]] = a[:,j] * alpha + *this[:,idx[j]] * beta
973979
template <class ElemType>
974980
GPUMatrix<ElemType>& GPUMatrix<ElemType>::DoScatterColumnsOf(ElemType beta, const GPUMatrix<ElemType>& idx, const GPUMatrix<ElemType>& a, ElemType alpha)
@@ -986,6 +992,27 @@ GPUMatrix<ElemType>& GPUMatrix<ElemType>::DoScatterColumnsOf(ElemType beta, cons
986992

987993
auto& us = *this;
988994

995+
#ifndef ALLOW_ATOMIC_SCATTER // verify that atomicAdd is not needed --this is not efficient
996+
{
997+
vector<ElemType> buf(idx.GetNumRows() * idx.GetNumCols()); // idx(,)are the column(s) we copy/add into
998+
CUDA_CALL(cudaMemcpy(buf.data(), idx.Data(), sizeof(ElemType) * buf.size(), cudaMemcpyDeviceToHost));
999+
vector<bool> writtenTo(GetNumCols(), false); // remember whether an output column is in fact a target
1000+
for (size_t i = 0; i < buf.size(); i++)
1001+
{
1002+
auto colF = buf[i];
1003+
if (isnan(colF) || colF < 0)
1004+
continue;
1005+
size_t col = (size_t)colF;
1006+
if (col >= GetNumCols())
1007+
LogicError("DoScatterColumnsOf: Index value out of bounds.");
1008+
if (writtenTo[col])
1009+
LogicError("DoScatterColumnsOf: #ifndef ALLOW_ATOMIC_SCATTER then columns must be unique. Column idx(%d,%d)=%d is used twice.", (int)(i % idx.GetNumCols()), (int)(i / idx.GetNumCols()), (int)col);
1010+
else
1011+
writtenTo[col] = true;
1012+
}
1013+
}
1014+
#endif
1015+
9891016
// pre-scale with beta upfront
9901017
// Scatter may add more than one source column to the same target, so we must pre-scale with beta, and then just keep adding.
9911018
Scale(beta, us); // if beta is 0, then this will be a memset()

Source/Math/GPUMatrixCUDAKernels.cuh

+7-1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
#define IDX2C(i, j, ld) (((j) * (ld)) + (i)) // 0 based indexing
4444

4545
// CUDA atomicAdd() only exists for 'float'. This is the 'double' version.
46+
// TODO: This may need to be guarded by CUDA version; newer devices may support this.
4647
static __inline__ __device__ double atomicAdd(double* address, double val)
4748
{
4849
unsigned long long int* address_as_ull = (unsigned long long int*) address;
@@ -3152,7 +3153,8 @@ __global__ void _scaleSparseBlockAndAddToDense(
31523153
rhs[IDX2C(row, col, numRows)] += alpha * lhsValues[index];
31533154
}
31543155

3155-
// compute predictions in cross entory node
3156+
#if 0
3157+
// compute predictions in cross entropy node
31563158
template <class ElemType>
31573159
__global__ void _computePrediction(
31583160
int nv,
@@ -3335,6 +3337,7 @@ __global__ void _computeGradientOfInput(
33353337

33363338
atomicAdd(&grd[IDX2C(h, j, numrows)], sum);
33373339
}
3340+
#endif
33383341

33393342
template <class ElemType>
33403343
__global__ void computeNCEForwardProp(
@@ -3713,6 +3716,8 @@ __global__ void _assignNceDerivativeNew(
37133716
atomicAdd(&c[wid], -er);
37143717
}
37153718
}
3719+
3720+
#if 0
37163721
// compute gradients of weights in cross entropy node
37173722
template <class ElemType>
37183723
__global__ void _computeGradientOfWeight(
@@ -3774,6 +3779,7 @@ __global__ void _computeGradientOfWeight(
37743779
blockIds[ii] = i;
37753780
}
37763781
}
3782+
#endif
37773783

37783784
// used in clipping gradients
37793785
template <class ElemType>

Source/Math/GPUTensor.cu

+12-4
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,8 @@ struct TensorOpElement<ElemType, N, M, K, /*parallelReduce=*/false, /*k=*/-1>
393393
}
394394
};
395395

396+
#define ALLOW_ATOMIC_REDUCTION // undefine to disable use of atomicAdd() below, for testing it
397+
396398
// specialization for k = -1 terminates the template recursion, and computes reductions in parallel
397399
template <class ElemType, C_size_t N, C_int M, C_int K>
398400
struct TensorOpElement<ElemType, N, M, K, /*parallelReduce=*/true, /*k=*/-1>
@@ -403,8 +405,8 @@ struct TensorOpElement<ElemType, N, M, K, /*parallelReduce=*/true, /*k=*/-1>
403405
const FixedArray<C_unsigned_int, K>& /*regularOpStrides*/, const FixedMatrix<C_int, N, K>& /*regularStrides*/,
404406
const FixedArray<C_unsigned_int, M>& reducingOpDims, const FixedMatrix<C_int, N, M>& reducingStrides, CUDA_LONG reductionBegin, CUDA_LONG reductionChunkSize)
405407
{
406-
CUDA_LONG reductionBlock = blockIdx.z; // block index --larger reductions are split into blocks
407-
CUDA_LONG reductionBlocks = gridDim.z; // number of blocks
408+
CUDA_LONG reductionBlock = blockIdx.z; // reduction-block index --larger reductions are split into blocks
409+
CUDA_LONG reductionBlocks = gridDim.z; // number of reduction blocks. If >1 we need atomicAdd
408410
CUDA_LONG tid = threadIdx.x; // thread index
409411
CUDA_LONG tids = blockDim.x; // out of how many threads --note: last block is partial
410412

@@ -427,7 +429,7 @@ struct TensorOpElement<ElemType, N, M, K, /*parallelReduce=*/true, /*k=*/-1>
427429
}
428430

429431
// reduce --cf https://docs.nvidia.com/cuda/samples/6_Advanced/reduction/doc/reduction.pdf
430-
__shared__ ReduceElemType accumulators[GridDim::maxThreadsPerBlock /*tids*/];
432+
__shared__ ReduceElemType volatile accumulators[GridDim::maxThreadsPerBlock /*tids*/];
431433
accumulators[tid] = sum;
432434
__syncthreads();
433435
static_assert(GridDim::maxThreadsPerBlock <= 512, "GridDim::maxThreadsPerBlock too large, need to add manually unrolled steps");
@@ -450,8 +452,12 @@ struct TensorOpElement<ElemType, N, M, K, /*parallelReduce=*/true, /*k=*/-1>
450452
auto* pout = pointers[pointers.size() - 1];
451453
if (reductionBlocks > 1) // multiple blocks: need to use atomicAdd()
452454
{
455+
#ifdef ALLOW_ATOMIC_REDUCTION
453456
// in this case, outer calling code must pass beta = 1
454457
atomicAdd(pout, val);
458+
#else
459+
*pout = 1000000.0f; // something that can't be missed? How to crash it?
460+
#endif
455461
}
456462
else
457463
{
@@ -560,8 +566,9 @@ static void LaunchTensorOpWithReduction(ElemType beta, array<ElemType*, N> point
560566
C_size_t reductionDim = 1; // number of elements to reduce over
561567
for (C_size_t k = 0; k < reducingOpDimVector.size(); k++)
562568
reductionDim *= (C_size_t) reducingOpDimVector[k];
563-
let& props = GridDim::GetDeviceProps();
564569
GridDim grid(NN);
570+
#ifdef ALLOW_ATOMIC_REDUCTION // temporarily disabled to ensure it is not causing the non-reproducability
571+
let& props = GridDim::GetDeviceProps();
565572
if (reductionDim > 1 && grid.m_blocksPerGrid < props.multiProcessorCount) // TODO: <= multiProcessorCount?
566573
{
567574
// we are reducing and are underutilizing the multiprocs we have: get more parallelism by doing reduction in parallel
@@ -603,6 +610,7 @@ static void LaunchTensorOpWithReduction(ElemType beta, array<ElemType*, N> point
603610
}
604611
}
605612
else
613+
#endif
606614
{
607615
// we got enough elements to generate: do one element per thread, and reduction inside
608616
_launchTensorOp<ElemType, N, M, K><<<grid.m_blocksPerGrid, grid.m_threadsPerBlock, 0, t_stream>>>(beta, pointers, alpha, op, regularOpStrides, regularStrides, grid.m_N, reducingOpDims, reducingStrides);

Source/Math/latticefunctionskernels.h

+7-10
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,7 @@
1616
#include "latticestorage.h"
1717
#include <limits>
1818

19-
namespace msra { namespace cuda {
20-
21-
class passtextureref;
22-
}
23-
}
19+
namespace msra { namespace cuda { class passtextureref; } }
2420

2521
#ifdef CPUONLY
2622
#define __kernel_emulation__
@@ -34,7 +30,8 @@ using namespace std;
3430
#define __device__
3531
#endif
3632
#define CUDART_MIN_DENORM_F numeric_limits<float>::denorm_min()
37-
#define atomicAdd(address, value) (*(address) += (value)) // don't forget to #undef (#praga pop_macro)! Otherwise CUDA might compile with this...
33+
// renamed to x- so we make sure to not accidentally use these; rename back if ever needed again
34+
#define xatomicAdd(address, value) (*(address) += (value)) // don't forget to #undef (#praga pop_macro)! Otherwise CUDA might compile with this...
3835
#define atomicCAS(address, compare, val) \
3936
*address; \
4037
*address = *address == compare ? val : *address;
@@ -47,8 +44,8 @@ using namespace std;
4744
#if __CUDA_ARCH__ < 200
4845
//#warning Sequence training not supported on 1.x CUDA machines.
4946
#define force_crash() (*((int *) -1) = 0) // TODO: this does not in fact seem to crash it...
50-
#define atomicAdd(a, v) (force_crash(), *(a) = v) // force a crash if used with 1.x devices
51-
#define atomicCAS(address, compare, val) (*(address) = compare + val, *((int *) -1) = 0)
47+
#define xatomicAdd(a, v) (force_crash(), *(a) = v) // force a crash if used with 1.x devices
48+
#define xatomicCAS(address, compare, val) (*(address) = compare + val, *((int *) -1) = 0)
5249
#define __double_as_longlong(in) (force_crash(), in)
5350
#define __longlong_as_double(in) (force_crash(), in)
5451
#define __float_as_int(in) (force_crash(), in)
@@ -956,8 +953,8 @@ struct latticefunctionskernels
956953
}
957954
}
958955
};
959-
};
960-
};
956+
957+
}};
961958

962959
#pragma pop_macro("atomicCAS")
963960
#pragma pop_macro("atomicAdd")

0 commit comments

Comments
 (0)