Skip to content

Commit

Permalink
misc
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed Jun 18, 2023
1 parent 50c6855 commit cca1b5e
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 9 deletions.
4 changes: 2 additions & 2 deletions RWKV-v4neo/src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,13 @@ def __init__(self, args):
elif args.data_type == "numpy":
self.data = np.load(args.data_file).astype("int")
self.vocab_size = args.vocab_size
rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)")
rank_zero_info(f"Current vocab size = {self.vocab_size} (make sure it's correct)")
self.data_size = len(self.data)
rank_zero_info(f"Data has {self.data_size} tokens.")
elif args.data_type == "uint16":
self.data = np.fromfile(args.data_file, dtype=np.uint16).astype("int32").reshape(-1, args.my_sample_len)
self.vocab_size = args.vocab_size
rank_zero_info("Current vocab size =", self.vocab_size, "(make sure it's correct)")
rank_zero_info(f"Current vocab size = {self.vocab_size} (make sure it's correct)")
self.data_size = self.data.shape[0]
rank_zero_info(f"Data has {self.data_size} samples.")
elif args.data_type == "wds_img":
Expand Down
18 changes: 18 additions & 0 deletions RWKV-v4neo/src/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,24 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
# if trainer.is_global_zero:
# print(trainer.global_step, decay_step, decay_total, w_step, progress, lr)

if args.my_exit_tokens > 0: # cosine decay
if trainer.global_step < w_step:
lr = args.lr_init * (0.2 + 0.8 * trainer.global_step / w_step)
else:
real_tokens = real_step * args.ctx_len * args.real_bsz
warmup_tokens = w_step * args.ctx_len * args.real_bsz
progress = (real_tokens - warmup_tokens) / (args.my_exit_tokens - warmup_tokens)
progress = max(0, min(1, progress))
lr_final_factor = 0.1
lr_mult = (0.5 + lr_final_factor / 2) + (0.5 - lr_final_factor / 2) * math.cos(math.pi * progress)
lr = args.lr_init * lr_mult
if progress >= 1:
my_save(
pl_module.state_dict(),
f"{args.proj_dir}/rwkv-final.pth",
)
exit(0)

for param_group in trainer.optimizers[0].param_groups:
if args.layerwise_lr > 0:
param_group["lr"] = lr * param_group["my_lr_scale"]
Expand Down
28 changes: 21 additions & 7 deletions RWKV-v4neo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
parser.add_argument("--beta2", default=0.99, type=float) # use 0.999 when your model is close to convergence
parser.add_argument("--adam_eps", default=1e-8, type=float)
parser.add_argument("--grad_cp", default=0, type=int) # gradient checkpt: saves VRAM, but slower
parser.add_argument("--dropout", default=0, type=float)

parser.add_argument("--my_pile_version", default=1, type=int) # my special pile version
parser.add_argument("--my_pile_stage", default=0, type=int) # my special pile mode
Expand Down Expand Up @@ -108,13 +109,14 @@
parser.add_argument("--my_random_steps", default=0, type=int)
parser.add_argument("--my_testing", default='', type=str)
parser.add_argument("--my_exit", default=99999999, type=int)
parser.add_argument("--my_exit_tokens", default=-1, type=int)

parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()

########################################################################################################

import os, warnings, math, datetime, sys, time, importlib
import os, warnings, math, datetime, sys, time
import numpy as np
import torch
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -186,12 +188,15 @@

if magic_prime_bak > 0:
args.magic_prime = magic_prime_bak
args.epoch_count = args.magic_prime // 40320
if args.my_qa_mask == 2:
args.epoch_count = 2 * args.magic_prime // 40320
else:
args.epoch_count = args.magic_prime // 40320

args.epoch_steps = 40320 // args.real_bsz
assert args.epoch_steps * args.real_bsz == 40320
if args.my_pile_stage == 2:
assert args.lr_final == args.lr_init
# if args.my_pile_stage == 2:
# assert args.lr_final == args.lr_init
if args.my_pile_stage >= 2: # find latest saved model
list_p = []
for p in os.listdir(args.proj_dir):
Expand Down Expand Up @@ -220,6 +225,11 @@

samples_per_epoch = args.epoch_steps * args.real_bsz
tokens_per_epoch = samples_per_epoch * args.ctx_len
try:
deepspeed_version = deepspeed.__version__
except:
deepspeed_version = None
pass
rank_zero_info(
f"""
############################################################################
Expand All @@ -237,7 +247,7 @@
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps}
#
# Found torch {torch.__version__}, recommend 1.13.1+cu117 or newer
# Found deepspeed {deepspeed.__version__ if importlib.util.find_spec('deepspeed') else 'None'}, recommend 0.7.0 (faster than newer versions)
# Found deepspeed {deepspeed_version}, recommend 0.7.0 (faster than newer versions)
# Found pytorch_lightning {pl.__version__}, recommend 1.9.1 or newer
#
############################################################################
Expand Down Expand Up @@ -290,8 +300,12 @@
from src.model_img import RWKV_IMG
model = RWKV_IMG(args)
else:
from src.model import RWKV
model = RWKV(args)
if args.dropout > 0:
from src.model_drop2 import RWKV
model = RWKV(args)
else:
from src.model import RWKV
model = RWKV(args)

if len(args.load_model) == 0 or args.my_pile_stage == 1: # shall we build the initial weights?
init_weight_name = f"{args.proj_dir}/rwkv-init.pth"
Expand Down

0 comments on commit cca1b5e

Please sign in to comment.