diff --git a/RWKV-v4neo/src/dataset.py b/RWKV-v4neo/src/dataset.py index 089da07d..342ac1be 100644 --- a/RWKV-v4neo/src/dataset.py +++ b/RWKV-v4neo/src/dataset.py @@ -17,7 +17,7 @@ def __init__(self, args): if args.data_type == "binidx": self.vocab_size = args.vocab_size - print("Current vocab size =", self.vocab_size, "(make sure it's correct)") + rank_zero_info(f"Current vocab size = {self.vocab_size} (make sure it's correct)") if args.data_file.endswith('/'): d_all = [] @@ -25,12 +25,12 @@ def __init__(self, args): if p.endswith(".idx"): d_all += [p[:-4]] d_all.sort() - print(d_all) + rank_zero_info(d_all) exit(0) else: self.data = MMapIndexedDataset(args.data_file) self.data_size = len(self.data._bin_buffer) // 2 - print(f"Data has {self.data_size} tokens.") + rank_zero_info(f"Data has {self.data_size} tokens.") if args.my_qa_mask == 1: self.data_pile = MMapIndexedDataset('/fsx/BlinkDL/pile/pile_20B_tokenizer_text_document') @@ -40,7 +40,7 @@ def __init__(self, args): # assert self.data_size == 332115325534 and self.vocab_size == 50277 self.samples_per_epoch = args.epoch_steps * args.real_bsz assert self.samples_per_epoch == 40320 - print(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########") + rank_zero_info(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########") dataset_slot = self.data_size // args.ctx_len assert MaybeIsPrime(args.magic_prime) assert args.magic_prime % 3 == 2 @@ -48,15 +48,15 @@ def __init__(self, args): elif args.data_type == "numpy": self.data = np.load(args.data_file).astype("int") self.vocab_size = args.vocab_size - print("Current vocab size =", self.vocab_size, "(make sure it's correct)") + rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)") self.data_size = len(self.data) - print(f"Data has {self.data_size} tokens.") + rank_zero_info(f"Data has {self.data_size} tokens.") elif args.data_type == "uint16": self.data = np.fromfile(args.data_file, dtype=np.uint16).astype("int32").reshape(-1, args.my_sample_len) self.vocab_size = args.vocab_size - print("Current vocab size =", self.vocab_size, "(make sure it's correct)") + rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)") self.data_size = self.data.shape[0] - print(f"Data has {self.data_size} samples.") + rank_zero_info(f"Data has {self.data_size} samples.") elif args.data_type == "wds_img": self.vocab_size = -1 self.data_size = -1 @@ -64,7 +64,7 @@ def __init__(self, args): self.error_count = 0 else: if args.data_type == "dummy": - print("Building dummy data...") + rank_zero_info("Building dummy data...") self.data = "" for i in range(100000): aa = (i) % 10000 @@ -73,13 +73,13 @@ def __init__(self, args): self.data += f".{aa}+{bb}={cc}." else: self.data = open(args.data_file, "r", encoding=args.data_type).read() - print("Building token list...") + rank_zero_info("Building token list...") unique = sorted(list(set(self.data))) self.vocab_size = len(unique) - # print() + # rank_zero_info() # for u in unique: # print(u, end=' ') - # print('\n\n') + # rank_zero_info('\n\n') xx = 0 xxObj = {} for u in unique: @@ -88,7 +88,7 @@ def __init__(self, args): with open(f"{args.proj_dir}/vocab.json", "w", encoding="utf-16le") as vocab_file: vocab_file.write(json.dumps(xxObj, ensure_ascii=False)) self.data_size = len(self.data) - print("Data has %d tokens, %d vocab size." % (self.data_size, self.vocab_size)) + rank_zero_info(f"Data has {self.data_size} tokens, {self.vocab_size} vocab size.") self.stoi = {ch: i for i, ch in enumerate(unique)} self.itos = {i: ch for i, ch in enumerate(unique)} diff --git a/RWKV-v4neo/src/trainer.py b/RWKV-v4neo/src/trainer.py index ac82d382..4b616218 100644 --- a/RWKV-v4neo/src/trainer.py +++ b/RWKV-v4neo/src/trainer.py @@ -1,9 +1,17 @@ -import os, math, time, datetime +import os, math, time, datetime, subprocess import torch from torch.utils.data import DataLoader import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_info, rank_zero_only +def my_save(dd, ff): + if '14b-run1' not in ff: + torch.save(dd, ff) + else: + fn = ff.split('/')[-1] + fff = '/dev/shm/' + fn + torch.save(dd, fff) + subprocess.Popen(f" aws s3 mv {fff} s3://rwkv-14b/{fn} --quiet", shell=True) class train_callback(pl.Callback): def __init__(self, args): @@ -100,7 +108,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): if args.magic_prime > 0: if int(real_step) == int(args.magic_prime * (1 + args.my_qa_mask) // args.real_bsz) - 1: to_save_dict = pl_module.state_dict() - torch.save( + my_save( to_save_dict, f"{args.proj_dir}/rwkv-final.pth", ) @@ -128,7 +136,7 @@ def on_train_epoch_end(self, trainer, pl_module): else: to_save_dict = pl_module.state_dict() try: - torch.save( + my_save( to_save_dict, f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth", ) diff --git a/RWKV-v4neo/train.py b/RWKV-v4neo/train.py index 073a5f7e..9b63deaf 100644 --- a/RWKV-v4neo/train.py +++ b/RWKV-v4neo/train.py @@ -5,8 +5,9 @@ if __name__ == "__main__": from argparse import ArgumentParser from pytorch_lightning import Trainer + from pytorch_lightning.utilities import rank_zero_info, rank_zero_only - print("########## work in progress ##########") + rank_zero_info("########## work in progress ##########") ######################################################################################################## # @@ -101,7 +102,7 @@ parser.add_argument("--load_partial", default=0, type=int) parser.add_argument("--magic_prime", default=0, type=int) parser.add_argument("--my_qa_mask", default=0, type=int) - parser.add_argument("--my_testing", default=0, type=int) + parser.add_argument("--my_testing", default='', type=str) parser = Trainer.add_argparse_args(parser) args = parser.parse_args() @@ -115,7 +116,6 @@ import deepspeed import pytorch_lightning as pl from pytorch_lightning import seed_everything - from pytorch_lightning.utilities import rank_zero_info, rank_zero_only if args.random_seed >= 0: print(f"########## WARNING: GLOBAL SEED {args.random_seed} THIS WILL AFFECT MULTIGPU SAMPLING ##########\n" * 3) @@ -138,6 +138,7 @@ args.betas = (args.beta1, args.beta2) args.real_bsz = int(args.num_nodes) * int(args.devices) * args.micro_bsz os.environ["RWKV_T_MAX"] = str(args.ctx_len) + os.environ["RWKV_MY_TESTING"] = args.my_testing if args.data_type == "wds_img": args.run_name = f"v{args.my_img_version}-{args.my_img_size}-{args.my_img_bit}bit-{args.my_img_clip}x{args.my_img_clip_scale}" @@ -276,11 +277,11 @@ generate_init_weight(model, init_weight_name) # save initial weights args.load_model = init_weight_name - print(f"########## Loading {args.load_model}... ##########") + rank_zero_info(f"########## Loading {args.load_model}... ##########") try: load_dict = torch.load(args.load_model, map_location="cpu") except: - print(f"Bad checkpoint {args.load_model}") + rank_zero_info(f"Bad checkpoint {args.load_model}") if args.my_pile_stage >= 2: # try again using another checkpoint max_p = args.my_pile_prev_p if max_p == -1: @@ -288,7 +289,7 @@ else: args.load_model = f"{args.proj_dir}/rwkv-{max_p}.pth" args.epoch_begin = max_p + 1 - print(f"Trying {args.load_model}") + rank_zero_info(f"Trying {args.load_model}") load_dict = torch.load(args.load_model, map_location="cpu") if args.load_partial == 1: @@ -302,6 +303,16 @@ args, callbacks=[train_callback(args)], ) + + if trainer.global_rank == 0: + for n in model.state_dict(): + shape = model.state_dict()[n].shape + shape = [i for i in shape if i != 1] + if len(shape) > 1: + print(f"{str(shape[0]).ljust(5)} {str(shape[1]).ljust(5)} {n}") + else: + print(f"{str(shape[0]).ljust(5)} {n}") + if "deepspeed" in args.strategy: trainer.strategy.config["zero_optimization"]["allgather_bucket_size"] = args.ds_bucket_mb * 1000 * 1000 trainer.strategy.config["zero_optimization"]["reduce_bucket_size"] = args.ds_bucket_mb * 1000 * 1000