Skip to content

Commit

Permalink
better
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed Mar 6, 2023
1 parent 8e99ac1 commit 6d4dec7
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 9 deletions.
7 changes: 4 additions & 3 deletions RWKV-v4neo/src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ def __init__(self, args):
assert self.samples_per_epoch == 40320
rank_zero_info(f"########## Pile 20b-tokenized stage {args.my_pile_stage} ##########")
dataset_slot = self.data_size // args.ctx_len
assert MaybeIsPrime(args.magic_prime)
assert args.magic_prime % 3 == 2
assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1
if args.my_pile_stage != 4:
assert MaybeIsPrime(args.magic_prime)
assert args.magic_prime % 3 == 2
assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1
elif args.data_type == "numpy":
self.data = np.load(args.data_file).astype("int")
self.vocab_size = args.vocab_size
Expand Down
7 changes: 4 additions & 3 deletions RWKV-v4neo/src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################

import os, math, gc
import os, math, gc, importlib
import torch
# torch._C._jit_set_profiling_executor(True)
# torch._C._jit_set_profiling_mode(True)
Expand All @@ -11,8 +11,9 @@
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only
from pytorch_lightning.strategies import DeepSpeedStrategy
import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
if importlib.util.find_spec('deepspeed'):
import deepspeed
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam

# from deepspeed.runtime.fp16.onebit.zoadam import ZeroOneAdam

Expand Down
7 changes: 4 additions & 3 deletions RWKV-v4neo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,12 @@

########################################################################################################

import os, warnings, math, datetime, sys, time
import os, warnings, math, datetime, sys, time, importlib
import numpy as np
import torch
from torch.utils.data import DataLoader
import deepspeed
if "deepspeed" in args.strategy:
import deepspeed
import pytorch_lightning as pl
from pytorch_lightning import seed_everything

Expand Down Expand Up @@ -223,7 +224,7 @@
# Adam = lr {args.lr_init} to {args.lr_final}, warmup {args.warmup_steps} steps, beta {args.betas}, eps {args.adam_eps}
#
# Found torch {torch.__version__}, recommend 1.13.1+cu117 or newer
# Found deepspeed {deepspeed.__version__}, recommend 0.7.0 (faster than newer versions)
# Found deepspeed {deepspeed.__version__ if importlib.util.find_spec('deepspeed') else 'None'}, recommend 0.7.0 (faster than newer versions)
# Found pytorch_lightning {pl.__version__}, recommend 1.9.1 or newer
#
############################################################################
Expand Down

0 comments on commit 6d4dec7

Please sign in to comment.