Skip to content

Commit

Permalink
misc
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed Jan 21, 2023
1 parent bc47cb9 commit b2a240d
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 12 deletions.
60 changes: 54 additions & 6 deletions RWKV-v4neo/src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ def __init__(self, args):
self.data_size = len(self.data._bin_buffer) // 2
print(f"Data has {self.data_size} tokens.")

if args.my_qa_mask == 1:
self.data_pile = MMapIndexedDataset('/fsx/BlinkDL/pile/pile_20B_tokenizer_text_document')
self.data_pile_size = len(self.data_pile._bin_buffer) // 2

if args.my_pile_stage > 0:
# assert self.data_size == 332115325534 and self.vocab_size == 50277
self.samples_per_epoch = args.epoch_steps * args.real_bsz
Expand Down Expand Up @@ -146,25 +150,69 @@ def worker_seed():
else:
ctx_len = args.ctx_len
req_len = ctx_len + 1
magic_prime = args.magic_prime
data = self.data

if args.my_pile_stage > 0:
ii = 1 + epoch * self.samples_per_epoch + (idx * world_size) + rank

if args.my_qa_mask == 1:
ii_orig = ii
if ii % 2 == 0:
ii = (ii // 2) * args.magic_prime
magic_prime = 324331313
data = self.data_pile
else:
ii = ii // 2

factor = (math.sqrt(5) - 1) / 2
factor = int(args.magic_prime * factor)
i = ((factor * ii * ii * ii) % args.magic_prime) * ctx_len
i = i + args.my_pile_shift
factor = int(magic_prime * factor)
i = ((factor * ii * ii * ii) % magic_prime) * ctx_len
if (args.my_qa_mask == 0) or (data == self.data_pile):
i = i + args.my_pile_shift
# print(f"epoch {epoch} idx {idx} rank {rank}/{world_size} ii {ii} pos {round(i / self.data_size, 3)}")
else:
# cheat: pick a random spot in dataset
i = np.random.randint(0, self.data_size - req_len)

if args.data_type == "binidx":
dix = self.data.get(idx=0, offset=i, length=req_len).astype(int)
dix = data.get(idx=0, offset=i, length=req_len).astype(int)
elif args.data_type == "numpy":
dix = self.data[i : i + req_len]
dix = data[i : i + req_len]
else:
dix = [self.stoi[s] for s in self.data[i : i + req_len]]
dix = [self.stoi[s] for s in data[i : i + req_len]]

if args.my_qa_mask == 1:
if data == self.data_pile:
z = [1] * ctx_len
else:
z = [0] * ctx_len
z_sum = 0
isGood = False
for i in range(3, ctx_len):
if dix[i] == 27 and dix[i-1] == 34 and dix[i-2] == 187 and dix[i-3] == 187:
isGood = True
if dix[i] == 0:
isGood = False
if isGood:
z[i] = 1
z_sum += 1
if z_sum == 0:
z = [1] * ctx_len
i = np.random.randint(0, self.data_pile_size - req_len)
dix = self.data_pile.get(idx=0, offset=i, length=req_len).astype(int)
z = torch.tensor(z, dtype=torch.bfloat16)

x = torch.tensor(dix[:-1], dtype=torch.long)
y = torch.tensor(dix[1:], dtype=torch.long)

# if ii_orig < 50:
# # if rank == 1:
# print('rank', rank, 'i', ii_orig, ii, i, 'x', x[:5], '...', x[-5:])
# else:
# exit(0)

if args.my_qa_mask == 1:
return x, y, z

return x, y
61 changes: 56 additions & 5 deletions RWKV-v4neo/src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(self, args, layer_id):
self.layer_id = layer_id
self.ctx_len = args.ctx_len
self.n_embd = args.n_embd

self.my_testing = self.args.my_testing
attn_sz = args.n_embd

with torch.no_grad(): # fancy init
Expand Down Expand Up @@ -142,6 +142,9 @@ def __init__(self, args, layer_id):

self.output = nn.Linear(attn_sz, args.n_embd, bias=False)

# if self.my_testing > 0:
# self.aaa = nn.Parameter(torch.zeros(1, 1, args.n_embd))

@MyFunction
def jit_func(self, x):

Expand Down Expand Up @@ -174,7 +177,7 @@ def __init__(self, args, layer_id):
super().__init__()
self.args = args
self.layer_id = layer_id

self.my_testing = self.args.my_testing
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))

