Skip to content

Commit

Permalink
with non-cuda rwkv7 kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed Nov 24, 2024
1 parent c5f9f15 commit ce7be6e
Showing 1 changed file with 69 additions and 38 deletions.
107 changes: 69 additions & 38 deletions RWKV-v7/rwkv_v7_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,55 +16,86 @@

# model download: https://huggingface.co/BlinkDL/temp-latest-training-models/tree/main
MODEL_PATH = "/mnt/e/rwkv-x070-rc4a-172m-pile-20241120-ctx4k.pth"

args.n_layer = 12
args.n_embd = 768

args.vocab_size = 50304 # "pile" model: 50277 padded to 50304
from tokenizers import Tokenizer
tokenizer = Tokenizer.from_file("../RWKV-v4neo/20B_tokenizer.json")

# DTYPE = torch.bfloat16
DTYPE = torch.half
RESCALE_LAYER = -1

########################################################################################################
# CUDA Kernel
########################################################################################################
DTYPE = torch.half # better

RESCALE_LAYER = -1 # not used here
args.head_size_a = 64 # don't change
args.head_size_divisor = 8 # don't change

from torch.utils.cpp_extension import load
HEAD_SIZE = args.head_size_a

load(name="wkv7", sources=["cuda/wkv7_op.cpp", f"cuda/wkv7.cu"], is_python_module=False,
verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}"])
class WKV_7(torch.autograd.Function):
@staticmethod
def forward(ctx, r, w, k, v, a, b):
with torch.no_grad():
B, T, C = r.size()
H = C // HEAD_SIZE
N = HEAD_SIZE
assert HEAD_SIZE == C // H
assert r.dtype == DTYPE
assert w.dtype == DTYPE
assert k.dtype == DTYPE
assert v.dtype == DTYPE
assert a.dtype == DTYPE
assert b.dtype == DTYPE
assert r.is_contiguous()
assert w.is_contiguous()
assert k.is_contiguous()
assert v.is_contiguous()
assert a.is_contiguous()
assert b.is_contiguous()
y = torch.empty((B, T, C), device=k.device, dtype=DTYPE, memory_format=torch.contiguous_format)
torch.ops.wkv7.forward(B, T, C, H, r, w, k, v, a, b, y)
return y

def RUN_CUDA_RWKV7(r, w, k, v, a, b):
return WKV_7.apply(r, w, k, v, a, b)
USE_CUDA_KERNEL = True # False => UNOPTIMIZED, VERY SLOW

########################################################################################################
# CUDA Kernel
########################################################################################################

if USE_CUDA_KERNEL:

from torch.utils.cpp_extension import load

load(name="wkv7", sources=["cuda/wkv7_op.cpp", f"cuda/wkv7.cu"], is_python_module=False,
verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}"])
class WKV_7(torch.autograd.Function):
@staticmethod
def forward(ctx, r, w, k, v, a, b):
with torch.no_grad():
B, T, C = r.size()
H = C // HEAD_SIZE
N = HEAD_SIZE
assert HEAD_SIZE == C // H
assert r.dtype == DTYPE
assert w.dtype == DTYPE
assert k.dtype == DTYPE
assert v.dtype == DTYPE
assert a.dtype == DTYPE
assert b.dtype == DTYPE
assert r.is_contiguous()
assert w.is_contiguous()
assert k.is_contiguous()
assert v.is_contiguous()
assert a.is_contiguous()
assert b.is_contiguous()
y = torch.empty((B, T, C), device=k.device, dtype=DTYPE, memory_format=torch.contiguous_format)
torch.ops.wkv7.forward(B, T, C, H, r, w, k, v, a, b, y)
return y

def RWKV7_OP(r, w, k, v, a, b):
return WKV_7.apply(r, w, k, v, a, b)

else:

def RWKV7_OP(r, w, k, v, a, b):
B, T, C = r.size()
H = C // HEAD_SIZE
N = HEAD_SIZE
r = r.view(B, T, H, N).float()
k = k.view(B, T, H, N).float()
v = v.view(B, T, H, N).float()
a = a.view(B, T, H, N).float()
b = b.view(B, T, H, N).float()
w = torch.exp(-torch.exp(w.view(B, T, H, N).float()))
out = torch.zeros((B, T, H, N), device=r.device, dtype=torch.float)
state = torch.zeros((B, H, N, N), device=r.device, dtype=torch.float)

for t in range(T):
kk = k[:, t, :]
rr = r[:, t, :]
vv = v[:, t, :]
aa = a[:, t, :]
bb = b[:, t, :]
sab = torch.einsum('bhik,bhk,bhj->bhij', state, aa, bb)
state = state * w[: , t, :, None, :] + sab + torch.einsum('bhj,bhi->bhij', kk, vv)
out[:, t, :] = torch.einsum('bhj,bhij->bhi', rr, state)

return out.view(B, T, C).to(dtype=DTYPE)

########################################################################################################
# RWKV TimeMix
Expand Down Expand Up @@ -147,7 +178,7 @@ def forward(self, x, v0):
kk = F.normalize(kk.view(B,T,H,-1), dim=-1, p=2.0).view(B,T,C)
k = k * (1 + (a-1) * self.time_misc_a)

x = RUN_CUDA_RWKV7(r, w, k, v, -kk, kk*a)
x = RWKV7_OP(r, w, k, v, -kk, kk*a)

x = self.ln_x(x.view(B * T, C)).view(B, T, C)

Expand Down

0 comments on commit ce7be6e

Please sign in to comment.