Skip to content

Commit

Permalink
Faster load for inference.
Browse files Browse the repository at this point in the history
  • Loading branch information
Blealtan committed Mar 4, 2023
1 parent 26b4a5a commit d04dcd7
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
4 changes: 2 additions & 2 deletions RWKV-v4neo/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
args.ctx_len = 1024

# Modify this to use LoRA models; lora_r = 0 will not use LoRA weights.
args.MODEL_LORA = '/home/blealtancao/rwkv-models/lora-full-1e-4/rwkv-30'
args.lora_r = 4
args.MODEL_LORA = '/home/blealtancao/rwkv-models/lora-full-1e-4/rwkv-33'
args.lora_r = 8
args.lora_alpha = 16

# args.MODEL_NAME = '/fsx/BlinkDL/HF-MODEL/rwkv-4-pile-7b/RWKV-4-Pile-7B-20221115-8047'
Expand Down
5 changes: 5 additions & 0 deletions RWKV-v4neo/src/model_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ def __init__(self, args):
assert lora_B in keys
print(f'merging {lora_A} and {lora_B} into {k}')
assert w[lora_B].shape[1] == w[lora_A].shape[0] == args.lora_r
# merging needs matmul, which is slow on cpu; work on gpu if possible
if args.RUN_DEVICE == 'cuda':
w[k] = w[k].cuda()
w[lora_A] = w[lora_A].cuda()
w[lora_B] = w[lora_B].cuda()
w[k] += w[lora_B] @ w[lora_A] * (args.lora_alpha / args.lora_r)
del w[lora_A]
del w[lora_B]
Expand Down

0 comments on commit d04dcd7

Please sign in to comment.