Skip to content

Commit

Permalink
testing
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed Jan 30, 2023
1 parent 8bf7061 commit f79d082
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 22 deletions.
26 changes: 13 additions & 13 deletions RWKV-v4neo/src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,20 @@ def __init__(self, args):

if args.data_type == "binidx":
self.vocab_size = args.vocab_size
print("Current vocab size =", self.vocab_size, "(make sure it's correct)")
rank_zero_info(f"Current vocab size = {self.vocab_size} (make sure it's correct)")

if args.data_file.endswith('/'):
d_all = []
for p in os.listdir(args.data_file):
if p.endswith(".idx"):
d_all += [p[:-4]]
d_all.sort()
print(d_all)
rank_zero_info(d_all)
exit(0)
else:
self.data = MMapIndexedDataset(args.data_file)
self.data_size = len(self.data._bin_buffer) // 2
print(f"Data has {self.data_size} tokens.")
rank_zero_info(f"Data has {self.data_size} tokens.")

if args.my_qa_mask == 1:
self.data_pile = MMapIndexedDataset('/fsx/BlinkDL/pile/pile_20B_tokenizer_text_document')
Expand All @@ -40,31 +40,31 @@ def __init__(self, args):
# assert self.data_size == 332115325534 and self.vocab_size == 50277
self.samples_per_epoch = args.epoch_steps * args.real_bsz
assert self.samples_per_epoch == 40320
print(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########")
rank_zero_info(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########")
dataset_slot = self.data_size // args.ctx_len
assert MaybeIsPrime(args.magic_prime)
assert args.magic_prime % 3 == 2
assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1
elif args.data_type == "numpy":
self.data = np.load(args.data_file).astype("int")
self.vocab_size = args.vocab_size
print("Current vocab size =", self.vocab_size, "(make sure it's correct)")
rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)")
self.data_size = len(self.data)
print(f"Data has {self.data_size} tokens.")
rank_zero_info(f"Data has {self.data_size} tokens.")
elif args.data_type == "uint16":
self.data = np.fromfile(args.data_file, dtype=np.uint16).astype("int32").reshape(-1, args.my_sample_len)
self.vocab_size = args.vocab_size
print("Current vocab size =", self.vocab_size, "(make sure it's correct)")
rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)")
self.data_size = self.data.shape[0]
print(f"Data has {self.data_size} samples.")
rank_zero_info(f"Data has {self.data_size} samples.")
elif args.data_type == "wds_img":
self.vocab_size = -1
self.data_size = -1
self.data = None
self.error_count = 0
else:
if args.data_type == "dummy":
print("Building dummy data...")
rank_zero_info("Building dummy data...")
self.data = ""
for i in range(100000):
aa = (i) % 10000
Expand All @@ -73,13 +73,13 @@ def __init__(self, args):
self.data += f".{aa}+{bb}={cc}."
else:
self.data = open(args.data_file, "r", encoding=args.data_type).read()
print("Building token list...")
rank_zero_info("Building token list...")
unique = sorted(list(set(self.data)))
self.vocab_size = len(unique)
# print()
# rank_zero_info()
# for u in unique:
# print(u, end=' ')
# print('\n\n')
# rank_zero_info('\n\n')
xx = 0
xxObj = {}
for u in unique:
Expand All @@ -88,7 +88,7 @@ def __init__(self, args):
with open(f"{args.proj_dir}/vocab.json", "w", encoding="utf-16le") as vocab_file:
vocab_file.write(json.dumps(xxObj, ensure_ascii=False))
self.data_size = len(self.data)
print("Data has %d tokens, %d vocab size." % (self.data_size, self.vocab_size))
rank_zero_info(f"Data has {self.data_size} tokens, {self.vocab_size} vocab size.")
self.stoi = {ch: i for i, ch in enumerate(unique)}
self.itos = {i: ch for i, ch in enumerate(unique)}

Expand Down
14 changes: 11 additions & 3 deletions RWKV-v4neo/src/trainer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
import os, math, time, datetime
import os, math, time, datetime, subprocess
import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only

def my_save(dd, ff):
if '14b-run1' not in ff:
torch.save(dd, ff)
else:
fn = ff.split('/')[-1]
fff = '/dev/shm/' + fn
torch.save(dd, fff)
subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b/{fn} --quiet", shell=True)

class train_callback(pl.Callback):
def __init__(self, args):
Expand Down Expand Up @@ -100,7 +108,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if args.magic_prime > 0:
if int(real_step) == int(args.magic_prime * (1 + args.my_qa_mask) // args.real_bsz) - 1:
to_save_dict = pl_module.state_dict()
torch.save(
my_save(
to_save_dict,
f"{args.proj_dir}/rwkv-final.pth",
)
Expand Down Expand Up @@ -128,7 +136,7 @@ def on_train_epoch_end(self, trainer, pl_module):
else:
to_save_dict = pl_module.state_dict()
try:
torch.save(
my_save(
to_save_dict,
f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth",
)
Expand Down
23 changes: 17 additions & 6 deletions RWKV-v4neo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
if __name__ == "__main__":
from argparse import ArgumentParser
from pytorch_lightning import Trainer
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only

print("########## work in progress ##########")
rank_zero_info("########## work in progress ##########")

########################################################################################################
#
Expand Down Expand Up @@ -101,7 +102,7 @@
parser.add_argument("--load_partial", default=0, type=int)
parser.add_argument("--magic_prime", default=0, type=int)
parser.add_argument("--my_qa_mask", default=0, type=int)
parser.add_argument("--my_testing", default=0, type=int)
parser.add_argument("--my_testing", default='', type=str)

parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()
Expand All @@ -115,7 +116,6 @@
import deepspeed
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only

if args.random_seed >= 0:
print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3)
Expand All @@ -138,6 +138,7 @@
args.betas = (args.beta1, args.beta2)
args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz
os.environ["RWKV_T_MAX"] = str(args.ctx_len)
os.environ["RWKV_MY_TESTING"] = args.my_testing

if args.data_type == "wds_img":
args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}"
Expand Down Expand Up @@ -276,19 +277,19 @@
generate_init_weight(model, init_weight_name) # save initial weights
args.load_model = init_weight_name

print(f"########## Loading {args.load_model}... ##########")
rank_zero_info(f"########## Loading {args.load_model}... ##########")
try:
load_dict = torch.load(args.load_model, map_location="cpu")
except:
print(f"Bad checkpoint {args.load_model}")
rank_zero_info(f"Bad checkpoint {args.load_model}")
if args.my_pile_stage >= 2: # try again using another checkpoint
max_p = args.my_pile_prev_p
if max_p == -1:
args.load_model = f"{args.proj_dir}/rwkv-init.pth"
else:
args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth"
args.epoch_begin = max_p + 1
print(f"Trying {args.load_model}")
rank_zero_info(f"Trying {args.load_model}")
load_dict = torch.load(args.load_model, map_location="cpu")

if args.load_partial == 1:
Expand All @@ -302,6 +303,16 @@
args,
callbacks=[train_callback(args)],
)

if trainer.global_rank == 0:
for n in model.state_dict():
shape = model.state_dict()[n].shape
shape = [i for i in shape if i != 1]
if len(shape) > 1:
print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}")
else:
print(f"{str(shape[0]).ljust(5)} {n}")

if "deepspeed" in args.strategy:
trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000
Expand Down

0 comments on commit f79d082

Please sign in to comment.