Skip to content

Commit

Permalink
better cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed Nov 23, 2024
1 parent 69bb585 commit bf4025e
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 17 deletions.
9 changes: 1 addition & 8 deletions RWKV-v7/cuda/wkv7.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,6 @@ __global__ void kernel_forward(const int B, const int T, const int C, const int
float state[_N_] = {0};
__shared__ float r[_N_], k[_N_], w[_N_], a[_N_], b[_N_];

float v[_T_];
for (int _t = 0; _t < T; _t++)
{
const int t = e*T*C + h*_N_ + i + _t * C;
v[_t] = float(_v[t]);
}

for (int _t = 0; _t < T; _t++)
{
const int t = e*T*C + h*_N_ + i + _t * C;
Expand All @@ -42,7 +35,7 @@ __global__ void kernel_forward(const int B, const int T, const int C, const int
sa += a[j] * state[j];
}

float vv = v[_t];
float vv = float(_v[t]);
float y = 0;
#pragma unroll
for (int j = 0; j < _N_; j++)
Expand Down
10 changes: 1 addition & 9 deletions RWKV-v7/rwkv_v7_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# model download: https://huggingface.co/BlinkDL/temp-latest-training-models/tree/main
MODEL_PATH = "/mnt/e/rwkv-x070-rc4a-172m-pile-20241120-ctx4k.pth"
args.n_layer = 12
args.ctx_len = 4096
args.n_embd = 768

args.vocab_size = 50304 # "pile" model: 50277 padded to 50304
Expand All @@ -36,11 +35,10 @@
args.head_size_divisor = 8 # don't change

from torch.utils.cpp_extension import load
T = args.ctx_len
HEAD_SIZE = args.head_size_a

load(name="wkv7", sources=["cuda/wkv7_op.cpp", f"cuda/wkv7.cu"], is_python_module=False,
verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}", f"-D_T_={T}"])
verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}"])
class WKV_7(torch.autograd.Function):
@staticmethod
def forward(ctx, r, w, k, v, a, b):
Expand Down Expand Up @@ -225,14 +223,8 @@ def forward(self, x, v0):
class RWKV(nn.Module):
def __init__(self, args):
super().__init__()
self.args = args
args.dim_att = args.n_embd
args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32)

assert args.n_embd % 32 == 0
assert args.dim_att % 32 == 0
assert args.dim_ffn % 32 == 0

self.emb = nn.Embedding(args.vocab_size, args.n_embd)

self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
Expand Down

0 comments on commit bf4025e

Please sign in to comment.