Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed Feb 12, 2023
1 parent c7b1900 commit e2ec7ae
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 41 deletions.
4 changes: 2 additions & 2 deletions RWKV-v4neo/src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, args):
self.data_size = len(self.data._bin_buffer) // 2
rank_zero_info(f"Data has {self.data_size} tokens.")

if args.my_qa_mask == 1:
if args.my_qa_mask > 0:
self.data_pile = MMapIndexedDataset('/fsx/BlinkDL/pile/pile_20B_tokenizer_text_document')
self.data_pile_size = len(self.data_pile._bin_buffer) // 2

Expand Down Expand Up @@ -156,7 +156,7 @@ def worker_seed():
if args.my_pile_stage > 0:
ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank

if args.my_qa_mask == 1:
if args.my_qa_mask > 0:
ii_orig = ii
if ii % 2 == 0:
ii = (ii // 2) * args.magic_prime
Expand Down
53 changes: 18 additions & 35 deletions RWKV-v4neo/src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,14 @@ def __init__(self, args, layer_id):
self.layer_id = layer_id
self.ctx_len = args.ctx_len
self.n_embd = args.n_embd
self.my_testing = self.args.my_testing

with torch.no_grad(): # fancy init
ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0

ddd = torch.ones(1, 1, args.n_embd)
for i in range(args.n_embd):
ddd[0, 0, i] = i / args.n_embd

# fancy time_decay
decay_speed = torch.ones(args.dim_att)
for h in range(args.dim_att):
Expand All @@ -126,12 +128,9 @@ def __init__(self, args, layer_id):
self.time_first = nn.Parameter(torch.ones(args.dim_att) * math.log(0.3) + zigzag)

# fancy time_mix
x = torch.ones(1, 1, args.n_embd)
for i in range(args.n_embd):
x[0, 0, i] = i / args.n_embd
self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0))
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
self.time_mix_v = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))

self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.key = nn.Linear(args.n_embd, args.dim_att, bias=False)
Expand All @@ -147,24 +146,17 @@ def __init__(self, args, layer_id):
self.vv = nn.Linear(args.n_embd, d_qkv, bias=False)
self.oo = nn.Linear(d_qkv, args.n_embd, bias=False)
with torch.no_grad():
x = torch.ones(1, 1, args.n_embd)
for i in range(args.n_embd):
x[0, 0, i] = i / args.n_embd
self.time_mix_qq = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
self.time_mix_kk = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
self.time_mix_vv = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
self.time_mix_qq = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
self.time_mix_kk = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
self.time_mix_vv = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)

if 'a' not in os.environ["RWKV_MY_TESTING"]:
@MyFunction
def jit_func(self, x):

# Mix x with the previous timestep to produce xk, xv, xr
xx = self.time_shift(x)
xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)

# Use xk, xv, xr to produce k, v, r
k = self.key(xk)
v = self.value(xv)
r = self.receptance(xr)
Expand All @@ -188,25 +180,20 @@ def QKV(self, q, k, v):

@MyFunction
def jit_funcQKV(self, x):
# Mix x with the previous timestep to produce xk, xv, xr
xx = self.time_shift(x)
xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
xqq = x * self.time_mix_qq + xx * (1 - self.time_mix_qq)
xkk = x * self.time_mix_kk + xx * (1 - self.time_mix_kk)
xvv = x * self.time_mix_vv + xx * (1 - self.time_mix_vv)

# Use xk, xv, xr to produce k, v, r
k = self.key(xk)
v = self.value(xv)
r = self.receptance(xr)
sr = torch.sigmoid(r)

qq = self.qq(xqq)
kk = self.kk(xkk)
vv = self.vv(xvv)

return sr, k, v, qq, kk, vv

def forward(self, x):
Expand All @@ -223,19 +210,16 @@ def __init__(self, args, layer_id):
super().__init__()
self.args = args
self.layer_id = layer_id
self.my_testing = self.args.my_testing
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))

with torch.no_grad(): # fancy init of time_mix
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0

x = torch.ones(1, 1, args.n_embd)
ddd = torch.ones(1, 1, args.n_embd)
for i in range(args.n_embd):
x[0, 0, i] = i / args.n_embd

self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))

ddd[0, 0, i] = i / args.n_embd
self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
self.time_mix_r = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))

self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
Expand All @@ -255,7 +239,6 @@ def __init__(self, args, layer_id):
super().__init__()
self.args = args
self.layer_id = layer_id
self.my_testing = self.args.my_testing
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))

with torch.no_grad():
Expand Down Expand Up @@ -478,7 +461,7 @@ def forward(self, idx):

def training_step(self, batch, batch_idx):
args = self.args
if args.my_qa_mask == 0:
if args.my_qa_mask != 1:
idx, targets = batch
logits = self(idx)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
Expand Down
26 changes: 22 additions & 4 deletions RWKV-v4neo/src/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,32 @@ def generate_init_weight(model, init_weight_name):
mm = model.generate_init_weight()

if model.args.my_pile_stage == 1:
try:
if len(model.args.load_model) > 0:
print(f"Combine weights from {model.args.load_model}...")
load_dict = torch.load(model.args.load_model, map_location="cpu")
for k in load_dict:
assert k in mm
mm[k] = load_dict[k].reshape(mm[k].shape)
except:
print(f"\n\n!!! FAIL !!!\n\n")
src = load_dict[k]
try:
mm[k] = src.reshape(mm[k].shape)
except:
tmp = mm[k].squeeze().clone()
print(k, src.shape, '-->', mm[k].shape)
ss = src.shape[0]
dd = tmp.shape[0]
for i in range(dd):
pos = i / dd * ss
if pos >= ss - 1:
tmp[i] = src[ss-1]
else:
p0 = int(math.floor(pos))
ii = pos - p0
tmp[i] = src[p0] * (1-ii) + src[p0+1] * (ii)
mm[k] = tmp.reshape(mm[k].shape)
sss = src.squeeze().float().cpu().numpy()
print(sss[:10], '...', sss[-10:])
mmm = mm[k].squeeze().float().cpu().numpy()
print(mmm[:10], '...', mmm[-10:])

print(f"Save to {init_weight_name}...")
torch.save(mm, init_weight_name)
Expand Down

0 comments on commit e2ec7ae

Please sign in to comment.