Skip to content

Commit

Permalink
better cuda kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed Feb 26, 2023
1 parent 760db55 commit 93d671c
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 75 deletions.
134 changes: 71 additions & 63 deletions RWKV-v4neo/cuda/wkv_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,33 @@ __global__ void kernel_forward(const int B, const int T, const int C,
const F *__restrict__ const v = _v + _offset;
F *__restrict__ const y = _y + _offset;

F p = 0, q = 0, o = MIN_VALUE;
// p and q are running sums divided by exp(o) (to avoid overflows)
// aa and bb are running sums divided by exp(pp) (to avoid overflow)
F aa = 0, bb = 0, pp = MIN_VALUE;
for (int i = 0; i < T; i++) {
const int ii = i * C;

F no = max(o, u + k[ii]);
F A = exp(o - no);
F B = exp(u + k[ii] - no);
y[ii] = (A * p + B * v[ii]) / (A * q + B);

no = max(w + o, k[ii]);
A = exp(w + o - no);
B = exp(k[ii] - no);
p = A * p + B * v[ii];
q = A * q + B;
o = no;
const F kk = k[ii];
const F vv = v[ii];

F ww = u + kk;
F p = max(pp, ww);
F e1 = exp(pp - p);
F e2 = exp(ww - p);
y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2);

ww = w + pp;
p = max(ww, kk);
e1 = exp(ww - p);
e2 = exp(kk - p);
aa = e1 * aa + e2 * vv;
bb = e1 * bb + e2;
pp = p;
}
}

template <typename F>
__global__ void kernel_backward(const int B, const int T, const int C,
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy,
const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
const F *__restrict__ const _y, const F *__restrict__ const _gy,
F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int _b = idx / C;
Expand All @@ -50,64 +55,67 @@ __global__ void kernel_backward(const int B, const int T, const int C,
F w = _w[_c];
const F *__restrict__ const k = _k + _offset;
const F *__restrict__ const v = _v + _offset;
const F *__restrict__ const y = _y + _offset;
const F *__restrict__ const gy = _gy + _offset;

F *__restrict__ const gk = _gk + _offset;
F *__restrict__ const gv = _gv + _offset;

F y[Tmax], z[Tmax], zexp[Tmax];
F q[Tmax], r[Tmax];

F gw = 0, gu = 0;
F p = 0, q = 0;
F dpdw = 0, dqdw = 0;
F o = MIN_VALUE;
F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
for (int i = 0; i < T; i++) {
const int ii = i * C;
F no = max(o, k[ii] + u);
F A = exp(o - no);
F B = exp(k[ii] + u - no);

F num = A * p + B * v[ii];
F iden = 1 / (A * q + B);

y[i] = num * iden;
z[i] = iden;
zexp[i] = k[ii] + u - no;

gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A;
gu += gy[ii] * (v[ii] - y[i]) * B * iden;

no = max(w + o, k[ii]);
A = exp(w + o - no);
B = exp(k[ii] - no);
dpdw = A * (p + dpdw);
dqdw = A * (q + dqdw);
p = A * p + B * v[ii];
q = A * q + B;
o = no;
const F kk = k[ii];
const F vv = v[ii];
const F yy = y[ii];

F ww = u + kk;
F p = max(pp, ww);
F e1 = exp(pp - p);
F e2 = exp(ww - p);
const F qq = gy[ii] / (e1 * bb + e2);
gw += (ga - gb * yy) * e1 * qq;
gu += (vv - yy) * e2 * qq;
q[i] = qq;
r[i] = ww - p;

ww = w + pp;
p = max(ww, kk);
e1 = exp(ww - p);
e2 = exp(kk - p);
ga = e1 * (aa + ga);
gb = e1 * (bb + gb);
aa = e1 * aa + e2 * vv;
bb = e1 * bb + e2;
pp = p;
}
const int _offsetBC = _b * C + _c;
_gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward()
_gu[_offsetBC] = gu;

F gp = 0, gq = 0;
o = MIN_VALUE;
aa = 0, bb = 0, pp = MIN_VALUE;
for (int i = T - 1; i >= 0; i--) {
const int ii = i * C;
F A = gy[ii] * z[i] * exp(zexp[i]);
F B = exp(k[ii] + o);
gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq);
gv[ii] = A + B * gp;

F no = max(w + o, zexp[i] - k[ii] - u);
A = exp(w + o - no);
B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no);
gp = A * gp + B;
gq = A * gq - B * y[i];
o = no;
const F kk = k[ii];
const F vv = v[ii];
const F yy = y[ii];
const F qq = q[i];
const F rr = r[i];

F e1 = qq * exp(rr);
F e2 = exp(kk + pp);
gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb);
gv[ii] = e1 + e2 * aa;

const F ww = w + pp;
const F www = rr - u - kk;
const F p = max(ww, www);
e1 = exp(ww - p);
e2 = qq * exp(www - p);
aa = e1 * aa + e2;
bb = e1 * bb - e2 * yy;
pp = p;
}

// Multiply by w because the w -> -exp(w) preprocessing is halfway in the backwards pass, even though it's not in the forward pass
const int _offsetBC = _b * C + _c;
_gw[_offsetBC] += gw * _w[_c];
_gu[_offsetBC] += gu;
}

void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
Expand All @@ -117,9 +125,9 @@ void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, f
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
}

void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) {
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) {
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
assert(B * C % threadsPerBlock.x == 0);
dim3 numBlocks(B * C / threadsPerBlock.x);
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv);
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
}
6 changes: 3 additions & 3 deletions RWKV-v4neo/cuda/wkv_op.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#include <torch/extension.h>

void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv);
void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv);

void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
}
void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
Expand Down
23 changes: 14 additions & 9 deletions RWKV-v4neo/src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __nop(ob):

from torch.utils.cpp_extension import load

wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", f"-DTmax={T_MAX}"])
wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"])


class WKV(torch.autograd.Function):
Expand All @@ -62,9 +62,9 @@ def forward(ctx, B, T, C, w, u, k, v):
u = u.float().contiguous()
k = k.float().contiguous()
v = v.float().contiguous()
ctx.save_for_backward(w, u, k, v)
y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format)
wkv_cuda.forward(B, T, C, w, u, k, v, y)
ctx.save_for_backward(w, u, k, v, y)
if "32" in os.environ["RWKV_FLOAT_MODE"]:
return y
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
Expand All @@ -79,15 +79,20 @@ def backward(ctx, gy):
C = ctx.C
assert T <= T_MAX
assert B * C % min(C, 32) == 0
w, u, k, v = ctx.saved_tensors
gw = torch.zeros((B, C), device=gy.device).contiguous()
gu = torch.zeros((B, C), device=gy.device).contiguous()
gk = torch.zeros((B, T, C), device=gy.device).contiguous()
gv = torch.zeros((B, T, C), device=gy.device).contiguous()
w, u, k, v, y = ctx.saved_tensors
gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format)
gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format)
gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format)
gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format)
if "32" in os.environ["RWKV_FLOAT_MODE"]:
wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv)
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv)
else:
wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.float().contiguous(), gw, gu, gk, gv)
del w
del u
del k
del v
del y
gw = torch.sum(gw, dim=0)
gu = torch.sum(gu, dim=0)
if "32" in os.environ["RWKV_FLOAT_MODE"]:
Expand Down

0 comments on commit 93d671c

Please sign in to comment.