Skip to content

Commit

Permalink
rwkv-7 rc4
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed Nov 15, 2024
1 parent 3b6080c commit c453b42
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 36 deletions.
4 changes: 2 additions & 2 deletions RWKV-v7/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# RWKV-7 "Goose" x070.rc3-2409-2r7a-d1
# RWKV-7 "Goose" x070.rc4-2411 (WIP, final release soon)

Please try https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v7/rwkv_v7_demo.py

UPDATE: check https://github.com/BlinkDL/modded-nanogpt-rwkv for RWKV-7 CUDA kernel.
Please check https://github.com/BlinkDL/modded-nanogpt-rwkv for RWKV-7 efficient fwd & bwd CUDA kernel.
74 changes: 40 additions & 34 deletions RWKV-v7/rwkv_v7_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
np.set_printoptions(precision=4, suppress=True, linewidth=200)

'''
This will load RWKV-7 "Goose" x070.rc3-2409-2r7a-d1 and inference in GPT-mode (slower than RNN-mode for autoregressive generation)
This will load RWKV-7 "Goose" x070.rc4-2411 and inference in GPT-mode (slower than RNN-mode for autoregressive generation)
'''

args = types.SimpleNamespace()

# model download: https://huggingface.co/BlinkDL/temp-latest-training-models/tree/main
MODEL_PATH = "/mnt/e/rwkv-x070-rc3-172m-pile-20241011-ctx4k.pth"
MODEL_PATH = "/mnt/e/rwkv-x070-rc4-172m-pile-20241115-ctx4k.pth"
args.n_layer = 12
args.ctx_len = 4096
args.n_embd = 768
Expand Down Expand Up @@ -98,6 +98,8 @@ def __init__(self, args, layer_id):
self.time_faaaa = nn.Parameter(torch.empty(self.n_head,self.head_size))
self.time_aaaaa = nn.Parameter(torch.empty(1,1,args.dim_att))

### TOO MANY LORAs HERE. I WILL REMOVE MOST OF THEM IN RWKV-7 FINAL :) ###

D_MIX_LORA = 32
self.time_maa_w1 = nn.Parameter(torch.empty(args.n_embd, D_MIX_LORA*6))
self.time_maa_w2 = nn.Parameter(torch.empty(6, D_MIX_LORA, args.n_embd))
Expand All @@ -110,7 +112,7 @@ def __init__(self, args, layer_id):
self.time_aaa_w1 = nn.Parameter(torch.empty(args.n_embd, D_AAA_LORA))
self.time_aaa_w2 = nn.Parameter(torch.empty(D_AAA_LORA, args.dim_att))

D_KKK_LORA = 64
D_KKK_LORA = 32
self.time_kkk_w1 = nn.Parameter(torch.empty(args.n_embd, D_KKK_LORA))
self.time_kkk_w2 = nn.Parameter(torch.empty(D_KKK_LORA, args.dim_att))

Expand All @@ -124,9 +126,13 @@ def __init__(self, args, layer_id):
D_MA_LORA = 16
self.ma_w1 = nn.Parameter(torch.empty(args.n_embd, D_MA_LORA))
self.ma_w2 = nn.Parameter(torch.empty(D_MA_LORA, args.dim_att))
D_MV_LORA = 32
self.mv_w1 = nn.Parameter(torch.empty(args.n_embd, D_MV_LORA))
self.mv_w2 = nn.Parameter(torch.empty(D_MV_LORA, args.dim_att))

self.time_misc_k = nn.Parameter(torch.empty(1,1,args.n_embd))
self.time_misc_a = nn.Parameter(torch.empty(1,1,args.n_embd))
self.time_misc_v = nn.Parameter(torch.empty(1,1,args.n_embd))

self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False)
Expand All @@ -135,7 +141,7 @@ def __init__(self, args, layer_id):
self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
self.ln_x = nn.GroupNorm(self.n_head, args.dim_att, eps=(1e-5)*(args.head_size_divisor**2))

def forward(self, x):
def forward(self, x, v0):
B, T, C = x.size()
H = self.n_head
xx = self.time_shift(x) - x
Expand All @@ -156,7 +162,11 @@ def forward(self, x):
w = -F.softplus(-(self.time_decay + torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2)) - 0.5 # soft-clamp to (-inf, -0.5)
k = self.key(xk)
v = self.value(xv)
g = torch.tanh(xg @ self.gate_w1) @ self.gate_w2
if self.layer_id == 0:
v0 = v
else:
v = v + (v0 - v) * torch.sigmoid(self.time_misc_v + (xv @ self.mv_w1) @ self.mv_w2)
g = torch.sigmoid(xg @ self.gate_w1) @ self.gate_w2

kk = k + torch.tanh(xk @ self.time_kkk_w1) @ self.time_kkk_w2
kk = F.normalize(kk.view(B,T,H,-1), dim=-1, p=2.0).view(B,T,C)
Expand All @@ -171,39 +181,33 @@ def forward(self, x):
x = self.ln_x(x.view(B * T, C)).view(B, T, C)

x = x + ((r.view(B,T,H,-1)*k.view(B,T,H,-1)*self.time_faaaa).sum(dim=-1, keepdim=True) * v.view(B,T,H,-1)).view(B,T,C)

x = self.output(x * g)
return x
return x, v0

########################################################################################################
# RWKV ChannelMix
########################################################################################################

class RWKV_CMix_x060(nn.Module):
class RWKV_CMix_x070(nn.Module):
def __init__(self, args, layer_id):
super().__init__()
self.args = args
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))

with torch.no_grad():
ddd = torch.empty(1, 1, args.n_embd)
self.time_maa_k = nn.Parameter(ddd)
self.time_maa_r = nn.Parameter(ddd)
self.time_maa_k = nn.Parameter(torch.empty(1, 1, args.n_embd))

self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
self.key = nn.Linear(args.n_embd, args.n_embd * 4, bias=False)
self.value = nn.Linear(args.n_embd * 4, args.n_embd, bias=False)

def forward(self, x):
xx = self.time_shift(x) - x
xk = x + xx * self.time_maa_k
xr = x + xx * self.time_maa_r

k = self.key(xk)
k = torch.relu(k) ** 2
kv = self.value(k)
return torch.sigmoid(self.receptance(xr)) * kv

k = x + xx * self.time_maa_k
k = torch.relu(self.key(k)) ** 2
return self.value(k)

########################################################################################################
# RWKV Block
Expand All @@ -222,23 +226,24 @@ def __init__(self, args, layer_id):
self.ln0 = nn.LayerNorm(args.n_embd)

self.att = RWKV_Tmix_x070(args, layer_id)
self.ffn = RWKV_CMix_x060(args, layer_id)
self.ffn = RWKV_CMix_x070(args, layer_id)

def forward(self, x):
def forward(self, x, v0):

if self.layer_id == 0:
x = self.ln0(x)

x = x + self.att(self.ln1(x))
xx, v0 = self.att(self.ln1(x), v0)
x = x + xx
x = x + self.ffn(self.ln2(x))

if RESCALE_LAYER > 0:
if (self.layer_id+1) % RESCALE_LAYER == 0:
x = x / 2
# if RESCALE_LAYER > 0:
# if (self.layer_id+1) % RESCALE_LAYER == 0:
# x = x / 2
# if self.layer_id == args.n_layer-1:
# print(torch.min(x).item(), torch.max(x).item())

return x
return x, v0

########################################################################################################
# RWKV Model
Expand Down Expand Up @@ -266,8 +271,9 @@ def forward(self, idx):

x = self.emb(idx)

v0 = torch.empty_like(x)
for block in self.blocks:
x = block(x)
x, v0 = block(x, v0)

x = self.ln_out(x)
x = self.head(x)
Expand All @@ -285,11 +291,11 @@ def forward(self, idx):

if '.time_faaaa' in k: model_params[k] = model_params[k].reshape(-1, args.head_size_a)

if RESCALE_LAYER > 0:
if 'att.output.weight' in k:
model_params[k] = model_params[k] / (2 ** int(layer_id // RESCALE_LAYER))
if 'ffn.value.weight' in k:
model_params[k] = model_params[k] / (2 ** int(layer_id // RESCALE_LAYER))
# if RESCALE_LAYER > 0:
# if 'att.output.weight' in k:
# model_params[k] = model_params[k] / (2 ** int(layer_id // RESCALE_LAYER))
# if 'ffn.value.weight' in k:
# model_params[k] = model_params[k] / (2 ** int(layer_id // RESCALE_LAYER))

with torch.no_grad():

Expand Down

0 comments on commit c453b42

Please sign in to comment.