Skip to content

Commit

Permalink
misc improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed Jan 6, 2023
1 parent a39a175 commit 23f64ae
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 13 deletions.
22 changes: 11 additions & 11 deletions RWKV-v4neo/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
########################################################################################################

args.RUN_DEVICE = "cpu" # 'cpu' (already very fast) // 'cuda'
args.FLOAT_MODE = "fp32" # fp32 (good for cpu) // fp16 (might overflow) // bf16 (less accurate)
args.FLOAT_MODE = "fp32" # fp32 (good for CPU) // fp16 (recommended for GPU) // bf16 (less accurate)
# if args.RUN_DEVICE == "cuda":
# os.environ["RWKV_RUN_BACKEND"] = 'nvfuser' # !!!BUGGY!!! wrong output

Expand Down Expand Up @@ -87,25 +87,25 @@

###### A good prompt for chatbot ######
# context = '''
# The following is a conversation between a highly knowledgeable and intelligent AI assistant, called RWKV, and a human user, called User. In the following interactions, User and RWKV will converse in natural language, and RWKV will do its best to answer Users questions. RWKV was built to be respectful, polite and inclusive. It knows a lot, and always tells the truth. The conversation begins.
# The following is a conversation between a highly knowledgeable and intelligent AI assistant called Bot, and a human user called User. In the following interactions, User and Bot converse in natural language, and Bot always answer User's questions. Bot is very smart, polite and humorous. Bot knows a lot, and always tells the truth. The conversation begins.

# User: OK RWKV, I’m going to start by quizzing you with a few warm-up questions. Who is currently the president of the USA?
# User: who is president of usa?

# RWKV: It’s Joe Biden; he was sworn in earlier this year.
# Bot: It’s Joe Biden; he was sworn in earlier this year.

# User: What year was the French Revolution?
# User: french revolution what year

# RWKV: It started in 1789, but it lasted 10 years until 1799.
# Bot: It started in 1789, but it lasted 10 years until 1799.

# User: Can you guess who I might want to marry?
# User: guess i marry who ?

# RWKV: Only if you tell me more about yourself - what are your interests?
# Bot: Only if you tell me more about yourself - what are your interests?

# User: Aha, I’m going to refrain from that for now. Now for a science question. What can you tell me about the Large Hadron Collider (LHC)?
# User: wat is lhc

# RWKV: It’s a large and very expensive piece of science equipment. If I understand correctly, it’s a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.
# Bot: It’s a large and very expensive piece of science equipment. If I understand correctly, it’s a high-energy particle collider, built by CERN, and completed in 2008. They used it to confirm the existence of the Higgs boson in 2012.

# User:'''
# User:''' # type your question here

NUM_TRIALS = 999
LENGTH_PER_TRIAL = 333
Expand Down
4 changes: 2 additions & 2 deletions RWKV-v4neo/src/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ def on_train_epoch_end(self, trainer, pl_module):
to_save_dict,
f"{args.proj_dir}/rwkv-{args.epoch_begin + trainer.current_epoch}.pth",
)
except:
pass
except Exception as e:
print('Error\n\n', e, '\n\n')
trainer.my_log.write(f"{args.epoch_begin + trainer.current_epoch} {trainer.my_epoch_loss:.6f} {math.exp(trainer.my_epoch_loss):.4f} {trainer.my_lr:.8f} {datetime.datetime.now()} {trainer.current_epoch}\n")
trainer.my_log.flush()

Expand Down
104 changes: 104 additions & 0 deletions RWKV-v4neo/verify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
########################################################################################################
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
########################################################################################################

# this is for verifying the results of different models and make sure they agree with each other

import os, sys, types
import numpy as np
import torch
np.set_printoptions(precision=4, suppress=True, linewidth=200)
try:
os.environ["CUDA_VISIBLE_DEVICES"] = sys.argv[1]
except:
pass
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = False
torch.backends.cuda.matmul.allow_tf32 = False

os.environ['RWKV_FLOAT_MODE'] = 'bf16' # bf16 or fp32
os.environ['RWKV_RUN_DEVICE'] = 'cuda' # currently model_train requires CUDA
RUN_DEVICE = os.environ['RWKV_RUN_DEVICE']

TOKEN_MODE = 'pile'

if TOKEN_MODE == 'pile':
WORD_NAME = ['20B_tokenizer.json', '20B_tokenizer.json']
MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-3b/RWKV-4-Pile-3B-20221003-6783'
n_layer = 32
n_embd = 2560
ctx_len = 1024
UNKNOWN_CHAR = None

from src.utils import TOKENIZER
tokenizer = TOKENIZER(WORD_NAME, UNKNOWN_CHAR=UNKNOWN_CHAR)
if TOKEN_MODE == 'pile':
tokenizer.vocab_size = 50277

########################################################################################################

os.environ["RWKV_JIT_ON"] = "1"
os.environ["RWKV_T_MAX"] = str(ctx_len)

from src.model_run import RWKV_RNN
from src.model import RWKV

args = types.SimpleNamespace()
args.vocab_size = tokenizer.vocab_size
args.ctx_len = ctx_len
args.n_embd = n_embd
args.n_layer = n_layer
args.head_qk = 0
args.pre_ffn = 0
args.grad_cp = 0
args.my_pos_emb = 0
model_train = RWKV(args).to(RUN_DEVICE)

if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
model_train = model_train.half()
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
model_train = model_train.bfloat16()

print('loading ' + MODEL_NAME)
m2 = torch.load(MODEL_NAME + '.pth', map_location='cpu')
model_train.load_state_dict(m2)

if os.environ['RWKV_FLOAT_MODE'] == 'fp16':
model_train = model_train.half()
elif os.environ['RWKV_FLOAT_MODE'] == 'bf16':
model_train = model_train.bfloat16()

args.MODEL_NAME = MODEL_NAME
args.RUN_DEVICE = RUN_DEVICE
args.FLOAT_MODE = os.environ['RWKV_FLOAT_MODE']
model_rnn = RWKV_RNN(args)

########################################################################################################

print(f"\nVerifying {os.environ['RWKV_RUN_DEVICE']} {os.environ['RWKV_FLOAT_MODE']}")

# context = '\nIn a'
context = '\nIn a shocking finding, scientist discovered a herd of dragons living in a remote, previously unexplored valley, in Tibet. Even more surprising to the researchers was the fact that the dragons spoke perfect Chinese.'

if TOKEN_MODE == 'pile':
ctx = tokenizer.tokenizer.encode(context)
print(f'input len {len(ctx)} data {ctx}')

########################################################################################################

with torch.no_grad():
print('\nRWKV-train output')
out = model_train.forward(torch.tensor([ctx]).to(RUN_DEVICE))[0].detach().cpu().float().numpy()
print(out, '\n')

print('\nRWKV-RNN output')
state = None
out = None
src_len = len(ctx)
for i in range(src_len):
x = ctx[:i+1]
out, state = model_rnn.forward(x, state)
if i < 3 or i >= src_len - 3:
print(out.detach().cpu().numpy())
if i == 2:
print('...')

0 comments on commit 23f64ae

Please sign in to comment.