Skip to content

Commit

Permalink
better names
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed Nov 28, 2024
1 parent 7384042 commit 02bc67a
Showing 1 changed file with 43 additions and 43 deletions.
86 changes: 43 additions & 43 deletions RWKV-v7/rwkv_v7_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,32 +128,32 @@ def __init__(self, args, layer_id):
N = self.head_size
C = args.n_embd

self.xx_r = nn.Parameter(torch.empty(1,1,C))
self.xx_w = nn.Parameter(torch.empty(1,1,C))
self.xx_k = nn.Parameter(torch.empty(1,1,C))
self.xx_v = nn.Parameter(torch.empty(1,1,C))
self.xx_a = nn.Parameter(torch.empty(1,1,C))
self.xx_g = nn.Parameter(torch.empty(1,1,C))
self.x_r = nn.Parameter(torch.empty(1,1,C))
self.x_w = nn.Parameter(torch.empty(1,1,C))
self.x_k = nn.Parameter(torch.empty(1,1,C))
self.x_v = nn.Parameter(torch.empty(1,1,C))
self.x_a = nn.Parameter(torch.empty(1,1,C))
self.x_g = nn.Parameter(torch.empty(1,1,C))

self.ww_b = nn.Parameter(torch.empty(1,1,C))
self.ww_w1 = nn.Parameter(torch.empty(C, D_DECAY_LORA))
self.ww_w2 = nn.Parameter(torch.empty(D_DECAY_LORA, C))
self.w0 = nn.Parameter(torch.empty(1,1,C))
self.w1 = nn.Parameter(torch.empty(C, D_DECAY_LORA))
self.w2 = nn.Parameter(torch.empty(D_DECAY_LORA, C))

self.aa_b = nn.Parameter(torch.empty(1,1,C))
self.aa_w1 = nn.Parameter(torch.empty(C, D_AAA_LORA))
self.aa_w2 = nn.Parameter(torch.empty(D_AAA_LORA, C))
self.a0 = nn.Parameter(torch.empty(1,1,C))
self.a1 = nn.Parameter(torch.empty(C, D_AAA_LORA))
self.a2 = nn.Parameter(torch.empty(D_AAA_LORA, C))

if layer_id > 0:
self.vv_b = nn.Parameter(torch.empty(1,1,C))
self.vv_w1 = nn.Parameter(torch.empty(C, D_MV_LORA))
self.vv_w2 = nn.Parameter(torch.empty(D_MV_LORA, C))
self.v0 = nn.Parameter(torch.empty(1,1,C))
self.v1 = nn.Parameter(torch.empty(C, D_MV_LORA))
self.v2 = nn.Parameter(torch.empty(D_MV_LORA, C))

self.gg_w1 = nn.Parameter(torch.empty(C, D_GATE_LORA))
self.gg_w2 = nn.Parameter(torch.empty(D_GATE_LORA, C))
self.g1 = nn.Parameter(torch.empty(C, D_GATE_LORA))
self.g2 = nn.Parameter(torch.empty(D_GATE_LORA, C))

self.kk_s = nn.Parameter(torch.empty(1,1,C))
self.ka_s = nn.Parameter(torch.empty(1,1,C))
self.rk_s = nn.Parameter(torch.empty(H,N))
self.k_k = nn.Parameter(torch.empty(1,1,C))
self.k_a = nn.Parameter(torch.empty(1,1,C))
self.r_k = nn.Parameter(torch.empty(H,N))

self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.receptance = nn.Linear(C, C, bias=False)
Expand All @@ -162,40 +162,40 @@ def __init__(self, args, layer_id):
self.output = nn.Linear(C, C, bias=False)
self.ln_x = nn.GroupNorm(H, C, eps=64e-5) # !!! notice eps value !!!

