From 93d671c2876e1d14fce185590dba562e304810d0 Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Sun, 26 Feb 2023 20:15:24 +0000 Subject: [PATCH] better cuda kernel --- RWKV-v4neo/cuda/wkv_cuda.cu | 134 +++++++++++++++++++----------------- RWKV-v4neo/cuda/wkv_op.cpp | 6 +- RWKV-v4neo/src/model.py | 23 ++++--- 3 files changed, 88 insertions(+), 75 deletions(-) diff --git a/RWKV-v4neo/cuda/wkv_cuda.cu b/RWKV-v4neo/cuda/wkv_cuda.cu index 6acd0f36..3d5dadbd 100644 --- a/RWKV-v4neo/cuda/wkv_cuda.cu +++ b/RWKV-v4neo/cuda/wkv_cuda.cu @@ -18,28 +18,33 @@ __global__ void kernel_forward(const int B, const int T, const int C, const F *__restrict__ const v = _v + _offset; F *__restrict__ const y = _y + _offset; - F p = 0, q = 0, o = MIN_VALUE; - // p and q are running sums divided by exp(o) (to avoid overflows) + // aa and bb are running sums divided by exp(pp) (to avoid overflow) + F aa = 0, bb = 0, pp = MIN_VALUE; for (int i = 0; i < T; i++) { const int ii = i * C; - - F no = max(o, u + k[ii]); - F A = exp(o - no); - F B = exp(u + k[ii] - no); - y[ii] = (A * p + B * v[ii]) / (A * q + B); - - no = max(w + o, k[ii]); - A = exp(w + o - no); - B = exp(k[ii] - no); - p = A * p + B * v[ii]; - q = A * q + B; - o = no; + const F kk = k[ii]; + const F vv = v[ii]; + + F ww = u + kk; + F p = max(pp, ww); + F e1 = exp(pp - p); + F e2 = exp(ww - p); + y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2); + + ww = w + pp; + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; } } template __global__ void kernel_backward(const int B, const int T, const int C, - const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy, + const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, + const F *__restrict__ const _y, const F *__restrict__ const _gy, F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) { const int idx = blockIdx.x * blockDim.x + threadIdx.x; const int _b = idx / C; @@ -50,64 +55,67 @@ __global__ void kernel_backward(const int B, const int T, const int C, F w = _w[_c]; const F *__restrict__ const k = _k + _offset; const F *__restrict__ const v = _v + _offset; + const F *__restrict__ const y = _y + _offset; const F *__restrict__ const gy = _gy + _offset; - F *__restrict__ const gk = _gk + _offset; F *__restrict__ const gv = _gv + _offset; - F y[Tmax], z[Tmax], zexp[Tmax]; + F q[Tmax], r[Tmax]; - F gw = 0, gu = 0; - F p = 0, q = 0; - F dpdw = 0, dqdw = 0; - F o = MIN_VALUE; + F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE; for (int i = 0; i < T; i++) { const int ii = i * C; - F no = max(o, k[ii] + u); - F A = exp(o - no); - F B = exp(k[ii] + u - no); - - F num = A * p + B * v[ii]; - F iden = 1 / (A * q + B); - - y[i] = num * iden; - z[i] = iden; - zexp[i] = k[ii] + u - no; - - gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A; - gu += gy[ii] * (v[ii] - y[i]) * B * iden; - - no = max(w + o, k[ii]); - A = exp(w + o - no); - B = exp(k[ii] - no); - dpdw = A * (p + dpdw); - dqdw = A * (q + dqdw); - p = A * p + B * v[ii]; - q = A * q + B; - o = no; + const F kk = k[ii]; + const F vv = v[ii]; + const F yy = y[ii]; + + F ww = u + kk; + F p = max(pp, ww); + F e1 = exp(pp - p); + F e2 = exp(ww - p); + const F qq = gy[ii] / (e1 * bb + e2); + gw += (ga - gb * yy) * e1 * qq; + gu += (vv - yy) * e2 * qq; + q[i] = qq; + r[i] = ww - p; + + ww = w + pp; + p = max(ww, kk); + e1 = exp(ww - p); + e2 = exp(kk - p); + ga = e1 * (aa + ga); + gb = e1 * (bb + gb); + aa = e1 * aa + e2 * vv; + bb = e1 * bb + e2; + pp = p; } + const int _offsetBC = _b * C + _c; + _gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward() + _gu[_offsetBC] = gu; - F gp = 0, gq = 0; - o = MIN_VALUE; + aa = 0, bb = 0, pp = MIN_VALUE; for (int i = T - 1; i >= 0; i--) { const int ii = i * C; - F A = gy[ii] * z[i] * exp(zexp[i]); - F B = exp(k[ii] + o); - gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq); - gv[ii] = A + B * gp; - - F no = max(w + o, zexp[i] - k[ii] - u); - A = exp(w + o - no); - B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no); - gp = A * gp + B; - gq = A * gq - B * y[i]; - o = no; + const F kk = k[ii]; + const F vv = v[ii]; + const F yy = y[ii]; + const F qq = q[i]; + const F rr = r[i]; + + F e1 = qq * exp(rr); + F e2 = exp(kk + pp); + gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb); + gv[ii] = e1 + e2 * aa; + + const F ww = w + pp; + const F www = rr - u - kk; + const F p = max(ww, www); + e1 = exp(ww - p); + e2 = qq * exp(www - p); + aa = e1 * aa + e2; + bb = e1 * bb - e2 * yy; + pp = p; } - - // Multiply by w because the w -> -exp(w) preprocessing is halfway in the backwards pass, even though it's not in the forward pass - const int _offsetBC = _b * C + _c; - _gw[_offsetBC] += gw * _w[_c]; - _gu[_offsetBC] += gu; } void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) { @@ -117,9 +125,9 @@ void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, f kernel_forward<<>>(B, T, C, w, u, k, v, y); } -void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) { +void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) { dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance assert(B * C % threadsPerBlock.x == 0); dim3 numBlocks(B * C / threadsPerBlock.x); - kernel_backward<<>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv); + kernel_backward<<>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv); } diff --git a/RWKV-v4neo/cuda/wkv_op.cpp b/RWKV-v4neo/cuda/wkv_op.cpp index efe56d8d..802021f0 100644 --- a/RWKV-v4neo/cuda/wkv_op.cpp +++ b/RWKV-v4neo/cuda/wkv_op.cpp @@ -1,13 +1,13 @@ #include void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y); -void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv); +void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv); void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { cuda_forward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr()); } -void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { - cuda_backward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); +void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { + cuda_backward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr(), gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { diff --git a/RWKV-v4neo/src/model.py b/RWKV-v4neo/src/model.py index 67804c85..71ddc340 100644 --- a/RWKV-v4neo/src/model.py +++ b/RWKV-v4neo/src/model.py @@ -41,7 +41,7 @@ def __nop(ob): from torch.utils.cpp_extension import load -wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", f"-DTmax={T_MAX}"]) +wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"]) class WKV(torch.autograd.Function): @@ -62,9 +62,9 @@ def forward(ctx, B, T, C, w, u, k, v): u = u.float().contiguous() k = k.float().contiguous() v = v.float().contiguous() - ctx.save_for_backward(w, u, k, v) y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format) wkv_cuda.forward(B, T, C, w, u, k, v, y) + ctx.save_for_backward(w, u, k, v, y) if "32" in os.environ["RWKV_FLOAT_MODE"]: return y elif os.environ["RWKV_FLOAT_MODE"] == "fp16": @@ -79,15 +79,20 @@ def backward(ctx, gy): C = ctx.C assert T <= T_MAX assert B * C % min(C, 32) == 0 - w, u, k, v = ctx.saved_tensors - gw = torch.zeros((B, C), device=gy.device).contiguous() - gu = torch.zeros((B, C), device=gy.device).contiguous() - gk = torch.zeros((B, T, C), device=gy.device).contiguous() - gv = torch.zeros((B, T, C), device=gy.device).contiguous() + w, u, k, v, y = ctx.saved_tensors + gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format) + gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format) + gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format) + gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format) if "32" in os.environ["RWKV_FLOAT_MODE"]: - wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv) + wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv) else: - wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv) + wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.float().contiguous(), gw, gu, gk, gv) + del w + del u + del k + del v + del y gw = torch.sum(gw, dim=0) gu = torch.sum(gu, dim=0) if "32" in os.environ["RWKV_FLOAT_MODE"]: