Skip to content

Commit

Permalink
update encoder_forward with the latest and greatest Packed128 and inc…
Browse files Browse the repository at this point in the history
…orporate into train_gpt2cu
  • Loading branch information
karpathy committed May 2, 2024
1 parent 8a9510f commit 9da2729
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 59 deletions.
121 changes: 73 additions & 48 deletions dev/cuda/encoder_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,22 @@ version 3 is like version 2 but uses float4 reads/writes
#include "common.h"
#include <cassert>

// turn on bf16 as default, done up here for now
#define ENABLE_BF16

#if defined(ENABLE_BF16)
typedef __nv_bfloat16 floatX;
typedef __nv_bfloat16 floatN;
#elif defined(ENABLE_FP16)
typedef half floatX;
typedef half floatN;
#else
typedef float floatX;
typedef float floatN;
#endif

typedef Packed128<floatX> x128;

// ----------------------------------------------------------------------------
// CPU code reference

Expand All @@ -44,28 +60,28 @@ void encoder_forward_cpu(float* out,
// GPU kernels

// naive implementation into kernel, parallelize over B,T, loop over C
__global__ void encoder_forward_kernel1(float* out,
const int* inp, const float* wte, const float* wpe,
__global__ void encoder_forward_kernel1(floatX* out,
const int* inp, const floatX* wte, const floatX* wpe,
int B, int T, int C) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int N = B * T;

if (idx < N) {
int b = idx / T;
int t = idx % T;
float* out_bt = out + b * T * C + t * C;
floatX* out_bt = out + b * T * C + t * C;
int ix = inp[b * T + t];
const float* wte_ix = wte + ix * C;
const float* wpe_t = wpe + t * C;
const floatX* wte_ix = wte + ix * C;
const floatX* wpe_t = wpe + t * C;
for (int i = 0; i < C; i++) {
out_bt[i] = wte_ix[i] + wpe_t[i];
out_bt[i] = (floatX)((float)wte_ix[i] + (float)wpe_t[i]);
}
}
}

// optimized implementation: parallelize over all of B,T,C
__global__ void encoder_forward_kernel2(float* out,
const int* inp, const float* wte, const float* wpe,
__global__ void encoder_forward_kernel2(floatX* out,
const int* inp, const floatX* wte, const floatX* wpe,
int B, int T, int C) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int N = B * T * C;
Expand All @@ -78,40 +94,46 @@ __global__ void encoder_forward_kernel2(float* out,

int ix = inp[b * T + t];

float* out_btc = out + b * T * C + t * C + c;
const float* wte_ix = wte + ix * C + c;
const float* wpe_tc = wpe + t * C + c;
*out_btc = *wte_ix + *wpe_tc;
floatX* out_btc = out + b * T * C + t * C + c;
const floatX* wte_ix = wte + ix * C + c;
const floatX* wpe_tc = wpe + t * C + c;
*out_btc = (floatX)((float)*wte_ix + (float)*wpe_tc);
}
}

__device__ inline float4 add_float4(const float4& a, const float4& b) {
return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
}

// use of float4 leads to using 128-bit LDG / STG instructions in SASS,
// very helpful in memory-bound kernels like encoder_forward
__global__ void encoder_forward_kernel3(float4* out,
const int* inp, const float4* wte, const float4* wpe,
__global__ void encoder_forward_kernel3(floatX* out,
const int* inp, const floatX* wte, const floatX* wpe,
int B, int T, int C) {
int C4 = C / 4;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int N = B * T * C4;
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size;
int N = B * T * C;
if (idx < N) {
int bt = idx / C4;
int bt = idx / C;
int b = bt / T;
int t = bt % T;
int c4 = idx % C4;
int c = idx % C;

int ix = inp[b * T + t];
out[b * T * C4 + t * C4 + c4] = add_float4(wte[ix * C4 + c4], wpe[t * C4 + c4]);

floatX* out_btc = out + b * T * C + t * C + c;
const floatX* wte_ix = wte + ix * C + c;
const floatX* wpe_tc = wpe + t * C + c;

x128 packed_out;
x128 wte = load128cs(wte_ix);
x128 wpe = load128cs(wpe_tc);
#pragma unroll wte.size
for (int k = 0; k < wte.size; k++) {
packed_out[k] = (floatX)((float)wte[k] + (float)wpe[k]);
}
store128(out_btc, packed_out);
}
}

// ----------------------------------------------------------------------------
// kernel launcher

void encoder_forward1(float* out,
const int* inp, const float* wte, const float* wpe,
void encoder_forward1(floatX* out,
const int* inp, const floatX* wte, const floatX* wpe,
int B, int T, int C,
const int block_size) {
const int N = B * T;
Expand All @@ -120,8 +142,8 @@ void encoder_forward1(float* out,
cudaCheck(cudaGetLastError());
}

void encoder_forward2(float* out,
const int* inp, const float* wte, const float* wpe,
void encoder_forward2(floatX* out,
const int* inp, const floatX* wte, const floatX* wpe,
int B, int T, int C,
const int block_size) {
const int N = B * T * C;
Expand All @@ -130,21 +152,20 @@ void encoder_forward2(float* out,
cudaCheck(cudaGetLastError());
}

void encoder_forward3(float* out,
const int* inp, const float* wte, const float* wpe,
void encoder_forward3(floatX* out,
const int* inp, const floatX* wte, const floatX* wpe,
int B, int T, int C,
const int block_size) {
assert(C % 4 == 0);
const int N = B * T * C;
const int grid_size = ceil_div(N / 4, block_size);
encoder_forward_kernel3<<<grid_size, block_size>>>((float4*) out, inp, (float4*) wte, (float4*) wpe, B, T, C);
const int grid_size = ceil_div(N, (int)(block_size * x128::size));
encoder_forward_kernel3<<<grid_size, block_size>>>(out, inp, wte, wpe, B, T, C);
cudaCheck(cudaGetLastError());
}

// kernel version dispatch
void encoder_forward(int kernel_num,
float* out,
const int* inp, const float* wte, const float* wpe,
floatX* out,
const int* inp, const floatX* wte, const floatX* wpe,
int B, int T, int C,
const int block_size) {
switch (kernel_num) {
Expand All @@ -166,7 +187,7 @@ void encoder_forward(int kernel_num,
// ----------------------------------------------------------------------------

int main(int argc, char **argv) {
srand(0);
setup_main();

int B = 8;
int T = 1024;
Expand All @@ -183,17 +204,17 @@ int main(int argc, char **argv) {
float* wpe = make_random_float(T * C);

// move to GPU
float* d_out;
floatX* d_out;
int* d_inp;
float* d_wte;
float* d_wpe;
cudaCheck(cudaMalloc(&d_out, B * T * C * sizeof(float)));
floatX* d_wte;
floatX* d_wpe;
cudaCheck(cudaMalloc(&d_out, B * T * C * sizeof(floatX)));
cudaCheck(cudaMalloc(&d_inp, B * T * sizeof(int)));
cudaCheck(cudaMalloc(&d_wte, V * C * sizeof(float)));
cudaCheck(cudaMalloc(&d_wpe, T * C * sizeof(float)));
cudaCheck(cudaMalloc(&d_wte, V * C * sizeof(floatX)));
cudaCheck(cudaMalloc(&d_wpe, T * C * sizeof(floatX)));
cudaCheck(cudaMemcpy(d_inp, inp, B * T * sizeof(int), cudaMemcpyHostToDevice));
cudaCheck(cudaMemcpy(d_wte, wte, V * C * sizeof(float), cudaMemcpyHostToDevice));
cudaCheck(cudaMemcpy(d_wpe, wpe, T * C * sizeof(float), cudaMemcpyHostToDevice));
cudaCheck(memcpy_convert(d_wte, wte, V * C));
cudaCheck(memcpy_convert(d_wpe, wpe, T * C));

// read kernel_num from command line
int kernel_num = 2;
Expand All @@ -205,15 +226,19 @@ int main(int argc, char **argv) {
// first check the correctness of the kernel
encoder_forward_cpu(out, inp, wte, wpe, B, T, C);


// time the kernel at different block sizes
int block_sizes[] = {32, 64, 128, 256, 512, 1024};

for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {
int block_size = block_sizes[j];
printf("Checking block size %d.\n", block_size);
encoder_forward(kernel_num, d_out, d_inp, d_wte, d_wpe, B, T, C, block_size);
validate_result(d_out, out, "out", B * T * C, 1e-5f);
#if !defined(ENABLE_BF16) && !defined(ENABLE_FP16)
float tol = 1e-5;
#else
float tol = 1e-2f;
#endif
validate_result(d_out, out, "out", B * T * C, tol);
}

printf("All results match. Starting benchmarks.\n\n");
Expand Down
29 changes: 18 additions & 11 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -732,12 +732,11 @@ void attention_backward_cudnn(floatX* dqkvr,
// ----------------------------------------------------------------------------
// all the kernels

__global__ void encoder_forward_kernel2(floatX* out,
int* inp, floatX* wte, floatX* wpe,
__global__ void encoder_forward_kernel3(floatX* out,
const int* inp, const floatX* wte, const floatX* wpe,
int B, int T, int C) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int idx = (blockIdx.x * blockDim.x + threadIdx.x) * x128::size;
int N = B * T * C;

if (idx < N) {
int bt = idx / C;
int b = bt / T;
Expand All @@ -747,9 +746,17 @@ __global__ void encoder_forward_kernel2(floatX* out,
int ix = inp[b * T + t];

floatX* out_btc = out + b * T * C + t * C + c;
floatX* wte_ix = wte + ix * C + c;
floatX* wpe_tc = wpe + t * C + c;
*out_btc = (floatX)((float)*wte_ix + (float)*wpe_tc);
const floatX* wte_ix = wte + ix * C + c;
const floatX* wpe_tc = wpe + t * C + c;

x128 packed_out;
x128 wte = load128cs(wte_ix);
x128 wpe = load128cs(wpe_tc);
#pragma unroll wte.size
for (int k = 0; k < wte.size; k++) {
packed_out[k] = (floatX)((float)wte[k] + (float)wpe[k]);
}
store128(out_btc, packed_out);
}
}

Expand Down Expand Up @@ -1344,12 +1351,12 @@ __global__ void copy_and_cast_kernel(float* dst, const floatX* src, size_t n) {
// kernel launchers

void encoder_forward(floatX* out,
int* inp, floatX* wte, floatX* wpe,
const int* inp, const floatX* wte, const floatX* wpe,
int B, int T, int C) {
const int N = B * T * C;
const int block_size = 256;
const int grid_size = CEIL_DIV(N, block_size);
encoder_forward_kernel2<<<grid_size, block_size>>>(out, inp, wte, wpe, B, T, C);
const int N = B * T * C;
const int grid_size = CEIL_DIV(N, (int)(block_size * x128::size));
encoder_forward_kernel3<<<grid_size, block_size>>>(out, inp, wte, wpe, B, T, C);
cudaCheck(cudaGetLastError());
}

Expand Down

0 comments on commit 9da2729

Please sign in to comment.