Skip to content

Commit

Permalink
Update residual_forward.cu
Browse files Browse the repository at this point in the history
  • Loading branch information
JaneIllario committed May 2, 2024
1 parent 4dd1ab4 commit 4ffcf5b
Showing 1 changed file with 74 additions and 14 deletions.
88 changes: 74 additions & 14 deletions dev/cuda/residual_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,30 @@ nvcc -O3 --use_fast_math residual_forward.cu -o residual_forward
version 1 is naive port from CPU code to kernel
./residual_forward 1
version 2 packs input into 128 bit memory reads
./residual_forward 2
*/

#include <stdio.h>
#include <stdlib.h>
#include <cuda_runtime.h>
#include "common.h"

// turn on bf16 as default, done up here for now
//#define ENABLE_BF16

#if defined(ENABLE_BF16)
typedef __nv_bfloat16 floatX;
typedef __nv_bfloat16 floatN;
#elif defined(ENABLE_FP16)
typedef half floatX;
typedef half floatN;
#else
typedef float floatX;
typedef float floatN;
#endif

typedef Packed128<floatX> x128;
// ----------------------------------------------------------------------------
// CPU code reference lol

Expand All @@ -26,33 +43,60 @@ void residual_forward_cpu(float* out, const float* inp1, const float* inp2, int
// GPU kernels

// elementwise ops are nice and ez
__global__ void residual_forward_kernel(float* out, const float* inp1, const float* inp2, int N) {
__global__ void residual_forward_kernel1(float* out, const float* inp1, const float* inp2, int N) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < N) {
out[idx] = inp1[idx] + inp2[idx];
}
}

__global__ void residual_forward_kernel2(floatX* out, const floatX* inp1, const floatX* inp2, int N) {
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size;
if (idx < N) {
x128 packed_out;
x128 packed_inp1 = load128cs(inp1 + idx);
x128 packed_inp2 = load128cs(inp2 + idx);
for (int k = 0; k < packed_inp1.size; ++k)
{
packed_out[k] = (floatX)((float)packed_inp1[k] + (float)packed_inp2[k]);
}
store128(out + idx, packed_out);
}
}

// ----------------------------------------------------------------------------
// kernel launcher

void residual_forward1(float* out, const float* inp1, const float* inp2, int N, const int block_size) {
const int grid_size = ceil_div(N, block_size);
residual_forward_kernel<<<grid_size, block_size>>>(out, inp1, inp2, N);
residual_forward_kernel1<<<grid_size, block_size>>>(out, inp1, inp2, N);
cudaCheck(cudaGetLastError());
}

void residual_forward2(floatX* out, const floatX* inp1, const floatX* inp2, int N, const int block_size) {
const int grid_size = ceil_div(N, block_size)/x128::size;
residual_forward_kernel2<<<grid_size, block_size>>>(out, inp1, inp2, N);
cudaCheck(cudaGetLastError());
}

// kernel version dispatch
void residual_forward(int kernel_num,
float* out,
const float* inp1,
const float* inp2,
floatX* out,
const floatX* inp1,
const floatX* inp2,
int N,
int block_size) {
switch (kernel_num) {
#if !defined(ENABLE_BF16) && !defined(ENABLE_FP16)
case 1:
residual_forward1(out, inp1, inp2, N, block_size);
break;
#endif
#if defined(ENABLE_BF16)
case 2:
residual_forward2(out, inp1, inp2, N, block_size);
break;
#endif
default:
printf("Invalid kernel number\n");
exit(1);
Expand All @@ -76,15 +120,24 @@ int main(int argc, char **argv) {
float* inp1 = make_random_float(B * T * C);
float* inp2 = make_random_float(B * T * C);

// create X host memory of random numbers
floatX* inp1X = (floatX*)malloc(B * T * C * sizeof(float));
floatX* inp2X = (floatX*)malloc(B * T * C * sizeof(float));

for (int i = 0; i < B * T * C; i++) {
inp1X[i] = (floatX)inp1[i];
inp2X[i] = (floatX)inp2[i];
}

// move to GPU
float* d_out;
float* d_inp1;
float* d_inp2;
cudaCheck(cudaMalloc(&d_out, B * T * C * sizeof(float)));
cudaCheck(cudaMalloc(&d_inp1, B * T * C * sizeof(float)));
cudaCheck(cudaMalloc(&d_inp2, B * T * C * sizeof(float)));
cudaCheck(cudaMemcpy(d_inp1, inp1, B * T * C * sizeof(float), cudaMemcpyHostToDevice));
cudaCheck(cudaMemcpy(d_inp2, inp2, B * T * C * sizeof(float), cudaMemcpyHostToDevice));
floatX* d_out;
floatX* d_inp1;
floatX* d_inp2;
cudaCheck(cudaMalloc(&d_out, B * T * C * sizeof(floatX)));
cudaCheck(cudaMalloc(&d_inp1, B * T * C * sizeof(floatX)));
cudaCheck(cudaMalloc(&d_inp2, B * T * C * sizeof(floatX)));
cudaCheck(cudaMemcpy(d_inp1, inp1X, B * T * C * sizeof(floatX), cudaMemcpyHostToDevice));
cudaCheck(cudaMemcpy(d_inp2, inp2X, B * T * C * sizeof(floatX), cudaMemcpyHostToDevice));

// read kernel_num from command line
int kernel_num = 1;
Expand All @@ -104,7 +157,12 @@ int main(int argc, char **argv) {
int block_size = block_sizes[j];
printf("Checking block size %d.\n", block_size);
residual_forward(kernel_num, d_out, d_inp1, d_inp2, B * T * C, block_size);
#if !defined(ENABLE_BF16) && !defined(ENABLE_FP16)
validate_result(d_out, out, "out", B * T * C, 1e-5f);
#endif
#if defined(ENABLE_BF16)
validate_result(d_out, out, "out", B * T * C, 1e-2f);
#endif
}

printf("All results match. Starting benchmarks.\n\n");
Expand All @@ -130,9 +188,11 @@ int main(int argc, char **argv) {
free(out);
free(inp1);
free(inp2);
free(inp1X);
free(inp2X);
cudaCheck(cudaFree(d_out));
cudaCheck(cudaFree(d_inp1));
cudaCheck(cudaFree(d_inp2));

return 0;
}
}

0 comments on commit 4ffcf5b

Please sign in to comment.