Skip to content

Commit

Permalink
faster bf16 & saves VRAM
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed Feb 28, 2023
1 parent 93d671c commit 1d72a48
Show file tree
Hide file tree
Showing 3 changed files with 238 additions and 54 deletions.
132 changes: 132 additions & 0 deletions RWKV-v4neo/cuda/wkv_cuda_bf16.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#include <stdio.h>
#include <assert.h>
#include "ATen/ATen.h"
#define MIN_VALUE (-1e38)
typedef at::BFloat16 bf16;

__global__ void kernel_forward(const int B, const int T, const int C,
const float *__restrict__ const _w, const bf16 *__restrict__ const _u, const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v,
bf16 *__restrict__ const _y) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int _b = idx / C;
const int _c = idx % C;
const int _offset = _b * T * C + _c;

float u = float(_u[_c]);
float w = _w[_c];
const bf16 *__restrict__ const k = _k + _offset;
const bf16 *__restrict__ const v = _v + _offset;
bf16 *__restrict__ const y = _y + _offset;

// aa and bb are running sums divided by exp(pp) (to avoid overflow)
float aa = 0, bb = 0, pp = MIN_VALUE;
for (int i = 0; i < T; i++) {
const int ii = i * C;
const float kk = float(k[ii]);
const float vv = float(v[ii]);

float ww = u + kk;
float p = max(pp, ww);
float e1 = exp(pp - p);
float e2 = exp(ww - p);
y[ii] = bf16((e1 * aa + e2 * vv) / (e1 * bb + e2));

ww = w + pp;
p = max(ww, kk);
e1 = exp(ww - p);
e2 = exp(kk - p);
aa = e1 * aa + e2 * vv;
bb = e1 * bb + e2;
pp = p;
}
}

__global__ void kernel_backward(const int B, const int T, const int C,
const float *__restrict__ const _w, const bf16 *__restrict__ const _u, const bf16 *__restrict__ const _k, const bf16 *__restrict__ const _v,
const bf16 *__restrict__ const _y, const bf16 *__restrict__ const _gy,
bf16 *__restrict__ const _gw, bf16 *__restrict__ const _gu, bf16 *__restrict__ const _gk, bf16 *__restrict__ const _gv) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int _b = idx / C;
const int _c = idx % C;
const int _offset = _b * T * C + _c;

float u = float(_u[_c]);
float w = _w[_c];
const bf16 *__restrict__ const k = _k + _offset;
const bf16 *__restrict__ const v = _v + _offset;
const bf16 *__restrict__ const y = _y + _offset;
const bf16 *__restrict__ const gy = _gy + _offset;
bf16 *__restrict__ const gk = _gk + _offset;
bf16 *__restrict__ const gv = _gv + _offset;

float q[Tmax], r[Tmax];

float gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
for (int i = 0; i < T; i++) {
const int ii = i * C;
const float kk = float(k[ii]);
const float vv = float(v[ii]);
const float yy = float(y[ii]);

float ww = u + kk;
float p = max(pp, ww);
float e1 = exp(pp - p);
float e2 = exp(ww - p);
const float qq = float(gy[ii]) / (e1 * bb + e2);
gw += (ga - gb * yy) * e1 * qq;
gu += (vv - yy) * e2 * qq;
q[i] = qq;
r[i] = ww - p;

ww = w + pp;
p = max(ww, kk);
e1 = exp(ww - p);
e2 = exp(kk - p);
ga = e1 * (aa + ga);
gb = e1 * (bb + gb);
aa = e1 * aa + e2 * vv;
bb = e1 * bb + e2;
pp = p;
}
const int _offsetBC = _b * C + _c;
_gw[_offsetBC] = bf16(gw * _w[_c]); // multiply by w because of w -> -exp(w) in python forward()
_gu[_offsetBC] = bf16(gu);

aa = 0, bb = 0, pp = MIN_VALUE;
for (int i = T - 1; i >= 0; i--) {
const int ii = i * C;
const float kk = float(k[ii]);
const float vv = float(v[ii]);
const float yy = float(y[ii]);
const float qq = q[i];
const float rr = r[i];

float e1 = qq * exp(rr);
float e2 = exp(kk + pp);
gk[ii] = bf16(e1 * (vv - yy) + e2 * (aa * vv + bb));
gv[ii] = bf16(e1 + e2 * aa);

const float ww = w + pp;
const float www = rr - u - kk;
const float p = max(ww, www);
e1 = exp(ww - p);
e2 = qq * exp(www - p);
aa = e1 * aa + e2;
bb = e1 * bb - e2 * yy;
pp = p;
}
}

void cuda_forward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y) {
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
assert(B * C % threadsPerBlock.x == 0);
dim3 numBlocks(B * C / threadsPerBlock.x);
kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
}

void cuda_backward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv) {
dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
assert(B * C % threadsPerBlock.x == 0);
dim3 numBlocks(B * C / threadsPerBlock.x);
kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
}
25 changes: 25 additions & 0 deletions RWKV-v4neo/cuda/wkv_op_bf16.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include <torch/extension.h>
#include "ATen/ATen.h"
typedef at::BFloat16 bf16;

void cuda_forward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y);
void cuda_backward(int B, int T, int C, float *w, bf16 *u, bf16 *k, bf16 *v, bf16 *y, bf16 *gy, bf16 *gw, bf16 *gu, bf16 *gk, bf16 *gv);

