Skip to content

Commit

Permalink
renamed reduction kernels that expect a hard-coded number of threads …
Browse files Browse the repository at this point in the history
…to reflect that number in their names
  • Loading branch information
frankseide committed Jul 27, 2016
1 parent 9dbf806 commit 8f71698
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 55 deletions.
8 changes: 4 additions & 4 deletions Source/Math/CPUMatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5788,14 +5788,14 @@ void CPUMatrix<ElemType>::RCRFBackwardCompute(const CPUMatrix<ElemType>& alpha,
#pragma omp parallel for
for (int k = 0; k < iNumLab; k++)
{
_rcrfBackwardCompute(t, k, alpha, beta, pair_scores);
_rcrfBackwardCompute1024Threads(t, k, alpha, beta, pair_scores);
}
}
};

/// the kernel function for RCRF backward computation
template <class ElemType>
void CPUMatrix<ElemType>::_rcrfBackwardCompute(size_t t, size_t k, const CPUMatrix<ElemType>& alpha,
void CPUMatrix<ElemType>::_rcrfBackwardCompute1024Threads(size_t t, size_t k, const CPUMatrix<ElemType>& alpha,
CPUMatrix<ElemType>& beta,
const CPUMatrix<ElemType>& pair_scores)
{
Expand Down Expand Up @@ -5859,7 +5859,7 @@ void CPUMatrix<ElemType>::RCRFTransGrdCompute(const CPUMatrix<ElemType>& lbls,
#pragma omp parallel for
for (int i = 0; i < iNumLab; i++)
{
_rcrfTransGrdCompute(i, lbls, alpha, beta, pair_scores, grd, tPos);
_rcrfTransGrdCompute1024Threads(i, lbls, alpha, beta, pair_scores, grd, tPos);
}

// transition score
Expand Down Expand Up @@ -5891,7 +5891,7 @@ void CPUMatrix<ElemType>::RCRFTransGrdCompute(const CPUMatrix<ElemType>& lbls,
};

template <class ElemType>
void CPUMatrix<ElemType>::_rcrfTransGrdCompute(size_t i,
void CPUMatrix<ElemType>::_rcrfTransGrdCompute1024Threads(size_t i,
const CPUMatrix<ElemType>& lbls,
const CPUMatrix<ElemType>& alpha,
const CPUMatrix<ElemType>& beta,
Expand Down
4 changes: 2 additions & 2 deletions Source/Math/CPUMatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ class MATH_API CPUMatrix : public BaseMatrix<ElemType>
static void RCRFBackwardCompute(const CPUMatrix<ElemType>& alpha, CPUMatrix<ElemType>& beta,
const CPUMatrix<ElemType>& lbls,
const CPUMatrix<ElemType>& pair_scores);
static void _rcrfBackwardCompute(size_t t, size_t k, const CPUMatrix<ElemType>& alpha,
static void _rcrfBackwardCompute1024Threads(size_t t, size_t k, const CPUMatrix<ElemType>& alpha,
CPUMatrix<ElemType>& beta,
const CPUMatrix<ElemType>& pair_scores);

Expand All @@ -496,7 +496,7 @@ class MATH_API CPUMatrix : public BaseMatrix<ElemType>
const CPUMatrix<ElemType>& pair_scores,
CPUMatrix<ElemType>& grd);

static void _rcrfTransGrdCompute(size_t i,
static void _rcrfTransGrdCompute1024Threads(size_t i,
const CPUMatrix<ElemType>& lbls,
const CPUMatrix<ElemType>& alpha,
const CPUMatrix<ElemType>& beta,
Expand Down
64 changes: 42 additions & 22 deletions Source/Math/GPUMatrix.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1949,7 +1949,8 @@ void GPUMatrix<ElemType>::AssignNoiseContrastiveEstimation(const GPUMatrix<ElemT
while (p / 2 > width)
p = p / 2;

_computeNceOutput<ElemType><<<GetNumElements() / 2, p>>>(
// note: kernel has hard-coded dimension of 512
_computeNceOutput512Threads<ElemType> << <GetNumElements() / 2, p >> >(
Data(),
sampleCount,
m_numRows / 2,
Expand All @@ -1963,7 +1964,8 @@ void GPUMatrix<ElemType>::AssignNoiseContrastiveEstimation(const GPUMatrix<ElemT
while (p / 2 > GetNumElements() / 2)
p = p / 2;
// summing up objective must be done in one block
_assignNoiseContrastiveEstimation<ElemType><<<1, p>>>(
// note: kernel has hard-coded dimension of 512
_assignNoiseContrastiveEstimation512Threads<ElemType> << <1, p >> >(
Data(),
sampleCount,
m_numRows / 2,
Expand Down Expand Up @@ -2008,7 +2010,8 @@ void GPUMatrix<ElemType>::AssignSoftmaxSum(const GPUMatrix<ElemType>& a, GPUMatr
while (p / 2 > width)
p = p / 2;

_assignSoftmaxSum<ElemType><<<1, p>>>(
// note: kernel has hard-coded dimension of 512
_assignSoftmaxSum512Threads<ElemType> << <1, p >> >(
my_a.Data(),
width,
Data(),
Expand Down Expand Up @@ -2084,7 +2087,8 @@ GPUMatrix<ElemType>& GPUMatrix<ElemType>::AssignLogSoftmaxOf(const GPUMatrix<Ele
CUDA_LONG N = (CUDA_LONG) GetNumCols();
CUDA_LONG M = (CUDA_LONG) GetNumRows();
SyncGuard syncGuard;
_assignColumnwiseLogSoftmaxOf<<<N, 512, 0, t_stream>>>(a.Data(), Data(), N, M);
// note: kernel uses hard-coded thread dimension
_assignColumnwiseLogSoftmaxOf512Threads<<<N, 512, 0, t_stream>>>(a.Data(), Data(), N, M);
}
else
{
Expand All @@ -2110,7 +2114,8 @@ GPUMatrix<ElemType>& GPUMatrix<ElemType>::AssignHardmaxOf(const GPUMatrix<ElemTy
CUDA_LONG N = (CUDA_LONG) GetNumCols();
CUDA_LONG M = (CUDA_LONG) GetNumRows();
SyncGuard syncGuard;
_assignColumnwiseHardmaxOf<<<N, 512, 0, t_stream>>>(a.Data(), Data(), N, M);
// note: kernel uses hard-coded thread dimension
_assignColumnwiseHardmaxOf512Threads << <N, 512, 0, t_stream >> >(a.Data(), Data(), N, M);
}
else
{
Expand Down Expand Up @@ -2262,7 +2267,8 @@ ElemType GPUMatrix<ElemType>::SumOfElements() const
ElemType h_sum;

// WARNING: THIS kernel is not the most efficient way!
_reductionSum<ElemType><<<1, 1024, 0, t_stream>>>(Data(), d_sum, (CUDA_LONG) GetNumElements());
// note: kernel has hard-coded dimension of 1024
_reductionSum1024Threads<ElemType> << <1, 1024, 0, t_stream >> >(Data(), d_sum, (CUDA_LONG)GetNumElements());
CUDA_CALL(cudaMemcpy(&h_sum, d_sum, sizeof(ElemType), cudaMemcpyDeviceToHost));
TracingGPUMemoryAllocator::Free<ElemType>(GetComputeDeviceId(), d_sum);
return h_sum;
Expand All @@ -2279,7 +2285,8 @@ GPUMatrix<ElemType>& GPUMatrix<ElemType>::AssignSumOfElements(const GPUMatrix<El
PrepareDevice();
SyncGuard syncGuard;
// WARNING: THIS kernel is not the most efficient way!
_reductionSumAndAssign<ElemType><<<1, 1024>>>(Data(), a.Data(), (CUDA_LONG) a.GetNumElements(), (CUDA_LONG) GetNumElements());
// note: kernel has hard-coded dimension of 1024
_reductionSumAndAssign1024Threads<ElemType> << <1, 1024 >> >(Data(), a.Data(), (CUDA_LONG)a.GetNumElements(), (CUDA_LONG)GetNumElements());
return (*this);
}

Expand All @@ -2291,7 +2298,8 @@ DeviceBoundNumber<ElemType> GPUMatrix<ElemType>::Sum_AsDeviceBoundNum() const
ElemType* d_sum = TracingGPUMemoryAllocator::Allocate<ElemType>(GetComputeDeviceId(), 1);

// WARNING: THIS kernel is not the most efficient way!
_reductionSum<ElemType><<<1, 1024, 0, t_stream>>>(Data(), d_sum, (CUDA_LONG) GetNumElements());
// note: kernel has hard-coded dimension of 1024
_reductionSum1024Threads<ElemType> << <1, 1024, 0, t_stream >> >(Data(), d_sum, (CUDA_LONG)GetNumElements());
DeviceBoundNumber<ElemType> result;
result.ShallowCopyFrom(d_sum, GetComputeDeviceId());
return result;
Expand Down Expand Up @@ -2593,7 +2601,8 @@ ElemType GPUMatrix<ElemType>::FrobeniusNorm() const

ElemType h_sum = 0;
// WARNING: THIS kernel is not the most efficient way!
_reductionSum2<ElemType><<<1, 1024, 0, t_stream>>>(Data(), d_sum, (CUDA_LONG) GetNumElements(), true);
// note: kernel has hard-coded dimension of 1024
_reductionSum21024Threads<ElemType> << <1, 1024, 0, t_stream >> >(Data(), d_sum, (CUDA_LONG)GetNumElements(), true);
CUDA_CALL(cudaMemcpy(&h_sum, d_sum, sizeof(ElemType), cudaMemcpyDeviceToHost));
TracingGPUMemoryAllocator::Free<ElemType>(GetComputeDeviceId(), d_sum);

Expand All @@ -2610,7 +2619,8 @@ GPUMatrix<ElemType>& GPUMatrix<ElemType>::AssignFrobeniusNormOf(const GPUMatrix<

PrepareDevice();
// WARNING: THIS kernel is not the most efficient way!
_reductionSum2<ElemType><<<1, 1024, 0, t_stream>>>(a.Data(), Data(), (CUDA_LONG) a.GetNumElements(), true);
// note: kernel has hard-coded dimension of 1024
_reductionSum21024Threads<ElemType> << <1, 1024, 0, t_stream >> >(a.Data(), Data(), (CUDA_LONG)a.GetNumElements(), true);

return *this;
}
Expand All @@ -2619,13 +2629,14 @@ template <class ElemType>
ElemType GPUMatrix<ElemType>::MatrixNormInf() const
{
if (IsEmpty())
LogicError("MatrixNorm1: Matrix is empty.");
LogicError("MatrixNormInf: Matrix is empty.");

ElemType* d_maxAbs = TracingGPUMemoryAllocator::Allocate<ElemType>(GetComputeDeviceId(), 1);

ElemType h_maxAbs = 0;
// WARNING: THIS kernel is not the most efficient way!
_reductionMatrixNormInf<ElemType><<<1, 1024, 0, t_stream>>>(Data(), d_maxAbs, (CUDA_LONG) GetNumElements());
// note: kernel has hard-coded dimension of 1024
_reductionMatrixNormInf1024Threads<ElemType> << <1, 1024, 0, t_stream >> >(Data(), d_maxAbs, (CUDA_LONG)GetNumElements());
CUDA_CALL(cudaMemcpy(&h_maxAbs, d_maxAbs, sizeof(ElemType), cudaMemcpyDeviceToHost));
TracingGPUMemoryAllocator::Free<ElemType>(GetComputeDeviceId(), d_maxAbs);
return h_maxAbs;
Expand All @@ -2648,7 +2659,8 @@ ElemType GPUMatrix<ElemType>::MatrixNorm0() const
ElemType* d_nz = TracingGPUMemoryAllocator::Allocate<ElemType>(GetComputeDeviceId(), 1);
ElemType h_nz = 0;
// WARNING: THIS kernel is not the most efficient way!
_reductionMatrixNorm0<ElemType><<<1, 1024, 0, t_stream>>>(Data(), d_nz, (CUDA_LONG) GetNumElements());
// note: kernel has hard-coded dimension of 1024
_reductionMatrixNorm01024Threads<ElemType> << <1, 1024, 0, t_stream >> >(Data(), d_nz, (CUDA_LONG)GetNumElements());
CUDA_CALL(cudaMemcpy(&h_nz, d_nz, sizeof(ElemType), cudaMemcpyDeviceToHost));
TracingGPUMemoryAllocator::Free<ElemType>(GetComputeDeviceId(), d_nz);
return h_nz;
Expand Down Expand Up @@ -2705,7 +2717,8 @@ void GPUMatrix<ElemType>::VectorMax(GPUMatrix<ElemType>& maxIndexes, GPUMatrix<E
maxIndexes.RequireSize(1, n);

int blocksPerGrid = n; // we'll have 1 block processing 1 column
_vectorMaxMinReduce<ElemType, true><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, t_stream>>>(us.Data(), maxIndexes.Data(), maxValues.Data(), m, n);
// note: kernel has hard-coded dimension of 512
_vectorMaxMinReduce512Threads<ElemType, true><<<blocksPerGrid, 512, 0, t_stream>>>(us.Data(), maxIndexes.Data(), maxValues.Data(), m, n);

/*int blocksPerGrid=(int)ceil(1.0*n/GridDim::maxThreadsPerBlock);
_vectorMax<ElemType><<<blocksPerGrid,GridDim::maxThreadsPerBlock,0,t_stream>>>(us.Data(),maxIndexes.Data(),maxValues.Data(),m,n,isColWise);*/
Expand Down Expand Up @@ -2831,7 +2844,8 @@ void GPUMatrix<ElemType>::VectorMin(GPUMatrix<ElemType>& minIndexes, GPUMatrix<E
minIndexes.RequireSize(1, n);

int blocksPerGrid = n; // we'll have 1 block processing 1 column
_vectorMaxMinReduce<ElemType, false><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0, t_stream>>>(us.Data(), minIndexes.Data(), minValues.Data(), m, n);
// note: kernel has hard-coded dimension of 512
_vectorMaxMinReduce512Threads<ElemType, false> << <blocksPerGrid, 512, 0, t_stream >> >(us.Data(), minIndexes.Data(), minValues.Data(), m, n);

/*
int blocksPerGrid=(int)ceil(1.0*n/GridDim::maxThreadsPerBlock);
Expand Down Expand Up @@ -2861,8 +2875,9 @@ GPUMatrix<ElemType>& GPUMatrix<ElemType>::AssignNumOfDiff(const GPUMatrix<ElemTy
if (!searchInCol)
{
// int blocksPerGrid=(int)ceil(1.0*a.GetNumElements()/GridDim::maxThreadsPerBlock);
// _assignNumOfDiff<ElemType><<<blocksPerGrid,GridDim::maxThreadsPerBlock,0,t_stream>>>(a.Data(), b.Data(), Data(), a.GetNumElements());
_assignNumOfDiff<ElemType><<<1, 1024, 0, t_stream>>>(a.Data(), b.Data(), Data(), (CUDA_LONG) a.GetNumElements());
// _assignNumOfDiff1024Threads<ElemType><<<blocksPerGrid,GridDim::maxThreadsPerBlock,0,t_stream>>>(a.Data(), b.Data(), Data(), a.GetNumElements());
// note: kernel has hard-coded dimension of 1024
_assignNumOfDiff1024Threads<ElemType> << <1, 1024, 0, t_stream >> >(a.Data(), b.Data(), Data(), (CUDA_LONG)a.GetNumElements());
}
else
{
Expand Down Expand Up @@ -4047,7 +4062,8 @@ ElemType GPUMatrix<ElemType>::GetLearnRateForBlock_Helper(const GPUMatrix<ElemTy
}
// d_res[0] should now contain inner product of matrices
// Compute squared Frobenius norms (squared sums of elements)
_lrHelper<ElemType><<<1, 512, 0, t_stream>>>(Gradients.Data(), SmoothedGradients.Data(), (CUDA_LONG) Gradients.GetNumElements(), d_res);
// note: kernel has hard-coded dimension of 512
_lrHelper512Threads<ElemType> << <1, 512, 0, t_stream >> >(Gradients.Data(), SmoothedGradients.Data(), (CUDA_LONG)Gradients.GetNumElements(), d_res);
ElemType res;
CUDA_CALL(cudaMemcpy(&res, d_res, sizeof(ElemType), cudaMemcpyDeviceToHost));
TracingGPUMemoryAllocator::Free<ElemType>(Gradients.GetComputeDeviceId(), d_res);
Expand Down Expand Up @@ -4276,10 +4292,12 @@ void GPUMatrix<ElemType>::RCRFBackwardCompute(
for (int t = iNumPos - 1; t >= 0; t--)
{
szMemSize = sizeof(ElemType) * iNumLab;
_rcrfBackwardComputeZeta<ElemType><<<blocksPerGrid, GridDim::maxThreadsPerBlock, szMemSize>>>(t, iNumPos, alpha.Data(), d_zeta, pair_scores.Data(), iNumLab, shift);
// note: kernel has hard-coded dimension of 1024
_rcrfBackwardComputeZeta1024Threads<ElemType> << <blocksPerGrid, GridDim::maxThreadsPerBlock, szMemSize >> >(t, iNumPos, alpha.Data(), d_zeta, pair_scores.Data(), iNumLab, shift);
szMemSize = iNumLab * 3;
szMemSize *= sizeof(ElemType);
_rcrfBackwardCompute<ElemType><<<blocksPerGrid, GridDim::maxThreadsPerBlock, szMemSize>>>(t, iNumPos, alpha.Data(), beta.Data(),
// note: kernel has hard-coded dimension of 1024
_rcrfBackwardCompute1024Threads<ElemType> << <blocksPerGrid, GridDim::maxThreadsPerBlock, szMemSize >> >(t, iNumPos, alpha.Data(), beta.Data(),
d_zeta, pair_scores.Data(), iNumLab, shift);
}
/*
Expand Down Expand Up @@ -4317,10 +4335,12 @@ void GPUMatrix<ElemType>::RCRFTransGrdCompute(const GPUMatrix<ElemType>& lbls,
for (int t = 0; t < iNumPos; t++)
{
szMemSize = sizeof(ElemType) * iNumLab;
_rcrfTransGrdComputeZeta<ElemType><<<blocksPerGrid, GridDim::maxThreadsPerBlock, szMemSize>>>(t - 1, iNumPos, alpha.Data(), d_zeta, pair_scores.Data(), iNumLab, startLbl, shift);
// note: kernel has hard-coded dimension of 1024
_rcrfTransGrdComputeZeta<ElemType> << <blocksPerGrid, GridDim::maxThreadsPerBlock, szMemSize >> >(t - 1, iNumPos, alpha.Data(), d_zeta, pair_scores.Data(), iNumLab, startLbl, shift);
szMemSize = iNumLab * 3;
szMemSize *= sizeof(ElemType);
_rcrfTransGrdCompute<ElemType><<<blocksPerGrid, GridDim::maxThreadsPerBlock, szMemSize>>>(t, startLbl, alpha.Data(), beta.Data(),
// note: kernel has hard-coded dimension of 1024
_rcrfTransGrdCompute1024Threads<ElemType> << <blocksPerGrid, GridDim::maxThreadsPerBlock, szMemSize >> >(t, startLbl, alpha.Data(), beta.Data(),
d_zeta, pair_scores.Data(), lbls.Data(), grd.Data(), iNumPos, iNumLab, shift);
}
TracingGPUMemoryAllocator::Free<ElemType>(alpha.GetComputeDeviceId(), d_zeta);
Expand Down
Loading

0 comments on commit 8f71698

Please sign in to comment.