forked from pytorch/ao
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path_eval.py
364 lines (302 loc) · 10.7 KB
/
_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from torchao.quantization.utils import _lm_eval_available, _MultiInput
from torchao.quantization.GPTQ_MT import MultiTensor
import lm_eval
try: # lm_eval version 0.4
from lm_eval.evaluator import evaluate # pyre-ignore[21]
from lm_eval.models.huggingface import HFLM as eval_wrapper # pyre-ignore[21]
from lm_eval.tasks import get_task_dict # pyre-ignore[21]
except: # lm_eval version 0.3
from lm_eval import base, evaluator, tasks
eval_wrapper = base.BaseLM
get_task_dict = tasks.get_task_dict
evaluate = evaluator.evaluate
import torch
import torch.nn.functional as F
class MultiTensorInputRecorder(eval_wrapper):
def __init__(
self,
tokenizer,
calibration_seq_length,
input_prep_func=None,
pad_calibration_inputs=False,
vocab_size=32000,
pad_token=0,
device="cpu",
):
try:
super().__init__()
except TypeError:
# lm_eval 0.4.2 removed the default init
super().__init__("gpt2", device="cpu")
self.tokenizer = tokenizer
self._device = torch.device(device)
self.vocab_size = vocab_size
self._max_seq_length = calibration_seq_length
self.calibration_seq_length = calibration_seq_length
self.input_prep_func = (
input_prep_func if input_prep_func is not None else lambda x: (x,)
)
self.pad_calibration_inputs = pad_calibration_inputs
self.pad_token = pad_token
# Initialize inputs as a list of two empty lists for input tensors and indices
self.inputs = [[], []]
@property
def eot_token_id(self):
try:
return self.tokenizer.eos_id()
except:
return self.tokenizer.eos_id
@property
def max_length(self):
return self._max_seq_length
@property
def max_gen_toks(self):
return 50
@property
def batch_size(self):
return 1
@property
def device(self):
return self._device
def tok_encode(self, string: str, **kwargs):
tokens = self.tokenizer.encode(string)
if hasattr(self.tokenizer, "bos_id"):
try:
tokens = [self.tokenizer.bos_id()] + tokens
except:
tokens = [self.tokenizer.bos_id] + tokens
return tokens
def tok_decode(self, tokens):
decoded = self.tokenizer.decode(tokens)
return decoded
def add_input(self, args):
# Ensure that inputs are added correctly as pairs
self.inputs[0].append(args[0])
self.inputs[1].append(args[1])
def record_inputs(self, calibration_tasks, calibration_limit):
try:
lm_eval.tasks.initialize_tasks()
except:
pass
task_dict = get_task_dict(calibration_tasks)
print("Obtaining GPTQ calibration inputs on: ", calibration_tasks)
evaluate(
self,
task_dict,
limit=calibration_limit,
)
return self
def get_inputs(self):
# Return MultiTensor instances for both inputs and indices
return [MultiTensor(self.inputs[0]), MultiTensor(self.inputs[1])]
def _model_call(self, inps):
inps = inps.squeeze(0)
T = len(inps)
if (
# Can't use inputs that are too short when padding is disabled
(T < self.calibration_seq_length and not self.pad_calibration_inputs)
or
# Can't use inputs that actually use the token we use for padding
(self.pad_calibration_inputs and self.pad_token in inps)
):
# Give random output
return torch.randn(
(1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device
)
# Pad or truncate to the correct size
if T >= self.calibration_seq_length:
inps = inps[: self.calibration_seq_length]
else:
inps = F.pad(inps, (0, self.calibration_seq_length - T), value=self.pad_token)
inps = inps.unsqueeze(0)
model_in = self.input_prep_func(inps)
self.add_input(model_in)
# Output `something` with the correct shape to keep eval going
return torch.randn(
(1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device
)
def _model_generate(self, context, max_length, eos_token_id):
raise Exception("unimplemented")
class InputRecorder(eval_wrapper):
"""
This is a fake evaluation wrapper from the lm_eval library that just records the inputs
so that they can be used in calibration.
If pad_calibration_inputs is enabled, the input recorder will take
each input and pad/truncate it down to the calibration_seq_length.
(if using padding you should set the embeddings for the pad_token to 0
in the model)
Note: after padding/truncation, input_prep_function is called to bring
it to the proper form to be inserted into a given model.
If not, it will only truncate inputs to the desired length.
"""
def __init__(
self,
tokenizer,
calibration_seq_length,
input_prep_func=None,
pad_calibration_inputs=False,
vocab_size=32000,
pad_token=0,
device="cpu",
):
try:
super().__init__()
except TypeError:
# lm_eval 0.4.2 removed the default init
super().__init__("gpt2", device="cpu")
self.tokenizer = tokenizer
self._device = torch.device(device)
self.vocab_size = vocab_size
self._max_seq_length = calibration_seq_length
self.calibration_seq_length = calibration_seq_length
# need to take inps and convert to corrent input
# for model
self.input_prep_func = (
input_prep_func if input_prep_func is not None
else lambda x: (x,)
)
self.pad_calibration_inputs = pad_calibration_inputs
self.pad_token = pad_token
self.inputs = None
@property
def eot_token_id(self):
try:
return self.tokenizer.eos_id()
except:
return self.tokenizer.eos_id
@property
def max_length(self):
return self._max_seq_length
@property
def max_gen_toks(self):
return 50
@property
def batch_size(self):
return 1
@property
def device(self):
return self._device
def tok_encode(self, string: str, **kwargs):
# TODO: verify this for multi-batch as well
tokens = self.tokenizer.encode(string)
if hasattr(self.tokenizer, "bos_id"):
try:
tokens = [self.tokenizer.bos_id()] + tokens
except:
tokens = [self.tokenizer.bos_id] + tokens
return tokens
def tok_decode(self, tokens):
decoded = self.tokenizer.decode(tokens)
return decoded
def add_input(self, args):
if self.inputs is None:
self.inputs = [_MultiInput([arg]) for arg in args]
else:
self.inputs = [
multi.add_input(arg) for (multi, arg) in zip(self.inputs, args)
]
def record_inputs(
self,
calibration_tasks,
calibration_limit,
):
try:
lm_eval.tasks.initialize_tasks()
except:
pass
task_dict = get_task_dict(calibration_tasks)
print("Obtaining GPTQ calibration inputs on: ", calibration_tasks)
evaluate(
self,
task_dict,
limit=calibration_limit,
)
return self
def get_inputs(self):
return self.inputs
def _model_call(self, inps):
inps = inps.squeeze(0)
T = len(inps)
if (
# can't use inputs that are too short when padding disabled
(T < self.calibration_seq_length and not self.pad_calibration_inputs)
or
# can't use inputs that actually use token we use for padding
(self.pad_calibration_inputs and self.pad_token in inps)
):
# give random output
return torch.randn(
(1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device
)
# pad or truncate to the right size
if T >= self.calibration_seq_length:
inps = inps[: self.calibration_seq_length]
else:
inps = F.pad(inps, (self.pad_token, self.calibration_seq_length - T))
inps = inps.unsqueeze(0)
model_in = self.input_prep_func(inps)
self.add_input(model_in)
# output `something` with correct shape to keep eval going
return torch.randn(
(1, T, self.vocab_size), dtype=torch.bfloat16, device=self._device
)
def _model_generate(self, context, max_length, eos_token_id):
raise Exception("unimplemented")
class TransformerEvalWrapper(InputRecorder):
"""
A wrapper class for GPTFast, providing integration with the lm-evaluation-harness library.
"""
def __init__(
self,
model,
tokenizer,
max_seq_length,
input_prep_func=None,
device="cuda"
):
super().__init__(tokenizer, None)
self._model = model
# self.tokenizer = tokenizer
self._device = torch.device(device)
self._max_seq_length = max_seq_length
# need to take inps and convert to corrent input
# for model
self.input_prep_func = (
input_prep_func if input_prep_func is not None
else lambda x: (x,)
)
def _model_call(self, inps):
# TODO: make batches work
input = self.input_prep_func(inps)
max_seq_length = min(max(inps.size()), self.max_length)
with torch.device(self._device):
self._model.setup_caches(self.batch_size, max_seq_length)
logits = self._model(*input)
return logits
def _model_generate(self, context, max_length, eos_token_id):
raise Exception('unimplemented')
def run_eval(self, tasks, limit):
try:
lm_eval.tasks.initialize_tasks()
except:
pass
task_dict = get_task_dict(tasks)
print("Evaluating Model On: ", task_dict)
with torch.no_grad():
result = evaluate(
self,
task_dict,
limit=limit,
)
for task, res in result["results"].items():
print(f"{task}: {res}")
return result