forked from mit-han-lab/streaming-llm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval_long_ppl.py
110 lines (92 loc) · 3.48 KB
/
eval_long_ppl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
from tqdm import tqdm
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from torch.nn import CrossEntropyLoss
from streaming_llm.kv_cache import StartRecentKVCache
from streaming_llm.utils import parse_args, load
device = "cuda"
args = parse_args()
data = load_dataset(args.dataset_name, args.task, split=args.split)
model, tokenizer = load(args.model_name_or_path)
nlls = []
loss_fn = CrossEntropyLoss(reduction="none")
past_key_values = None
if args.enable_start_recent_kv_cache:
if "llama" in model.config.model_type:
k_seq_dim = v_seq_dim = 2
elif "mpt" in model.config.model_type:
v_seq_dim = 2
k_seq_dim = 3
elif "pythia" in model.config.model_type:
k_seq_dim = v_seq_dim = 2
elif "falcon" in model.config.model_type:
v_seq_dim = 1
k_seq_dim = 1
else:
raise ValueError(f"got {model.config.model_type}")
kv_cache = StartRecentKVCache(
start_size=args.start_size,
recent_size=args.recent_size,
k_seq_dim=k_seq_dim,
v_seq_dim=v_seq_dim,
)
else:
kv_cache = None
if args.enable_pos_shift:
if "llama" in model.config.model_type:
from streaming_llm.pos_shift.modify_llama import enable_llama_pos_shift_attention
enable_llama_pos_shift_attention(model)
elif "falcon" in model.config.model_type:
from streaming_llm.pos_shift.modify_falcon import (
enable_falcon_pos_shift_attention,
)
enable_falcon_pos_shift_attention(model)
elif "gpt_neox" in model.config.model_type:
from streaming_llm.pos_shift.modify_gpt_neox import (
enable_gpt_neox_pos_shift_attention,
)
enable_gpt_neox_pos_shift_attention(model)
elif "mpt" in model.config.model_type:
pass
else:
raise ValueError(f"got {model.config.model_type}")
os.makedirs(args.output_dir, exist_ok=True)
f = open(f"{args.output_dir}/log.txt", "w")
num_eval_tokens = 0
for text in data["text"][: args.num_samples]:
encodings = tokenizer(text, return_tensors="pt")
print(encodings.input_ids[:, :10])
seq_len = encodings.input_ids.size(1)
print(f"seq_len: {seq_len}")
pbar = tqdm(range(0, seq_len - 1))
for idx in pbar:
input_ids = encodings.input_ids[:, idx : idx + 1].to(device)
with torch.no_grad():
outputs = model(
input_ids,
past_key_values=past_key_values,
use_cache=True,
)
logits = outputs.logits.view(-1, model.config.vocab_size)
past_key_values = outputs.past_key_values
label = encodings.input_ids[:, idx + 1 : idx + 2].to(logits.device).view(-1)
neg_log_likelihood = loss_fn(logits, label)
if kv_cache is not None:
past_key_values = kv_cache(past_key_values)
nlls.append(neg_log_likelihood)
pbar.set_description(
f"nll: {neg_log_likelihood.item():.2f}, ppl: {torch.exp(neg_log_likelihood).item():.2f}"
)
print(neg_log_likelihood.item(), file=f, flush=True)
num_eval_tokens += 1
if args.num_eval_tokens is not None and num_eval_tokens >= args.num_eval_tokens:
break
if args.num_eval_tokens is not None and num_eval_tokens >= args.num_eval_tokens:
break
f.close()
ppl = torch.exp(torch.stack(nlls).mean())
print(ppl.item())
with open(f"{args.output_dir}/ppl.txt", "w") as f:
f.write(f"{ppl.item()}\n")