Skip to content

Commit

Permalink
Util functions for rank-reduce experiment
Browse files Browse the repository at this point in the history
  • Loading branch information
turboderp committed Jan 6, 2024
1 parent fc1629d commit 0240801
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 4 deletions.
15 changes: 14 additions & 1 deletion exllamav2/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,4 +200,17 @@ def get_weight_tensor_dq(self):

def is_quant(self):

return self.q_handle is not None
return self.q_handle is not None


def rank_reduce(self, k):

assert not self.is_quant(), "Can't rank-reduce quantized layer"

weight = self.linear.weight.data.float()
max_rank = min(weight.shape[0], weight.shape[1])
desired_rank = int(max_rank * k)
results = torch.svd_lowrank(weight, q = desired_rank, niter = 10)
weight_approx = results[0] @ torch.diag(results[1]) @ results[2].T

self.linear.weight = nn.Parameter(weight_approx.half())
8 changes: 8 additions & 0 deletions exllamav2/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,11 @@ def is_quant(self):
return self.q_handle is not None


def rank_reduce(self, k):

self.gate_proj.rank_reduce(k)
self.up_proj.rank_reduce(k)
self.down_proj.rank_reduce(k)



7 changes: 7 additions & 0 deletions exllamav2/moe_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,10 @@ def update_loras(self):
def is_quant(self):
return self.q_handle is not None


def rank_reduce(self, k):

for e in range(self.num_experts):
self.w1[e].rank_reduce(k)
self.w2[e].rank_reduce(k)
self.w3[e].rank_reduce(k)
32 changes: 29 additions & 3 deletions test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
ExLlamaV2Sampler
)

from exllamav2.attn import (
ExLlamaV2Attention
)
from exllamav2.attn import ExLlamaV2Attention
from exllamav2.mlp import ExLlamaV2MLP
from exllamav2.moe_mlp import ExLlamaV2MoEMLP

import argparse, os, math, time
import pandas, fastparquet
Expand Down Expand Up @@ -50,6 +50,7 @@
parser.add_argument("-nwu", "--no_warmup", action = "store_true", help = "Skip warmup before testing model")
parser.add_argument("-sl", "--stream_layers", action = "store_true", help = "Load model layer by layer (perplexity evaluation only)")
parser.add_argument("-sp", "--standard_perplexity", choices = ["wiki2"], help = "Run standard (HF) perplexity test, stride 512 (experimental)")
parser.add_argument("-rr", "--rank_reduce", type = str, help = "Rank-reduction for MLP layers of model, in reverse order (for experimentation)")

# Initialize model and tokenizer

Expand Down Expand Up @@ -100,6 +101,31 @@
model.config.max_batch_size = stream_batch_size
model.load(lazy = True)

# Rank reduction

if args.rank_reduce:

if args.stream_layers:
print(" ## --rank_reduce can not be combined with --stream_layers")
sys.exit()

rr = args.rank_reduce.split(",")
idx = len(model.modules) - 1
for r in rr:
k = float(r)

while True:
idx -= 1
module = model.modules[idx]
if isinstance(module, ExLlamaV2MLP): break
if isinstance(module, ExLlamaV2MoEMLP): break
if idx < 0:
print(" ## Not enough layers")
sys.exit()

print(f" -- Reducing {module.key} ({module.name}) to {k * 100:.2f}%")
module.rank_reduce(k)

# Replacement

if args.mix_layers:
Expand Down

0 comments on commit 0240801

Please sign in to comment.