def forward(self, x, v0):
def forward(self, x, v_first):
B, T, C = x.size()
H = self.n_head
xx = self.time_shift(x) - x
xr = x + xx * self.xx_r
xw = x + xx * self.xx_w
xk = x + xx * self.xx_k
xv = x + xx * self.xx_v
xa = x + xx * self.xx_a
xg = x + xx * self.xx_g
xr = x + xx * self.x_r
xw = x + xx * self.x_w
xk = x + xx * self.x_k
xv = x + xx * self.x_v
xa = x + xx * self.x_a
xg = x + xx * self.x_g

r = self.receptance(xr)
w = -F.softplus(-(self.ww_b + torch.tanh(xw @ self.ww_w1) @ self.ww_w2)) - 0.5 # soft-clamp to (-inf, -0.5)
w = -F.softplus(-(self.w0 + torch.tanh(xw @ self.w1) @ self.w2)) - 0.5 # soft-clamp to (-inf, -0.5)
k = self.key(xk)
v = self.value(xv)
if self.layer_id == 0:
v0 = v
v_first = v
else:
v = v + (v0 - v) * torch.sigmoid(self.vv_b + (xv @ self.vv_w1) @ self.vv_w2)
a = torch.sigmoid(self.aa_b + (xa @ self.aa_w1) @ self.aa_w2) # a is "in-context learning rate"
g = torch.sigmoid(xg @ self.gg_w1) @ self.gg_w2
v = v + (v_first - v) * torch.sigmoid(self.v0 + (xv @ self.v1) @ self.v2)
a = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2) # a is "in-context learning rate"
g = torch.sigmoid(xg @ self.g1) @ self.g2

kk = k * self.kk_s
kk = k * self.k_k
kk = F.normalize(kk.view(B,T,H,-1), dim=-1, p=2.0).view(B,T,C)
k = k * (1 + (a-1) * self.ka_s)
k = k * (1 + (a-1) * self.k_a)

x = RWKV7_OP(r, w, k, v, -kk, kk*a)

x = self.ln_x(x.view(B * T, C)).view(B, T, C)

x = x + ((r.view(B,T,H,-1)*k.view(B,T,H,-1)*self.rk_s).sum(dim=-1, keepdim=True) * v.view(B,T,H,-1)).view(B,T,C)
x = x + ((r.view(B,T,H,-1)*k.view(B,T,H,-1)*self.r_k).sum(dim=-1, keepdim=True) * v.view(B,T,H,-1)).view(B,T,C)

x = self.output(x * g)
return x, v0
return x, v_first

########################################################################################################
# RWKV ChannelMix
Expand All @@ -209,15 +209,15 @@ def __init__(self, args, layer_id):
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))

with torch.no_grad():
self.xx_k = nn.Parameter(torch.empty(1, 1, args.n_embd))
self.x_k = nn.Parameter(torch.empty(1, 1, args.n_embd))

self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)

def forward(self, x):
xx = self.time_shift(x) - x

k = x + xx * self.xx_k
k = x + xx * self.x_k
k = torch.relu(self.key(k)) ** 2
return self.value(k)

Expand All @@ -240,16 +240,16 @@ def __init__(self, args, layer_id):
self.att = RWKV_Tmix_x070(args, layer_id)
self.ffn = RWKV_CMix_x070(args, layer_id)

def forward(self, x, v0):
def forward(self, x, v_first):

if self.layer_id == 0:
x = self.ln0(x)

xx, v0 = self.att(self.ln1(x), v0)
xx, v_first = self.att(self.ln1(x), v_first)
x = x + xx
x = x + self.ffn(self.ln2(x))

return x, v0
return x, v_first

########################################################################################################
# RWKV Model
Expand All @@ -271,9 +271,9 @@ def forward(self, idx):

x = self.emb(idx)

v0 = torch.empty_like(x)
v_first = torch.empty_like(x)
for block in self.blocks:
x, v0 = block(x, v0)
x, v_first = block(x, v_first)

x = self.ln_out(x)
x = self.head(x)
Expand Down

0 comments on commit 02bc67a

Please sign in to comment.