Skip to content

Commit

Permalink
misc
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed Apr 27, 2023
1 parent 12391fe commit 2f57660
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
4 changes: 3 additions & 1 deletion RWKV-v4neo/src/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def on_train_epoch_start(self, trainer, pl_module):
def on_train_epoch_end(self, trainer, pl_module):
args = self.args
if trainer.is_global_zero: # logging & save state_dict
if (args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0) or trainer.current_epoch == args.epoch_count - 1:
if (args.epoch_save > 0 and trainer.current_epoch % args.epoch_save == 0) or (trainer.current_epoch == args.epoch_count - 1):
if args.data_type == 'wds_img':
raw_dict = pl_module.state_dict()
to_save_dict = {}
Expand All @@ -150,6 +150,8 @@ def on_train_epoch_end(self, trainer, pl_module):

trainer.my_loss_sum = 0
trainer.my_loss_count = 0
if (args.epoch_begin + trainer.current_epoch) >= args.my_exit:
exit(0)


@rank_zero_only
Expand Down
12 changes: 7 additions & 5 deletions RWKV-v4neo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
parser.add_argument("--my_qa_mask", default=0, type=int)
parser.add_argument("--my_random_steps", default=0, type=int)
parser.add_argument("--my_testing", default='', type=str)
parser.add_argument("--my_exit", default=99999999, type=int)

parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()
Expand Down Expand Up @@ -204,11 +205,12 @@
for p in os.listdir(args.proj_dir):
if p.startswith("rwkv") and p.endswith(".pth"):
p = ((p.split("-"))[1].split("."))[0]
if p == "init":
p = -1
else:
p = int(p)
list_p += [p]
if p != "final":
if p == "init":
p = -1
else:
p = int(p)
list_p += [p]
list_p.sort()
max_p = list_p[-1]
if len(list_p) > 1:
Expand Down

0 comments on commit 2f57660

Please sign in to comment.