Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Eric-mingjie committed Oct 7, 2023
1 parent 069ac97 commit 4651156
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 42 deletions.
20 changes: 11 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Compared to magnitude pruning which removes weights solely based on their magnit
## Update
- [x] (9.22.2023) Add [support](https://github.com/locuslab/wanda#pruning-llama-2) for LLaMA-2.
- [x] (9.22.2023) Add [code](https://github.com/locuslab/wanda#ablation-on-obs-weight-update) to reproduce the ablation study on OBS weight update in the paper.
- [x] (10.6.2023) Add new [support](https://github.com/locuslab/wanda#ablation-on-obs-weight-update) for the weight update analysis in the ablation study. Feel free to try it out!

## Setup
Installation instructions can be found in [INSTALL.md](INSTALL.md).
Expand Down Expand Up @@ -61,7 +62,7 @@ python main.py \
--sparsity_type unstructured \
--save out/llama2_7b/unstructured/wanda/
```
LLaMA-2 results: (LLaMA-2-30b is not released as of 9.22.2023)
LLaMA-2 results: (LLaMA-2-34b is not released as of 9.22.2023)
|sparsity| ppl | llama2-7b | llama2-13b | llama2-70b |
|------|------------------|----------|------------|------------|
|-| dense | 5.12 | 4.57 | 3.12 |
Expand All @@ -76,18 +77,19 @@ LLaMA-2 results: (LLaMA-2-30b is not released as of 9.22.2023)
|2:4| wanda | 11.02 | **8.27** | **5.16** |

### Ablation on OBS weight update
To reproduce the results in Table 6. Run the following commands:
To reproduce the analysis on weight update, we provide our implementation for this ablation. All commands can be found in [this script](scripts/ablate_weight_update.sh).
```sh
for method in ablate_magnitude ablate_wanda
for method in ablate_mag_seq ablate_wanda_seq ablate_mag_iter ablate_wanda_iter
do
python main.py \
--model decapoda-research/llama-7b-hf \
--prune_method ${method} \
--sparsity_ratio 0.5 \
--sparsity_type unstructured \
--save out/llama_7b/ablate/
CUDA_VISIBLE_DEVICES=0 python main.py \
--model decapoda-research/llama-7b-hf \
--sparsity_ratio 0.5 \
--sparsity_type unstructured \
--prune_method ${method} \
--save out/llama_7b_ablation/unstructured/
done
```
Here `ablate_{mag/wanda}_{seq/iter}` means that we use magnitude pruning or wanda to obtain the pruned mask at each layer, then apply weight update procedure with either a sequential style or an iterative style every 128 input channels. For details, please see Section 5 of our [paper](https://arxiv.org/abs/2306.11695).

For pruning image classifiers, see directory [image_classifiers](image_classifiers) for details.

Expand Down
47 changes: 33 additions & 14 deletions lib/ablate.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,25 +42,38 @@ def add_batch(self, inp, out):
self.H += inp.matmul(inp.t())
self.scaler_row += torch.norm(inp, p=2, dim=1) ** 2 / self.nsamples

def get_wanda_mask(self, sparsity):
def get_wanda_mask(self, sparsity, prunen, prunem):
W_metric = torch.abs(self.layer.weight.data) * torch.sqrt(self.scaler_row.reshape((1,-1)))
W_mask = (torch.zeros_like(W_metric) == 1) ## initialize a mask to be all False
sort_res = torch.sort(W_metric, dim=-1, stable=True)
indices = sort_res[1][:,:int(W_metric.shape[1]*sparsity)]
W_mask.scatter_(1, indices, True)
if prunen != 0:
for ii in range(W_metric.shape[1]):
if ii % prunem == 0:
tmp = W_metric[:,ii:(ii+prunem)].float()
W_mask.scatter_(1,ii+torch.topk(tmp, prunen,dim=1, largest=False)[1], True)
else:
sort_res = torch.sort(W_metric, dim=-1, stable=True)
indices = sort_res[1][:,:int(W_metric.shape[1]*sparsity)]
W_mask.scatter_(1, indices, True)

return W_mask

def get_mag_mask(self, sparsity):
def get_mag_mask(self, sparsity, prunen, prunem):
W = self.layer.weight.data
W_metric = torch.abs(W)
thresh = torch.sort(W_metric.flatten().cuda())[0][int(W.numel()*sparsity)].cpu()
W_mask = (W_metric<=thresh)
if prunen != 0:
W_mask = (torch.zeros_like(W)==1)
for ii in range(W_metric.shape[1]):
if ii % prunem == 0:
tmp = W_metric[:,ii:(ii+prunem)].float()
W_mask.scatter_(1,ii+torch.topk(tmp, prunen,dim=1, largest=False)[1], True)
else:
thresh = torch.sort(W_metric.flatten().cuda())[0][int(W.numel()*sparsity)].cpu()
W_mask = (W_metric<=thresh)

return W_mask

def fasterprune(
self, sparsity, mask=None, prune_n=0, prune_m=0, blocksize=128, percdamp=.01
self, args, sparsity, mask=None, prune_n=0, prune_m=0, blocksize=128, percdamp=.01
):
W = self.layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d):
Expand All @@ -87,8 +100,6 @@ def fasterprune(
H = torch.linalg.cholesky(H, upper=True)
Hinv = H

# mask = None

for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
count = i2 - i1
Expand All @@ -99,11 +110,15 @@ def fasterprune(
Losses1 = torch.zeros_like(W1)
Hinv1 = Hinv[i1:i2, i1:i2]

if prune_n == 0:
if prune_n == 0 or mask is not None:
if mask is not None:
mask1 = mask[:, i1:i2]
else:
tmp = W1 ** 2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2
# tmp = W1 ** 2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2
if "wanda" in args.prune_method:
tmp = torch.abs(W1) * torch.sqrt(self.scaler_row[i1:i2].reshape((1,-1)))
elif "mag" in args.prune_method:
tmp = torch.abs(W1)
thresh = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)]
mask1 = tmp <= thresh
else:
Expand All @@ -113,8 +128,12 @@ def fasterprune(
w = W1[:, i]
d = Hinv1[i, i]

if prune_n != 0 and i % prune_m == 0:
tmp = W1[:, i:(i + prune_m)] ** 2 / (torch.diag(Hinv1)[i:(i + prune_m)].reshape((1, -1))) ** 2
if prune_n != 0 and i % prune_m == 0 and mask is None:
# tmp = W1[:, i:(i + prune_m)] ** 2 / (torch.diag(Hinv1)[i:(i + prune_m)].reshape((1, -1))) ** 2
if "wanda" in args.prune_method:
tmp = torch.abs(W1[:, i:(i+prune_m)]) * torch.sqrt(self.scaler_row[(i+i1):(i+i1+prune_m)].reshape((1,-1)))
elif "mag" in args.prune_method:
tmp = torch.abs(W1[:, i:(i+prune_m)])
mask1.scatter_(1, i + torch.topk(tmp, prune_n, dim=1, largest=False)[1], True)

q = w.clone()
Expand Down
7 changes: 3 additions & 4 deletions lib/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,22 @@
from .data import get_loaders

# Function to evaluate perplexity (ppl) on a specified model and tokenizer
def eval_ppl(model, tokenizer, device=torch.device("cuda:0")):
def eval_ppl(args, model, tokenizer, device=torch.device("cuda:0")):
# Set dataset
dataset = "wikitext2"

# Print status
print(f"evaluating on {dataset}")

# Get the test loader
trainloader, testloader = get_loaders(
_, testloader = get_loaders(
dataset, seed=0, seqlen=model.seqlen, tokenizer=tokenizer
)

# Evaluate ppl in no grad context to avoid updating the model
with torch.no_grad():
ppl_test = eval_ppl_wikitext(model, testloader, 1, device)
ppl_train = eval_ppl_wikitext_train(model, trainloader, 1, device)
return ppl_train, ppl_test
return ppl_test

# Function to evaluate perplexity (ppl) specifically on the wikitext dataset
def eval_ppl_wikitext_train(model, trainloader, bs=1, device=None):
Expand Down
14 changes: 8 additions & 6 deletions lib/prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def tmp(_, inp, out):
def prune_ablate(args, model, tokenizer, dev, prune_n=0, prune_m=0):
## SparseGPT code available at: https://github.com/IST-DASLab/sparsegpt/tree/f5c25005a61f96a0933ca2f95705a963585aafaa
print('Starting ...')
dataloader, _ = get_loaders("c4",nsamples=args.nsamples,seed=args.seed,seqlen=2048,tokenizer=tokenizer)
dataloader, _ = get_loaders("c4",nsamples=args.nsamples,seed=args.seed,seqlen=model.seqlen,tokenizer=tokenizer)

use_cache = model.config.use_cache
model.config.use_cache = False
Expand Down Expand Up @@ -376,12 +376,14 @@ def tmp(_, inp, out):
print(i, name)
print('Pruning ...')

if args.prune_method == "ablate_wanda":
prune_mask = gpts[name].get_wanda_mask(args.sparsity_ratio)
elif args.prune_method == "ablate_magnitude":
prune_mask = gpts[name].get_mag_mask(args.sparsity_ratio)
if args.prune_method == "ablate_wanda_seq":
prune_mask = gpts[name].get_wanda_mask(args.sparsity_ratio, prune_n, prune_m)
elif args.prune_method == "ablate_mag_seq":
prune_mask = gpts[name].get_mag_mask(args.sparsity_ratio, prune_n, prune_m)
elif "iter" in args.prune_method:
prune_mask = None

gpts[name].fasterprune(args.sparsity_ratio, mask=prune_mask, prune_n=prune_n, prune_m=prune_m, percdamp=0.01, blocksize=128)
gpts[name].fasterprune(args, args.sparsity_ratio, mask=prune_mask, prune_n=prune_n, prune_m=prune_m, percdamp=0.01, blocksize=128)
gpts[name].free()

for j in range(args.nsamples):
Expand Down
20 changes: 11 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from importlib.metadata import version

from lib.prune import prune_wanda, prune_magnitude, prune_sparsegpt, prune_ablate, check_sparsity, find_layers
# from lib.search import prune_search
from lib.eval import eval_ppl

print('torch', version('torch'))
Expand All @@ -32,11 +33,16 @@ def main():
parser.add_argument('--nsamples', type=int, default=128, help='Number of calibration samples.')
parser.add_argument('--sparsity_ratio', type=float, default=0, help='Sparsity level')
parser.add_argument("--sparsity_type", type=str, choices=["unstructured", "4:8", "2:4"])
parser.add_argument("--prune_method", type=str, choices=["magnitude", "wanda", "sparsegpt", "ablate_magnitude", "ablate_wanda"])
parser.add_argument("--prune_method", type=str, choices=["magnitude", "wanda", "sparsegpt",
"ablate_mag_seq", "ablate_wanda_seq", "ablate_mag_iter", "ablate_wanda_iter", "search"])
parser.add_argument("--cache_dir", default="llm_weights", type=str )
parser.add_argument('--use_variant', action="store_true", help="whether to use the wanda variant described in the appendix")
parser.add_argument('--save', type=str, default=None, help='Path to save results.')
parser.add_argument('--save_model', type=str, default=None, help='Path to save the pruned model.')

parser.add_argument('--prune_metric', type=str, help='LLaMA model')
parser.add_argument('--prune_granularity', type=str, help='LLaMA model')
parser.add_argument('--blocksize', type=int, default=1, help='LLaMA model')
args = parser.parse_args()

# Setting seeds for reproducibility
Expand Down Expand Up @@ -77,19 +83,15 @@ def main():
print(f"sparsity sanity check {sparsity_ratio:.4f}")
print("*"*30)
################################################################
ppl_train, ppl_test = eval_ppl(model, tokenizer, device)
print(f"ppl on wikitext_train {ppl_train}, wikitext_test {ppl_test}")
ppl_test = eval_ppl(args, model, tokenizer, device)
print(f"wikitext perplexity {ppl_test}")

if not os.path.exists(args.save):
os.makedirs(args.save)
save_filepath = os.path.join(args.save, f"log_{args.prune_method}.txt")
with open(save_filepath, "w") as f:
if "ablate" in args.prune_method:
print("method\tactual_sparsity\tppl_train\tppl_test", file=f, flush=True)
print(f"{args.prune_method}\t{sparsity_ratio:.4f}\t{ppl_train:.4f}\t{ppl_test:.4f}", file=f, flush=True)
else:
print("method\tactual_sparsity\tppl_test", file=f, flush=True)
print(f"{args.prune_method}\t{sparsity_ratio:.4f}\t{ppl_test:.4f}", file=f, flush=True)
print("method\tactual_sparsity\tppl_test", file=f, flush=True)
print(f"{args.prune_method}\t{sparsity_ratio:.4f}\t{ppl_test:.4f}", file=f, flush=True)

if args.save_model:
model.save_pretrained(args.save_model)
Expand Down
32 changes: 32 additions & 0 deletions scripts/ablate_weight_update.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
for method in ablate_mag_seq ablate_wanda_seq ablate_mag_iter ablate_wanda_iter
do
CUDA_VISIBLE_DEVICES=0 python main.py \
--model decapoda-research/llama-7b-hf \
--nsamples 128 \
--sparsity_ratio 0.5 \
--sparsity_type unstructured \
--prune_method ${method} \
--save out/llama_7b_ablation/unstructured/
done

for method in ablate_mag_seq ablate_wanda_seq ablate_mag_iter ablate_wanda_iter
do
CUDA_VISIBLE_DEVICES=0 python main.py \
--model decapoda-research/llama-7b-hf \
--nsamples 128 \
--sparsity_ratio 0.5 \
--sparsity_type 4:8 \
--prune_method ${method} \
--save out/llama_7b_ablation/4:8/
done

for method in ablate_mag_seq ablate_wanda_seq ablate_mag_iter ablate_wanda_iter
do
CUDA_VISIBLE_DEVICES=0 python main.py \
--model decapoda-research/llama-7b-hf \
--nsamples 128 \
--sparsity_ratio 0.5 \
--sparsity_type 2:4 \
--prune_method ${method} \
--save out/llama_7b_ablation/2:4/
done

0 comments on commit 4651156

Please sign in to comment.