Skip to content

Commit

Permalink
better
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed Jan 12, 2023
1 parent 935d8d3 commit cf34026
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 6 deletions.
52 changes: 52 additions & 0 deletions RWKV-v4neo/src/binidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,58 @@ class MMapIndexedDataset(torch.utils.data.Dataset):
class Index(object):
_HDR_MAGIC = b"MMIDIDX\x00\x00"

@classmethod
def writer(cls, path, dtype):
class _Writer(object):
def __enter__(self):
self._file = open(path, "wb")

# Write Magic string so we can check the file format then opening it again.
self._file.write(cls._HDR_MAGIC)
# Write version number
# Little endian unsigned 64 Bit integer
self._file.write(struct.pack("<Q", 1))
# Little endian unsigned 8 Bit integer
self._file.write(struct.pack("<B", code(dtype)))

return self

@staticmethod
def _get_pointers(sizes):
dtype_size = dtype().itemsize
address = 0
pointers = []

for size in sizes:
pointers.append(address)
address += size * dtype_size

return pointers

def write(self, sizes, doc_idx):
pointers = self._get_pointers(sizes)

# Little endian unsigned 64 Bit integer
self._file.write(struct.pack("<Q", len(sizes)))
# Little endian unsigned 64 Bit integer
self._file.write(struct.pack("<Q", len(doc_idx)))

sizes = np.array(sizes, dtype=np.int32)
self._file.write(sizes.tobytes(order="C"))
del sizes

pointers = np.array(pointers, dtype=np.int64)
self._file.write(pointers.tobytes(order="C"))
del pointers

doc_idx = np.array(doc_idx, dtype=np.int64)
self._file.write(doc_idx.tobytes(order="C"))

def __exit__(self, exc_type, exc_val, exc_tb):
self._file.close()

return _Writer()

def __init__(self, path, skip_warmup=False):
with open(path, "rb") as stream:
magic_test = stream.read(9)
Expand Down
22 changes: 16 additions & 6 deletions RWKV-v4neo/src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################

import json, math, random
import json, math, random, os, sys
import numpy as np
import torch
from torch.utils.data import Dataset
Expand All @@ -16,21 +16,31 @@ def __init__(self, args):
self.args = args

if args.data_type == "binidx":
self.data = MMapIndexedDataset(args.data_file)
self.vocab_size = args.vocab_size
print("Current vocab size =", self.vocab_size, "(make sure it's correct)")
self.data_size = len(self.data._bin_buffer) // 2
print(f"Data has {self.data_size} tokens.")

if args.data_file.endswith('/'):
d_all = []
for p in os.listdir(args.data_file):
if p.endswith(".idx"):
d_all += [p[:-4]]
d_all.sort()
print(d_all)
exit(0)
else:
self.data = MMapIndexedDataset(args.data_file)
self.data_size = len(self.data._bin_buffer) // 2
print(f"Data has {self.data_size} tokens.")

if args.my_pile_stage > 0:
assert self.data_size == 332115325534 and self.vocab_size == 50277
# assert self.data_size == 332115325534 and self.vocab_size == 50277
self.samples_per_epoch = args.epoch_steps * args.real_bsz
assert self.samples_per_epoch == 40320
print(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########")
dataset_slot = self.data_size // args.ctx_len
assert MaybeIsPrime(args.magic_prime)
assert args.magic_prime % 3 == 2
assert args.magic_prime / dataset_slot > 0.999999 and args.magic_prime / dataset_slot <= 1
assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1
elif args.data_type == "numpy":
self.data = np.load(args.data_file).astype("int")
self.vocab_size = args.vocab_size
Expand Down
8 changes: 8 additions & 0 deletions RWKV-v4neo/src/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,14 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if kt_s > 0:
lll["kt/s"] = kt_s
trainer.my_wandb.log(lll, step=int(real_step))
if args.magic_prime > 0:
if int(real_step) == int(args.magic_prime // args.real_bsz) - 1:
to_save_dict = pl_module.state_dict()
torch.save(
to_save_dict,
f"{args.proj_dir}/rwkv-final.pth",
)


def on_train_epoch_start(self, trainer, pl_module):
args = self.args
Expand Down
5 changes: 5 additions & 0 deletions RWKV-v4neo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
parser.add_argument("--my_att_shift", default=1, type=int)
parser.add_argument("--my_pos_emb", default=0, type=int)
parser.add_argument("--load_partial", default=0, type=int)
parser.add_argument("--magic_prime", default=0, type=int)

parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()
Expand Down Expand Up @@ -145,6 +146,7 @@
os.makedirs(args.proj_dir)

if args.my_pile_stage > 0:
magic_prime_bak = args.magic_prime
if args.ctx_len == 1024:
args.magic_prime = 324331313
args.epoch_count = 8043
Expand All @@ -162,6 +164,9 @@
elif args.ctx_len == 4096:
args.my_pile_shift = 768

if magic_prime_bak > 0:
args.magic_prime = magic_prime_bak

args.epoch_steps = 40320 // args.real_bsz
assert args.epoch_steps * args.real_bsz == 40320
if args.my_pile_stage == 2:
Expand Down

0 comments on commit cf34026

Please sign in to comment.