void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>());
}
void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y,
torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), y.data_ptr<bf16>(),
gy.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>());
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "wkv forward");
m.def("backward", &backward, "wkv backward");
}

TORCH_LIBRARY(wkv, m) {
m.def("forward", forward);
m.def("backward", backward);
}
135 changes: 81 additions & 54 deletions RWKV-v4neo/src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,66 +41,93 @@ def __nop(ob):

from torch.utils.cpp_extension import load

wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"])


class WKV(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, w, u, k, v):
ctx.B = B
ctx.T = T
ctx.C = C
assert T <= T_MAX
assert B * C % min(C, 32) == 0
if "32" in os.environ["RWKV_FLOAT_MODE"]:
w = -torch.exp(w.contiguous())
if os.environ["RWKV_FLOAT_MODE"] == "bf16":
wkv_cuda = load(name=f"wkv_{T_MAX}_bf16", sources=["cuda/wkv_op_bf16.cpp", "cuda/wkv_cuda_bf16.cu"], verbose=True, extra_cuda_cflags=["-lineinfo", "-std=c++17", "-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"])
class WKV(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, w, u, k, v):
ctx.B = B
ctx.T = T
ctx.C = C
assert T <= T_MAX
assert B * C % min(C, 32) == 0
w = -torch.exp(w.float().contiguous())
u = u.contiguous()
k = k.contiguous()
v = v.contiguous()
else:
w = -torch.exp(w.float().contiguous())
u = u.float().contiguous()
k = k.float().contiguous()
v = v.float().contiguous()
y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format)
wkv_cuda.forward(B, T, C, w, u, k, v, y)
ctx.save_for_backward(w, u, k, v, y)
if "32" in os.environ["RWKV_FLOAT_MODE"]:
y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
wkv_cuda.forward(B, T, C, w, u, k, v, y)
ctx.save_for_backward(w, u, k, v, y)
return y
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
return y.half()
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
return y.bfloat16()

@staticmethod
def backward(ctx, gy):
B = ctx.B
T = ctx.T
C = ctx.C
assert T <= T_MAX
assert B * C % min(C, 32) == 0
w, u, k, v, y = ctx.saved_tensors
gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format)
gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format)
gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format)
gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format)
if "32" in os.environ["RWKV_FLOAT_MODE"]:
@staticmethod
def backward(ctx, gy):
B = ctx.B
T = ctx.T
C = ctx.C
assert T <= T_MAX
assert B * C % min(C, 32) == 0
w, u, k, v, y = ctx.saved_tensors
gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format, dtype=torch.bfloat16)
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv)
else:
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.float().contiguous(), gw, gu, gk, gv)
del w
del u
del k
del v
del y
gw = torch.sum(gw, dim=0)
gu = torch.sum(gu, dim=0)
if "32" in os.environ["RWKV_FLOAT_MODE"]:
gw = torch.sum(gw, dim=0)
gu = torch.sum(gu, dim=0)
return (None, None, None, gw, gu, gk, gv)
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
else:
wkv_cuda = load(name=f"wkv_{T_MAX}", sources=["cuda/wkv_op.cpp", "cuda/wkv_cuda.cu"], verbose=True, extra_cuda_cflags=["-res-usage", "--maxrregcount 60", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-DTmax={T_MAX}"])
class WKV(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, w, u, k, v):
ctx.B = B
ctx.T = T
ctx.C = C
assert T <= T_MAX
assert B * C % min(C, 32) == 0
if "32" in os.environ["RWKV_FLOAT_MODE"]:
w = -torch.exp(w.contiguous())
u = u.contiguous()
k = k.contiguous()
v = v.contiguous()
else:
w = -torch.exp(w.float().contiguous())
u = u.float().contiguous()
k = k.float().contiguous()
v = v.float().contiguous()
y = torch.empty((B, T, C), device=w.device, memory_format=torch.contiguous_format)
wkv_cuda.forward(B, T, C, w, u, k, v, y)
ctx.save_for_backward(w, u, k, v, y)
if "32" in os.environ["RWKV_FLOAT_MODE"]:
return y
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
return y.half()
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
return y.bfloat16()
@staticmethod
def backward(ctx, gy):
B = ctx.B
T = ctx.T
C = ctx.C
assert T <= T_MAX
assert B * C % min(C, 32) == 0
w, u, k, v, y = ctx.saved_tensors
gw = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format)
gu = torch.empty((B, C), device=gy.device, memory_format=torch.contiguous_format)
gk = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format)
gv = torch.empty((B, T, C), device=gy.device, memory_format=torch.contiguous_format)
if "32" in os.environ["RWKV_FLOAT_MODE"]:
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.contiguous(), gw, gu, gk, gv)
else:
wkv_cuda.backward(B, T, C, w, u, k, v, y, gy.float().contiguous(), gw, gu, gk, gv)
gw = torch.sum(gw, dim=0)
gu = torch.sum(gu, dim=0)
if "32" in os.environ["RWKV_FLOAT_MODE"]:
return (None, None, None, gw, gu, gk, gv)
elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())


def RUN_CUDA(B, T, C, w, u, k, v):
Expand Down

0 comments on commit 1d72a48

Please sign in to comment.