diff --git a/RWKV-v4neo/src/binidx.py b/RWKV-v4neo/src/binidx.py index 404a581c..f8365f3d 100644 --- a/RWKV-v4neo/src/binidx.py +++ b/RWKV-v4neo/src/binidx.py @@ -49,6 +49,58 @@ class MMapIndexedDataset(torch.utils.data.Dataset): class Index(object): _HDR_MAGIC = b"MMIDIDX\x00\x00" + @classmethod + def writer(cls, path, dtype): + class _Writer(object): + def __enter__(self): + self._file = open(path, "wb") + + # Write Magic string so we can check the file format then opening it again. + self._file.write(cls._HDR_MAGIC) + # Write version number + # Little endian unsigned 64 Bit integer + self._file.write(struct.pack(" 0: - assert self.data_size == 332115325534 and self.vocab_size == 50277 + # assert self.data_size == 332115325534 and self.vocab_size == 50277 self.samples_per_epoch = args.epoch_steps * args.real_bsz assert self.samples_per_epoch == 40320 print(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########") dataset_slot = self.data_size // args.ctx_len assert MaybeIsPrime(args.magic_prime) assert args.magic_prime % 3 == 2 - assert args.magic_prime / dataset_slot > 0.999999 and args.magic_prime / dataset_slot <= 1 + assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1 elif args.data_type == "numpy": self.data = np.load(args.data_file).astype("int") self.vocab_size = args.vocab_size diff --git a/RWKV-v4neo/src/trainer.py b/RWKV-v4neo/src/trainer.py index 4b09639d..a109724c 100644 --- a/RWKV-v4neo/src/trainer.py +++ b/RWKV-v4neo/src/trainer.py @@ -97,6 +97,14 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): if kt_s > 0: lll["kt/s"] = kt_s trainer.my_wandb.log(lll, step=int(real_step)) + if args.magic_prime > 0: + if int(real_step) == int(args.magic_prime // args.real_bsz) - 1: + to_save_dict = pl_module.state_dict() + torch.save( + to_save_dict, + f"{args.proj_dir}/rwkv-final.pth", + ) + def on_train_epoch_start(self, trainer, pl_module): args = self.args diff --git a/RWKV-v4neo/train.py b/RWKV-v4neo/train.py index 1adf6972..2b2d3e81 100644 --- a/RWKV-v4neo/train.py +++ b/RWKV-v4neo/train.py @@ -99,6 +99,7 @@ parser.add_argument("--my_att_shift", default=1, type=int) parser.add_argument("--my_pos_emb", default=0, type=int) parser.add_argument("--load_partial", default=0, type=int) + parser.add_argument("--magic_prime", default=0, type=int) parser = Trainer.add_argparse_args(parser) args = parser.parse_args() @@ -145,6 +146,7 @@ os.makedirs(args.proj_dir) if args.my_pile_stage > 0: + magic_prime_bak = args.magic_prime if args.ctx_len == 1024: args.magic_prime = 324331313 args.epoch_count = 8043 @@ -162,6 +164,9 @@ elif args.ctx_len == 4096: args.my_pile_shift = 768 + if magic_prime_bak > 0: + args.magic_prime = magic_prime_bak + args.epoch_steps = 40320 // args.real_bsz assert args.epoch_steps * args.real_bsz == 40320 if args.my_pile_stage == 2: