Skip to content

Commit

Permalink
bias backward kernel that will use all available threads
Browse files Browse the repository at this point in the history
  • Loading branch information
ngc92 committed May 2, 2024
1 parent 6a3d63f commit 6e48501
Showing 1 changed file with 26 additions and 9 deletions.
35 changes: 26 additions & 9 deletions dev/cuda/matmul_backward_bias.cu
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,19 @@ __global__ void matmul_backward_bias_kernel4(float* dbias, const float* dout, in
}
}

__global__ void matmul_backward_bias_kernel5(float* dbias, const float* dout, int B, int T, int OC) {
int oc = blockIdx.x * blockDim.x + threadIdx.x;
if(oc >= OC) return;
float sum = 0.0;
// grid-wide loop for maximum parallelism
for (int i = blockIdx.y; i < B * T; i += gridDim.y) {
sum += dout[i * OC + oc];
}
// and atomcially add everything together. atomics within one block are conflict-free!
atomicAdd(dbias + oc, sum);
}


// ----------------------------------------------------------------------------
// kernel launcher

Expand Down Expand Up @@ -202,6 +215,14 @@ void matmul_backward_bias4(float* dinp, float* dweight, float* dbias,
matmul_backward_bias_kernel4<<<grid_size, block_size, block_size * sizeof(float)>>>(dbias, dout, B, T, OC);
}

void matmul_backward_bias5(float* dinp, float* dweight, float* dbias,
float* dout, float* inp, float* weight, float* ones,
int B, int T, int C, int OC, int block_size) {
const int grid_size_x = ceil_div(OC, block_size);
const int grid_size_y = max(1, cuda_threads_per_SM * cuda_num_SMs / block_size);
matmul_backward_bias_kernel5<<<dim3(grid_size_x, grid_size_y), dim3(block_size)>>>(dbias, dout, B, T, OC);
}

void matmul_backward_bias(int kernel_num,
float* dinp, float* dweight, float* dbias,
float* dout, float* inp, float* weight, float* ones,
Expand All @@ -219,6 +240,9 @@ void matmul_backward_bias(int kernel_num,
case 4:
matmul_backward_bias4(dinp, dweight, dbias, dout, inp, weight, ones, B, T, C, OC, block_size);
break;
case 5:
matmul_backward_bias5(dinp, dweight, dbias, dout, inp, weight, ones, B, T, C, OC, block_size);
break;
default:
printf("Invalid kernel number\n");
exit(1);
Expand All @@ -228,20 +252,13 @@ void matmul_backward_bias(int kernel_num,
// ----------------------------------------------------------------------------

int main(int argc, char **argv) {
srand(0);
setup_main();

int B = 8;
int T = 1024;
int C = 768;
int OC = 768 * 4; // expansion of 4, e.g. in the MLP

// set up the device
int deviceIdx = 0;
cudaCheck(cudaSetDevice(deviceIdx));
cudaDeviceProp deviceProp;
cudaGetDeviceProperties(&deviceProp, deviceIdx);
printf("Device %d: %s\n", deviceIdx, deviceProp.name);

// read kernel_num from command line
int kernel_num = 1;
if (argc > 1) {
Expand Down Expand Up @@ -280,7 +297,7 @@ int main(int argc, char **argv) {
// memset the bias to zero
cudaCheck(cudaMemset(d_dbias, 0, OC * sizeof(float)));
// calculate the GPU version
matmul_backward_bias(kernel_num, NULL, NULL, d_dbias, d_dout, NULL, NULL, NULL, B, T, C, OC, 128);
matmul_backward_bias(kernel_num, NULL, NULL, d_dbias, d_dout, NULL, NULL, NULL, B, T, C, OC, block_size);
// compare
printf("Checking correctness...\n");
validate_result(d_dbias, dbias, "dbias", OC, 5e-3f);
Expand Down

0 comments on commit 6e48501

Please sign in to comment.