diff --git a/RWKV-v4neo/src/trainer.py b/RWKV-v4neo/src/trainer.py index 609fdffd..13b50ab4 100644 --- a/RWKV-v4neo/src/trainer.py +++ b/RWKV-v4neo/src/trainer.py @@ -129,7 +129,7 @@ def on_train_epoch_start(self, trainer, pl_module): def on_train_epoch_end(self, trainer, pl_module): args = self.args if trainer.is_global_zero: # logging & save state_dict - if (args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0) or trainer.current_epoch == args.epoch_count - 1: + if (args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0) or (trainer.current_epoch == args.epoch_count - 1): if args.data_type == 'wds_img': raw_dict = pl_module.state_dict() to_save_dict = {} @@ -150,6 +150,8 @@ def on_train_epoch_end(self, trainer, pl_module): trainer.my_loss_sum = 0 trainer.my_loss_count = 0 + if (args.epoch_begin + trainer.current_epoch) >= args.my_exit: + exit(0) @rank_zero_only diff --git a/RWKV-v4neo/train.py b/RWKV-v4neo/train.py index ef97fd9d..9bcf1761 100644 --- a/RWKV-v4neo/train.py +++ b/RWKV-v4neo/train.py @@ -107,6 +107,7 @@ parser.add_argument("--my_qa_mask", default=0, type=int) parser.add_argument("--my_random_steps", default=0, type=int) parser.add_argument("--my_testing", default='', type=str) + parser.add_argument("--my_exit", default=99999999, type=int) parser = Trainer.add_argparse_args(parser) args = parser.parse_args() @@ -204,11 +205,12 @@ for p in os.listdir(args.proj_dir): if p.startswith("rwkv") and p.endswith(".pth"): p = ((p.split("-"))[1].split("."))[0] - if p == "init": - p = -1 - else: - p = int(p) - list_p += [p] + if p != "final": + if p == "init": + p = -1 + else: + p = int(p) + list_p += [p] list_p.sort() max_p = list_p[-1] if len(list_p) > 1: