Skip to content

Commit

Permalink
misc
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed Sep 13, 2023
1 parent 552cd66 commit 653ae70
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 26 deletions.
6 changes: 3 additions & 3 deletions RWKV-v4neo/cuda/wkv5_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ __global__ void kernel_forward(const int B, const int T, const int C, const int
template <typename F>
__global__ void kernel_backward(const int B, const int T, const int C, const int H,
const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const float *__restrict__ __w, const F *__restrict__ _u, const F *__restrict__ const _gy,
F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, float *__restrict__ _gw, float *__restrict__ _gu)
F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu)
{
const int b = blockIdx.x / H;
const int h = blockIdx.x % H;
Expand Down Expand Up @@ -105,7 +105,7 @@ __global__ void kernel_backward(const int B, const int T, const int C, const int
gw += r * ww * saaaa[j] * gy[j];
}
}
_gw[_t] = gw;
_gw[_t] = F(gw);
}

#pragma unroll
Expand Down Expand Up @@ -167,7 +167,7 @@ void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, y);
}

void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, float *gw, float *gu)
void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu)
{
assert(H*_N_ == C);
kernel_backward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, ww, u, gy, gr, gk, gv, gw, gu);
Expand Down
4 changes: 2 additions & 2 deletions RWKV-v4neo/cuda/wkv5_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
typedef at::BFloat16 bf16;

void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, float *gw, float *gu);
void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu);

void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
cuda_forward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
}
void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &ww, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) {
cuda_backward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), ww.data_ptr<float>(), u.data_ptr<bf16>(), gy.data_ptr<bf16>(), gr.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>(), gw.data_ptr<float>(), gu.data_ptr<float>());
cuda_backward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), ww.data_ptr<float>(), u.data_ptr<bf16>(), gy.data_ptr<bf16>(), gr.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>());
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "wkv5 forward");
Expand Down
41 changes: 20 additions & 21 deletions RWKV-v4neo/src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,49 +320,47 @@ class WKV_5(torch.autograd.Function):
@staticmethod
def forward(ctx, B, T, C, H, r, k, v, w, u):
with torch.no_grad():
assert HEAD_SIZE == C // H
assert r.dtype == torch.bfloat16
assert k.dtype == torch.bfloat16
assert v.dtype == torch.bfloat16
assert w.dtype == torch.bfloat16
assert u.dtype == torch.bfloat16
assert HEAD_SIZE == C // H
ctx.B = B
ctx.T = T
ctx.C = C
ctx.H = H
r = r.contiguous()
k = k.contiguous()
v = v.contiguous()
w = w.float().contiguous()
u = u.contiguous()
ew = -torch.exp(w)
eew = torch.exp(ew)
assert r.is_contiguous()
assert k.is_contiguous()
assert v.is_contiguous()
assert w.is_contiguous()
assert u.is_contiguous()
ew = (-torch.exp(w.float())).contiguous()
eew = (torch.exp(ew)).contiguous()
ctx.save_for_backward(r, k, v, eew, ew, u)
y = torch.empty((B, T, C), device=w.device, dtype=torch.bfloat16, memory_format=torch.contiguous_format)
wkv5_cuda.forward(B, T, C, H, r, k, v, eew, u, y)
return y
y = torch.empty((B, T, C), device=r.device, dtype=torch.bfloat16, memory_format=torch.contiguous_format) # .uniform_(-1, 1)
wkv5_cuda.forward(B, T, C, H, r, k, v, eew, u, y)
return y

@staticmethod
def backward(ctx, gy):
with torch.no_grad():
assert gy.dtype == torch.bfloat16
B = ctx.B
T = ctx.T
C = ctx.C
H = ctx.H
gy = gy.contiguous()
assert gy.dtype == torch.bfloat16
assert gy.is_contiguous()
r, k, v, eew, ew, u = ctx.saved_tensors
gr = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format)
gk = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format)
gv = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format)
gw = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=torch.float, memory_format=torch.contiguous_format)
gu = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=torch.float, memory_format=torch.contiguous_format)

gr = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format) # .uniform_(-1, 1)
gk = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format) # .uniform_(-1, 1)
gv = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format) # .uniform_(-1, 1)
gw = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format) # .uniform_(-1, 1)
gu = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=torch.bfloat16, memory_format=torch.contiguous_format) # .uniform_(-1, 1)
wkv5_cuda.backward(B, T, C, H, r, k, v, eew, ew, u, gy, gr, gk, gv, gw, gu)

gw = torch.sum(gw.view(B*T, H, C//H), 0)
gu = torch.sum(gu.view(B*T, H, C//H), 0)
return (None, None, None, None, gr, gk, gv, gw.bfloat16(), gu.bfloat16())
return (None, None, None, None, gr, gk, gv, gw, gu)

def RUN_CUDA(B, T, C, H, r, k, v, w, u):
return WKV_5.apply(B, T, C, H, r, k, v, w, u)
Expand All @@ -376,6 +374,7 @@ def __init__(self, args, layer_id):
self.layer_id = layer_id

self.head_size = args.head_size_a
assert HEAD_SIZE == self.head_size # change HEAD_SIZE to match args.head_size_a
self.n_head = args.dim_att // self.head_size
assert args.dim_att % self.n_head == 0
self.head_size_divisor = args.head_size_divisor
Expand Down

0 comments on commit 653ae70

Please sign in to comment.