Skip to content

Commit

Permalink
pile v2
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed Mar 20, 2023
1 parent 3d2b04b commit 1945cb5
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 25 deletions.
38 changes: 26 additions & 12 deletions RWKV-v4neo/src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,26 @@ def __init__(self, args):
self.vocab_size = args.vocab_size
rank_zero_info(f"Current vocab size = {self.vocab_size} (make sure it's correct)")

if args.data_file.endswith('/'):
d_all = []
for p in os.listdir(args.data_file):
if p.endswith(".idx"):
d_all += [p[:-4]]
d_all.sort()
rank_zero_info(d_all)
exit(0)
else:
if args.my_pile_version == 1:
self.data = MMapIndexedDataset(args.data_file)
self.data_size = len(self.data._bin_buffer) // 2
self.data_size = len(self.data._bin_buffer) // self.data._index._dtype_size
rank_zero_info(f"Data has {self.data_size} tokens.")
else:
data_list = open(args.data_file, "r", encoding='utf-8').read().strip().split('\n')
data_list = [i.strip().split(' ') for i in data_list]
self.data = []
self.data_size = int(data_list[-1][-1])
rank_zero_info(f"Data has {self.data_size} chunks.")
for d in data_list:
data = MMapIndexedDataset(d[0])
data_size = len(data._bin_buffer) // data._index._dtype_size
assert (data_size - args.ctx_len) == int(d[1])
self.data += [[int(d[-1]), int(d[1]), data]]
# rank_zero_info(self.data)

if args.my_qa_mask > 0:
self.data_pile = MMapIndexedDataset('/fsx/BlinkDL/pile/pile_20B_tokenizer_text_document')
self.data_pile_size = len(self.data_pile._bin_buffer) // 2
self.data_pile_size = len(self.data_pile._bin_buffer) // self.data._index._dtype_size

if args.my_pile_stage > 0:
# assert self.data_size == 332115325534 and self.vocab_size == 50277
Expand Down Expand Up @@ -184,7 +188,17 @@ def worker_seed():
i = np.random.randint(0, self.data_size - req_len)

if args.data_type == "binidx":
dix = data.get(idx=0, offset=i, length=req_len).astype(int)
if args.my_pile_version == 1:
dix = data.get(idx=0, offset=i, length=req_len).astype(int)
else:
# self.data : cutoff, chunk_count, data
for j in range(len(data)):
if i < data[j][0]:
ii = i
i = (i - (data[j-1][0] if j > 0 else 0)) % data[j][1]
dix = data[j][2].get(idx=0, offset=i, length=req_len).astype(int)
# print(ii, j, i)
break
elif args.data_type == "numpy":
dix = data[i : i + req_len]
else:
Expand Down
36 changes: 23 additions & 13 deletions RWKV-v4neo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,9 @@
parser.add_argument("--beta1", default=0.9, type=float)
parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence
parser.add_argument("--adam_eps", default=1e-8, type=float)

parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower

parser.add_argument("--my_pile_version", default=1, type=int) # my special pile version
parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
parser.add_argument("--my_pile_shift", default=-1, type=int) # my special pile mode - text shift
parser.add_argument("--my_pile_edecay", default=0, type=int)
Expand Down Expand Up @@ -157,18 +158,27 @@

if args.my_pile_stage > 0:
magic_prime_bak = args.magic_prime
if args.ctx_len == 1024:
args.magic_prime = 324331313
args.epoch_count = 8043
elif args.ctx_len == 2048:
args.magic_prime = 162165671
args.epoch_count = 4021
elif args.ctx_len == 4096:
args.magic_prime = 81082817
args.epoch_count = 2010
elif args.ctx_len == 8192:
args.magic_prime = 40541399
args.epoch_count = 1005

if args.my_pile_version == 1:
if args.ctx_len == 1024:
args.magic_prime = 324331313
args.epoch_count = 8043
elif args.ctx_len == 2048:
args.magic_prime = 162165671
args.epoch_count = 4021
elif args.ctx_len == 4096:
args.magic_prime = 81082817
args.epoch_count = 2010
elif args.ctx_len == 8192:
args.magic_prime = 40541399
args.epoch_count = 1005
else:
if args.ctx_len == 4096:
args.magic_prime = 423736637
args.epoch_count = 10508
elif args.ctx_len == 8192:
args.magic_prime = 211868309
args.epoch_count = 5253
if args.my_pile_shift < 0:
args.my_pile_shift = 0

Expand Down

0 comments on commit 1945cb5

Please sign in to comment.