Skip to content

Commit

Permalink
RWKV-7 "Goose" x070.rc2-2409-2r7a-b0b4a
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed Sep 23, 2024
1 parent ff72f84 commit e1d143f
Show file tree
Hide file tree
Showing 5 changed files with 5,566 additions and 1 deletion.
2 changes: 1 addition & 1 deletion RWKV-v7/README.md
Original file line number Diff line number Diff line change
@@ -1 +1 @@
.
# RWKV-7 "Goose" x070.rc2-2409-2r7a-b0b4a
62 changes: 62 additions & 0 deletions RWKV-v7/cuda/wkv7.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#include <stdio.h>
#include <assert.h>
#include "ATen/ATen.h"

typedef at::Half bf16;
// typedef at::BFloat16 bf16;

template <typename F>
__global__ void kernel_forward(const int B, const int T, const int C, const int H,
const F *__restrict__ const _r, const F *__restrict__ const _w, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _a, const F *__restrict__ const _b,
F *__restrict__ const _y)
{
const int e = blockIdx.x / H;
const int h = blockIdx.x % H;
const int i = threadIdx.x;

float state[_N_] = {0};
__shared__ float r[_N_], k[_N_], w[_N_], a[_N_], b[_N_];

float v[_T_];
for (int _t = 0; _t < T; _t++)
{
const int t = e*T*C + h*_N_ + i + _t * C;
v[_t] = float(_v[t]);
}

for (int _t = 0; _t < T; _t++)
{
const int t = e*T*C + h*_N_ + i + _t * C;
__syncthreads();
r[i] = float(_r[t]);
w[i] = __expf(-__expf(float(_w[t])));
k[i] = float(_k[t]);
a[i] = float(_a[t]);
b[i] = float(_b[t]);
__syncthreads();

float sa = 0;
#pragma unroll
for (int j = 0; j < _N_; j++)
{
sa += a[j] * state[j];
}

float vv = v[_t];
float y = 0;
#pragma unroll
for (int j = 0; j < _N_; j++)
{
float& s = state[j];
s = s * w[j] + k[j] * vv + sa * b[j];
y += s * r[j];
}
_y[t] = F(y);
}
}

void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16* w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y)
{
assert(H*_N_ == C);
kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, w, k, v, a, b, y);
}
15 changes: 15 additions & 0 deletions RWKV-v7/cuda/wkv7_op.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#include <torch/extension.h>
#include "ATen/ATen.h"

typedef at::Half bf16;
// typedef at::BFloat16 bf16;

void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *w, bf16 *k, bf16 *v, bf16 *a, bf16 *b, bf16 *y);

void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &w, torch::Tensor &k, torch::Tensor &v, torch::Tensor &a, torch::Tensor &b, torch::Tensor &y) {
cuda_forward(B, T, C, H, r.data_ptr<bf16>(), w.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), a.data_ptr<bf16>(), b.data_ptr<bf16>(), y.data_ptr<bf16>());
}

TORCH_LIBRARY(wkv7, m) {
m.def("forward", forward);
}
Loading

0 comments on commit e1d143f

Please sign in to comment.