Skip to content

Commit

Permalink
Fix a strange issue in NCEDerivative function, I do not know why the …
Browse files Browse the repository at this point in the history
…implementation is wrong, But I use a new and correct one
  • Loading branch information
zhaoyukoon committed Jul 10, 2015
1 parent 4684e13 commit 6168015
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 8 deletions.
6 changes: 4 additions & 2 deletions Math/Math/GPUMatrix.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1926,7 +1926,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
my_b.GetArray(),
tmp.GetArray(),
c.GetArray());

if (do_sync) CUDA_CALL(cudaEventRecord(done));
if (do_sync) CUDA_CALL(cudaEventSynchronize(done));
if (do_sync) CUDA_CALL(cudaEventDestroy(done));
Expand All @@ -1943,7 +1943,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
int p = 512;
int width = a.GetNumRows();
while (p / 2 > width) p = p / 2;
_assignNceDerivative<ElemType> << <this->GetNumElements() / 2, p >> >(

_assignNceDerivativeNew<ElemType> << < (tmp.GetNumElements() + p - 1) / p, p >> >(
GetArray(),
tmp.GetNumCols(),
m_numRows / 2,
Expand All @@ -1953,6 +1954,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
tmp.GetArray(),
c.GetArray(),
inputIndex);

if (do_sync) CUDA_CALL(cudaEventRecord(done));
if (do_sync) CUDA_CALL(cudaEventSynchronize(done));
if (do_sync) CUDA_CALL(cudaEventDestroy(done));
Expand Down
81 changes: 80 additions & 1 deletion Math/Math/GPUMatrixCUDAKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3117,11 +3117,90 @@ __global__ void _assignNceDerivative(
}
else //bias vector
{
c[wid] -= er;
//ElemType val = -er;
atomicAdd(&c[wid], -er);
//c[wid] -= er;
}
}
}
template<class ElemType>
__global__ void _assignNceDerivativeNew(
const ElemType* val,
int numRows,
int sampleCount,
const ElemType* a,
int width, // number of columns in a
const ElemType* b,
const ElemType* tmp,
ElemType* c,
size_t inputIndex)
{
// val and col are CSR format sparse matrix for label
// val is an array contains log_Pn(w). To differentiate positive and negative samples
// we store log_Pn(w) as it is for positive samples, and -log_Pn(w) for negative samples
// col is an array contains index of the word samples
// a is a matrix in column major format contains output from hidden layer
// b is the weight matrix for output layer
// tmp is a matrix of precalculated error
// c is the output matrix to store calculated gradients

// assume a 1 dimensional thread array
int tx = threadIdx.x; // thread index in thread-block (0-indexed)
int bx = blockIdx.x; // block index (0-indexed)
int bdim = blockDim.x; // number of threads in thread block

// logical single index for this thread
int n = tx + bdim*bx;

int batchId = n / sampleCount;
int total = numRows * sampleCount;
// is thread in range for the addition
if (n < total)
{
int wid = (int)val[2 * n];
ElemType er = tmp[n];
//c[n] = a[n] + b[n]; // this thread does one addition
if (inputIndex == 1)
{
for (int j = 0; j < width; j++)
{
ElemType val = -er * b[IDX2C(j, wid, width)];
atomicAdd(&c[IDX2C(j, batchId, width)], val);
}
}
else if (inputIndex == 2)
{
for (int j = 0; j < width; j++)
{
ElemType val = -er * a[IDX2C(j, batchId, width)];
atomicAdd(&c[IDX2C(j, wid, width)], val);
}
}
else
atomicAdd(&c[wid], -er);
}
/*
int loadPerBlock = (total + gridDim.x - 1) / gridDim.x;
// find out the items this block is responsible for
int start = loadPerBlock * blockIdx.x;
int end = min(total, loadPerBlock * (blockIdx.x + 1));
// find out the items this block is responsible for
for (int i = start; i < end; i++)
{
int wid = (int)val[2 * i];
ElemType er = tmp[i]; // precalculated error for this output node
if (inputIndex == 3) //bias vector
{
//ElemType val = -er;
atomicAdd(&c[wid], -er);
//c[wid] -= er;
}
}
*/
}
// compute gradients of weights in cross entropy node
template<class ElemType>
__global__ void _computeGradientOfWeight(
Expand Down
6 changes: 1 addition & 5 deletions Math/Math/Matrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include <assert.h>
#include <math.h>
#include "GPUWatcher.h" // bring in this class as well so that it gets exported from this DLL
#include <iostream>

#ifndef CPUONLY
#pragma comment (lib, "CNTKMathCUDA.lib") // built by CNTKMathCUDA project
Expand Down Expand Up @@ -3716,11 +3715,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {
if (a.IsEmpty() || b.IsEmpty() || c.IsEmpty())
throw std::logic_error("AssignNoiseContrastiveEstimation: one of the input matrices is empty.");

if (a.GetDeviceId() != b.GetDeviceId() || b.GetDeviceId() != c.GetDeviceId() || c.GetDeviceId() != this->GetDeviceId())
{
std::cerr << a.GetDeviceId() << " " << b.GetDeviceId() << " " << c.GetDeviceId() << " " << this->GetDeviceId() << std::endl;
if (a.GetDeviceId() != b.GetDeviceId() || b.GetDeviceId() != c.GetDeviceId() || c.GetDeviceId() != this->GetDeviceId())
NOT_IMPLEMENTED;
}

this->Resize(1, 1);

Expand Down

0 comments on commit 6168015

Please sign in to comment.