Skip to content

Commit

Permalink
misc
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed May 20, 2023
1 parent a8a6809 commit f8c5f6c
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 13 deletions.
2 changes: 1 addition & 1 deletion RWKV-v4neo/src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, args):
self.data = MMapIndexedDataset(args.data_file)
self.data_size = len(self.data._bin_buffer) // self.data._index._dtype_size
rank_zero_info(f"Data has {self.data_size} tokens.")
else:
elif args.my_pile_version == 2:
data_list = open(args.data_file, "r", encoding='utf-8').read().strip().split('\n')
data_list = [i.strip().split(' ') for i in data_list]
self.data = []
Expand Down
12 changes: 9 additions & 3 deletions RWKV-v4neo/src/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,19 @@
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only

def my_save(dd, ff):
if '14b-run1' not in ff:
torch.save(dd, ff)
else:
if '14b-run1' in ff:
fn = ff.split('/')[-1]
fff = '/dev/shm/' + fn
torch.save(dd, fff)
subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b-4k/{fn} --quiet", shell=True)
elif ('world/14b' in ff) or ('world/7b' in ff):
aa = ff.split('/')[1]
fn = ff.split('/')[-1]
fff = f'/dev/shm/{aa}-{fn}'
torch.save(dd, fff)
subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-world/{aa}-{fn} --quiet", shell=True)
else:
torch.save(dd, ff)

class train_callback(pl.Callback):
def __init__(self, args):
Expand Down
10 changes: 1 addition & 9 deletions RWKV-v4neo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,37 +164,29 @@
if args.my_pile_version == 1:
if args.ctx_len == 1024:
args.magic_prime = 324331313
args.epoch_count = 8043
elif args.ctx_len == 2048:
args.magic_prime = 162165671
args.epoch_count = 4021
elif args.ctx_len == 4096:
args.magic_prime = 81082817
args.epoch_count = 2010
elif args.ctx_len == 8192:
args.magic_prime = 40541399
args.epoch_count = 1005
else:
if args.ctx_len == 1024:
args.magic_prime = 1670239709
args.epoch_count = 41423
elif args.ctx_len == 2048:
args.magic_prime = 835119767
args.epoch_count = 20711
elif args.ctx_len == 4096:
args.magic_prime = 417559889
args.epoch_count = 10355
elif args.ctx_len == 6144:
args.magic_prime = 278373239
args.epoch_count = 6903
elif args.ctx_len == 8192:
args.magic_prime = 208779911
args.epoch_count = 5177
if args.my_pile_shift < 0:
args.my_pile_shift = 0

if magic_prime_bak > 0:
args.magic_prime = magic_prime_bak
args.epoch_count = args.magic_prime // 40320

args.epoch_steps = 40320 // args.real_bsz
assert args.epoch_steps * args.real_bsz == 40320
Expand Down

0 comments on commit f8c5f6c

Please sign in to comment.