with torch.no_grad(): # fancy init of time_mix
Expand All @@ -192,6 +195,12 @@ def __init__(self, args, layer_id):
self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
self.value = nn.Linear(hidden_sz, args.n_embd, bias=False)

# if self.my_testing in [1]:
# self.aaa = nn.Parameter(torch.zeros(1, 1, hidden_sz))
# elif self.my_testing in [2]:
# self.aaa = nn.Parameter(torch.zeros(1, 1, args.n_embd))


@MyFunction
def forward(self, x):
xx = self.time_shift(x)
Expand All @@ -205,6 +214,19 @@ def forward(self, x):
rkv = torch.sigmoid(self.receptance(xr)) * kv
return rkv

# k = self.key(xk)
# # if self.my_testing in [0, 2]:
# k = torch.square(torch.relu(k))
# # elif self.my_testing == 1:
# # k = torch.square(torch.relu(k)) + k * self.aaa
# kv = self.value(k)
# r = self.receptance(xr)
# # if self.my_testing == 0:
# r = torch.sigmoid(r)
# # elif self.my_testing == 2:
# # r = torch.sigmoid(r) + r * self.aaa
# rkv = r * kv
# return rkv

########################################################################################################
# The RWKV Model with our blocks
Expand Down Expand Up @@ -401,9 +423,38 @@ def forward(self, idx):

def training_step(self, batch, batch_idx):
args = self.args
idx, targets = batch
logits = self(idx)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
if args.my_qa_mask == 0:
idx, targets = batch
logits = self(idx)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
else:
idx, targets, mask = batch
mask = mask.view(-1)
sum_mask = torch.sum(mask).item()
# if sum_mask == 0:
# return torch.tensor([0.0], requires_grad=True)

logits = self(idx)
if sum_mask == mask.shape[0]:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
# print('rank', self.global_rank, 'loss', loss.item())
else:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), reduction='none')
# loss_raw = loss
loss = torch.sum(loss * mask) / sum_mask

# torch.set_printoptions(threshold=10000)
# if True: #self.global_rank == 1:
# tmp = ''
# sss = 0
# ccc = 0
# for i in range(mask.shape[0]):
# if mask[i] > 0:
# tmp += str(idx.view(-1)[i].item()) + ','
# sss += loss_raw.view(-1)[i].float().item()
# ccc += 1
# print('rank', self.global_rank, 'loss', loss.item(), 'lavg', sss / ccc)#, 'tmp', tmp, 'input', idx)

return L2Wrap.apply(loss, logits)

def training_step_end(self, batch_parts):
Expand Down
2 changes: 1 addition & 1 deletion RWKV-v4neo/src/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
lll["kt/s"] = kt_s
trainer.my_wandb.log(lll, step=int(real_step))
if args.magic_prime > 0:
if int(real_step) == int(args.magic_prime // args.real_bsz) - 1:
if int(real_step) == int(args.magic_prime * (1 + args.my_qa_mask) // args.real_bsz) - 1:
to_save_dict = pl_module.state_dict()
torch.save(
to_save_dict,
Expand Down
1 change: 1 addition & 0 deletions RWKV-v4neo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
parser.add_argument("--my_pos_emb", default=0, type=int)
parser.add_argument("--load_partial", default=0, type=int)
parser.add_argument("--magic_prime", default=0, type=int)
parser.add_argument("--my_qa_mask", default=0, type=int)
parser.add_argument("--my_testing", default=0, type=int)

parser = Trainer.add_argparse_args(parser)
Expand Down

0 comments on commit b2a240d

Please sign in to comment.