diff --git a/RWKV-v4neo/src/dataset.py b/RWKV-v4neo/src/dataset.py index 93409132..4dff63e3 100644 --- a/RWKV-v4neo/src/dataset.py +++ b/RWKV-v4neo/src/dataset.py @@ -57,13 +57,13 @@ def __init__(self, args): elif args.data_type == "numpy": self.data = np.load(args.data_file).astype("int") self.vocab_size = args.vocab_size - rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)") + rank_zero_info(f"Current vocab size = {self.vocab_size} (make sure it's correct)") self.data_size = len(self.data) rank_zero_info(f"Data has {self.data_size} tokens.") elif args.data_type == "uint16": self.data = np.fromfile(args.data_file, dtype=np.uint16).astype("int32").reshape(-1, args.my_sample_len) self.vocab_size = args.vocab_size - rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)") + rank_zero_info(f"Current vocab size = {self.vocab_size} (make sure it's correct)") self.data_size = self.data.shape[0] rank_zero_info(f"Data has {self.data_size} samples.") elif args.data_type == "wds_img": diff --git a/RWKV-v4neo/src/trainer.py b/RWKV-v4neo/src/trainer.py index 72406d8c..a9fca323 100644 --- a/RWKV-v4neo/src/trainer.py +++ b/RWKV-v4neo/src/trainer.py @@ -52,6 +52,24 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): # if trainer.is_global_zero: # print(trainer.global_step, decay_step, decay_total, w_step, progress, lr) + if args.my_exit_tokens > 0: # cosine decay + if trainer.global_step < w_step: + lr = args.lr_init * (0.2 + 0.8 * trainer.global_step / w_step) + else: + real_tokens = real_step * args.ctx_len * args.real_bsz + warmup_tokens = w_step * args.ctx_len * args.real_bsz + progress = (real_tokens - warmup_tokens) / (args.my_exit_tokens - warmup_tokens) + progress = max(0, min(1, progress)) + lr_final_factor = 0.1 + lr_mult = (0.5 + lr_final_factor / 2) + (0.5 - lr_final_factor / 2) * math.cos(math.pi * progress) + lr = args.lr_init * lr_mult + if progress >= 1: + my_save( + pl_module.state_dict(), + f"{args.proj_dir}/rwkv-final.pth", + ) + exit(0) + for param_group in trainer.optimizers[0].param_groups: if args.layerwise_lr > 0: param_group["lr"] = lr * param_group["my_lr_scale"] diff --git a/RWKV-v4neo/train.py b/RWKV-v4neo/train.py index a2aa7e8d..a05c1367 100644 --- a/RWKV-v4neo/train.py +++ b/RWKV-v4neo/train.py @@ -81,6 +81,7 @@ parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence parser.add_argument("--adam_eps", default=1e-8, type=float) parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower + parser.add_argument("--dropout", default=0, type=float) parser.add_argument("--my_pile_version", default=1, type=int) # my special pile version parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode @@ -108,13 +109,14 @@ parser.add_argument("--my_random_steps", default=0, type=int) parser.add_argument("--my_testing", default='', type=str) parser.add_argument("--my_exit", default=99999999, type=int) + parser.add_argument("--my_exit_tokens", default=-1, type=int) parser = Trainer.add_argparse_args(parser) args = parser.parse_args() ######################################################################################################## - import os, warnings, math, datetime, sys, time, importlib + import os, warnings, math, datetime, sys, time import numpy as np import torch from torch.utils.data import DataLoader @@ -186,12 +188,15 @@ if magic_prime_bak > 0: args.magic_prime = magic_prime_bak - args.epoch_count = args.magic_prime // 40320 + if args.my_qa_mask == 2: + args.epoch_count = 2 * args.magic_prime // 40320 + else: + args.epoch_count = args.magic_prime // 40320 args.epoch_steps = 40320 // args.real_bsz assert args.epoch_steps * args.real_bsz == 40320 - if args.my_pile_stage == 2: - assert args.lr_final == args.lr_init + # if args.my_pile_stage == 2: + # assert args.lr_final == args.lr_init if args.my_pile_stage >= 2: # find latest saved model list_p = [] for p in os.listdir(args.proj_dir): @@ -220,6 +225,11 @@ samples_per_epoch = args.epoch_steps * args.real_bsz tokens_per_epoch = samples_per_epoch * args.ctx_len + try: + deepspeed_version = deepspeed.__version__ + except: + deepspeed_version = None + pass rank_zero_info( f""" ############################################################################ @@ -237,7 +247,7 @@ # Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps} # # Found torch {torch.__version__}, recommend 1.13.1+cu117 or newer -# Found deepspeed {deepspeed.__version__ if importlib.util.find_spec('deepspeed') else 'None'}, recommend 0.7.0 (faster than newer versions) +# Found deepspeed {deepspeed_version}, recommend 0.7.0 (faster than newer versions) # Found pytorch_lightning {pl.__version__}, recommend 1.9.1 or newer # ############################################################################ @@ -290,8 +300,12 @@ from src.model_img import RWKV_IMG model = RWKV_IMG(args) else: - from src.model import RWKV - model = RWKV(args) + if args.dropout > 0: + from src.model_drop2 import RWKV + model = RWKV(args) + else: + from src.model import RWKV + model = RWKV(args) if len(args.load_model) == 0 or args.my_pile_stage == 1: # shall we build the initial weights? init_weight_name = f"{args.proj_dir}/rwkv-init.pth"