forked from pytorch/torchchat
-
Notifications
You must be signed in to change notification settings - Fork 0
/
quantize.py
1223 lines (1041 loc) · 42.4 KB
/
quantize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import json
from functools import reduce
from math import gcd
from typing import Dict, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
from eval import evaluate, get_task_dict, lm_eval
from GPTQ import GenericGPTQRunner, InputRecorder
except:
pass
##########################################################################
### dtype name to torch.dtype mapping ###
precision = torch.float
def set_precision(dtype):
global precision
precision = dtype
def get_precision():
global precision
return precision
def name_to_dtype(name):
if name in name_to_dtype_dict:
return name_to_dtype_dict[name]
else:
raise RuntimeError(f"unsupported dtype name {name} specified")
name_to_dtype_dict = {
"fp32": torch.float,
"fp16": torch.float16,
"bf16": torch.bfloat16,
"float": torch.float,
"half": torch.float16,
"float32": torch.float,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}
##########################################################################
### process quantization dictionary ###
def quantize_model(model: nn.Module, device, quantize_options):
"""
Quantize the specified model using the quantizers described by
a quantization dict of the form:
{
'embedding': {'bitwidth': 8, 'groupsize': 8 },
'linear:int8': {'bitwidth': 8, 'groupsize': 8},
'precision': {'dtype': torch.float16},
}
"""
linears_quantized = False
if isinstance(quantize_options, str):
quantize_options = json.loads(quantize_options)
for quantizer, q_kwargs in quantize_options.items():
if quantizer == "embedding":
model = EmbeddingOnlyInt8QuantHandler(
model, device, **q_kwargs
).quantized_model()
elif linears_quantized:
assert 0 == 1, "can only specify one linear quantizer"
elif quantizer == "linear:int8":
linears_quantized = True
model = WeightOnlyInt8QuantHandler(
model, device, **q_kwargs
).quantized_model()
elif quantizer == "linear:int4":
linears_quantized = True
model = WeightOnlyInt4QuantHandler(
model, device, **q_kwargs
).quantized_model()
elif quantizer == "linear:a8w4dq":
linears_quantized = True
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
# Note that Int8DynActInt4WeightQuantizer takes precision as
# arg, which is used to determine the precision/dtype of the output
# That is, if dtype=fp32 than this dynamically quantized linear will
# return output tensor with fp32 dtype.
# Ideally we make this dynamic such that the output dtype is determined
# based on the input dtype, instead of having to instantiate quantizer
# that picks the output dtype.
# Since this require change in torchao, we leave the current state as is
# and use the default precision for Int8DynActInt4WeightQuantizer
# which is fp32.
assert 'groupsize' in list(q_kwargs.keys()), f"a8w4dq quantization option must specify groupsize. Specified options {q_kwargs}"
model = Int8DynActInt4WeightQuantizer(groupsize=q_kwargs['groupsize']
).quantize(model)
elif quantizer == "linear:gptq":
linears_quantized = True
model = WeightOnlyInt4GPTQQuantHandler(
model, device, **q_kwargs
).quantized_model()
elif quantizer == "linear:hqq":
linears_quantized = True
model = WeightOnlyInt4HqqQuantHandler(
model, device, **q_kwargs
).quantized_model()
elif quantizer == "precision":
model.to(**q_kwargs)
else:
assert 0 == 1, f"quantizer {quantizer} not supported"
#########################################################################
##### Quantization Primitives ######
def dynamically_quantize_per_channel(
x,
quant_min,
quant_max,
target_dtype,
groupsize: Optional[int] = None,
*,
scales_dtype=torch.float16,
enable_non_multiple_groups=True,
):
"""
Dynamically quantize per channel. This function is used for quantizing weights,
for linear and embedding layers.
Arguments:
x: input tensor,
quant_min: minimum value after quantization,
quant_max: maximum value after quantization,
target_dtype: target data type for weights after quantization,
groupsize: number of elements of the channel to quantize together
Keyword arguments:
scales_dtype: data type of scale,
enable_non_multiple_groups: if True, allow the rowsize to not be a multiple of group size,
with a final group of a size less than group size.
Assumptions:
This function assumes symmetric quantization, axis ==0 and a dense memory format.
"""
# assumes symmetric quantization
# assumes axis == 0
# assumes dense memory format
# TODO(future): relax ^ as needed
x_shape_1 = x.shape[1]
if groupsize is None or groupsize == 0:
items = x_shape_1
elif ((x_shape_1 % groupsize) == 0) or not enable_non_multiple_groups:
assert groupsize > 0, "group size must be positive"
assert (
x_shape_1 % groupsize
) == 0, f"weights dimension 1 = {x_shape_1} must be a multiple of group size {groupsize}"
items = groupsize
else:
assert groupsize > 0, "group size must be positive"
print(
f"row-size of weight matrix {x_shape_1} is not divisible by group size {groupsize}, using nearest neighbor rounding"
)
assert (
x_shape_1 % groupsize != 0
), f"expected x.shape[1] to not be a multiple of group size {groupsize}, but got {x_shape_1}"
padding = groupsize - (x_shape_1 % groupsize)
x = F.pad(x, (0, padding))
items = groupsize
# default setup for affine quantization of activations
eps = torch.finfo(torch.float32).eps
x = x.view(x.shape[0], x.shape[1] // items, items)
# get min and max
min_val, max_val = torch.aminmax(x, dim=2)
# print(f"min_val {min_val}")
# print(f"max_val {max_val}")
# calculate scales and zero_points based on min and max
# reference: https://fburl.com/code/srbiybme
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
device = min_val_neg.device
# reference: https://fburl.com/code/4wll53rk
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scales = max_val_pos / (float(quant_max - quant_min) / 2)
# ensure scales is the same dtype as the original tensor
scales = torch.clamp(scales, min=eps).to(x.dtype)
zero_points = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
# quantize based on qmin/qmax/scales/zp
# reference: https://www.internalfb.com/code/fbsource/[8edc275012b1]/fbcode/caffe2/torch/ao/quantization/fx/_decomposed.py?lines=63
x_div = x / scales.unsqueeze(-1)
x_round = torch.round(x_div)
x_zp = x_round + zero_points.unsqueeze(-1)
quant = (
torch.clamp(x_zp, quant_min, quant_max).to(target_dtype).view(x.shape[0], -1)
)
scales = scales.to(dtype=scales_dtype)
quant = quant[:, :x_shape_1]
return quant, scales, zero_points
def get_group_qparams(w, n_bit=4, groupsize=128, *, scales_dtype=torch.float):
# needed for GPTQ with padding
if groupsize > w.shape[-1]:
groupsize = w.shape[-1]
assert groupsize > 1
assert w.shape[-1] % groupsize == 0
assert w.dim() == 2
to_quant = w.reshape(-1, groupsize)
assert torch.isnan(to_quant).sum() == 0
max_val = to_quant.amax(dim=1, keepdim=True)
min_val = to_quant.amin(dim=1, keepdim=True)
max_int = 2**n_bit - 1
scales = (max_val - min_val).clamp(min=1e-6) / max_int
zeros = min_val + scales * (2 ** (n_bit - 1))
return scales.to(scales_dtype).reshape(w.shape[0], -1), zeros.to(
scales_dtype
).reshape(w.shape[0], -1)
def pack_scales_and_zeros(scales, zeros, *, scales_dtype=torch.float):
assert scales.shape == zeros.shape
assert scales.dtype == scales_dtype
assert zeros.dtype == scales_dtype
return (
torch.cat(
[
scales.reshape(scales.size(0), scales.size(1), 1),
zeros.reshape(zeros.size(0), zeros.size(1), 1),
],
2,
)
.transpose(0, 1)
.contiguous()
)
def unpack_scales_and_zeros(scales_and_zeros):
assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
assert scales_and_zeros.dtype == torch.float
return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)
def group_quantize_tensor_from_qparams(w, scales, zeros, n_bit=4, groupsize=128):
assert groupsize > 1
# needed for GPTQ single column quantize
if groupsize > w.shape[-1] and scales.shape[-1] == 1:
groupsize = w.shape[-1]
assert w.shape[-1] % groupsize == 0
assert w.dim() == 2
to_quant = w.reshape(-1, groupsize)
assert torch.isnan(to_quant).sum() == 0
scales = scales.reshape(-1, 1)
zeros = zeros.reshape(-1, 1)
min_val = zeros - scales * (2 ** (n_bit - 1))
max_int = 2**n_bit - 1
min_int = 0
w_int32 = (
to_quant.sub(min_val)
.div(scales)
.round()
.clamp_(min_int, max_int)
.to(torch.int32)
.reshape_as(w)
)
return w_int32
def group_quantize_tensor(w, n_bit=4, groupsize=128):
scales, zeros = get_group_qparams(w, n_bit, groupsize)
w_int32 = group_quantize_tensor_from_qparams(w, scales, zeros, n_bit, groupsize)
scales_and_zeros = pack_scales_and_zeros(scales, zeros)
return w_int32, scales_and_zeros
def group_dequantize_tensor_from_qparams(
w_int32, scales, zeros, n_bit=4, groupsize=128
):
assert groupsize > 1
# needed for GPTQ single column dequantize
if groupsize > w_int32.shape[-1] and scales.shape[-1] == 1:
groupsize = w_int32.shape[-1]
assert w_int32.shape[-1] % groupsize == 0
assert w_int32.dim() == 2
w_int32_grouped = w_int32.reshape(-1, groupsize)
scales = scales.reshape(-1, 1)
zeros = zeros.reshape(-1, 1)
w_dq = (
w_int32_grouped.sub(2 ** (n_bit - 1)).mul(scales).add(zeros).reshape_as(w_int32)
)
return w_dq
def group_dequantize_tensor(w_int32, scales_and_zeros, n_bit=4, groupsize=128):
scales, zeros = unpack_scales_and_zeros(scales_and_zeros)
return group_dequantize_tensor_from_qparams(
w_int32, scales, zeros, n_bit, groupsize
)
#########################################################################
### QuantHandler API definition ###
class QuantHandler:
def __init__(self, mod):
self.mod = mod
def create_quantized_state_dict(self) -> Dict: # "StateDict"
pass
def convert_for_runtime(self) -> nn.Module:
pass
def quantized_model(self) -> nn.Module:
model_updated_state_dict = self.create_quantized_state_dict()
self.convert_for_runtime()
self.mod.load_state_dict(model_updated_state_dict)
return self.mod
#########################################################################
##### Weight-only int8 per-channel quantized code ######
def replace_linear_weight_only_int8_per_channel(
module, device, node_type, groupsize=None
):
if groupsize is not None and groupsize != 0:
pass # groupsize = 2 ** groupsize
for name, child in module.named_children():
# print(f"name: {name}")
if isinstance(child, nn.Linear):
if (
(node_type == "*")
or (node_type == "output" and name == "output")
or (node_type == "!output" and name != "output")
):
# print(f"{name, child}")
# print(f"in_features: {child.in_features}")
# print(f"out_features: {child.out_features}")
setattr(
module,
name,
WeightOnlyInt8Linear(
device, child.in_features, child.out_features, groupsize
),
)
else:
replace_linear_weight_only_int8_per_channel(
child, device, node_type, groupsize
)
class WeightOnlyInt8QuantHandler(QuantHandler):
def __init__(
self,
mod,
device,
*,
node_type: str = "*",
bitwidth: Optional[int] = None,
groupsize: Optional[int] = None,
):
self.mod = mod
self.device = device
self.groupsize = groupsize
self.node_type = node_type
if bitwidth is None:
self.bitwidth = 8
else:
self.bitwidth = bitwidth
@torch.no_grad()
def create_quantized_state_dict(self) -> Dict:
cur_state_dict = self.mod.state_dict()
if self.bitwidth == 4:
range_min = -8
range_max = 7
elif self.bitwidth == 8:
range_min = -128
range_max = 127
else:
raise ValueError(f"Unsupported bitwidth {self.bitwidth}")
for fqn, mod in self.mod.named_modules():
# print(f"maybe? quantize {fqn}...{type(mod)}")
if isinstance(mod, torch.nn.Linear):
# print(f"candidate {fqn}, nodetype {self.node_type}")
if (
(self.node_type == "*")
or (self.node_type == "output" and fqn in ["output", "final_proj"])
or (
self.node_type == "!output"
and fqn not in ["output", "final_proj"]
)
):
# print(
# f"quantize {self.node_type} {fqn, mod} with groupsize {self.groupsize}, bitwidth {self.bitwidth}"
# )
# print(f"initial weight shape {mod.weight.shape}")
input_weight = mod.weight.float()
# print(f"expanded weight shape {input_weight.shape}")
weight, scales, _ = dynamically_quantize_per_channel(
input_weight,
range_min,
range_max,
torch.int8,
self.groupsize,
scales_dtype=mod.weight.dtype,
)
weight = weight.to(device=self.device)
scales = scales.to(device=self.device)
cur_state_dict[f"{fqn}.weight"] = weight
# squeeze makes groupsize=rowsize unidimensional
cur_state_dict[f"{fqn}.scales"] = scales.squeeze(dim=-1)
return cur_state_dict
def convert_for_runtime(self) -> nn.Module:
replace_linear_weight_only_int8_per_channel(
self.mod, self.device, self.node_type, self.groupsize
)
return self.mod
def quantized_model(self) -> nn.Module:
model_updated_state_dict = self.create_quantized_state_dict()
self.convert_for_runtime()
self.mod.load_state_dict(model_updated_state_dict)
return self.mod
class WeightOnlyInt8Linear(torch.nn.Module):
__constants__ = ["in_features", "out_features"]
in_features: int
out_features: int
weight: torch.Tensor
def __init__(
self,
device,
in_features: int,
out_features: int,
groupsize: Optional[int] = None,
bias: bool = True,
dtype=None,
) -> None:
super().__init__()
# print(f"group size: {groupsize}")
self.in_features = in_features
self.out_features = out_features
self.register_buffer(
"weight",
torch.empty((out_features, in_features), dtype=torch.int8, device=device),
)
dtype = get_precision()
if groupsize is None or (groupsize == 0):
self.register_buffer(
"scales", torch.ones(out_features, dtype=dtype, device=device)
)
else:
groups = (in_features + groupsize - 1) // groupsize
self.register_buffer(
"scales", torch.ones(out_features, groups, dtype=dtype, device=device)
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
scales = self.scales
weight = self.weight
scales = scales.view(scales.shape[0], -1)
no_groups = scales.shape[1]
# need a formulation / custom op for good performance on both eager, CUDA compiled, CPU compiled and ET exported
# maybe use IR-based rewriting?
# for now, we special-case channel-wise, because we know how to make that fast (but does not work for groupwise)
if scales.shape[1] == 1:
return F.linear(input, weight.to(dtype=input.dtype)) * self.scales
else:
return F.linear(
input,
(
weight.to(dtype=input.dtype).view(weight.shape[0], no_groups, -1)
* scales.view(weight.shape[0], no_groups, -1)
).view(weight.shape[0], -1),
)
#########################################################################
##### embedding table quantization ######
def replace_embedding_weight_only_grouped_int8_per_channel(
module, device, bitwidth: int = 8, groupsize: Optional[int] = None, packed=False
):
for name, child in module.named_children():
# print(f"name: {name}")
if isinstance(child, nn.Embedding):
# print(f"{name, child}")
# print(f"weights size: {child.weight.size()}")
setattr(
module,
name,
QuantizedGroupEmbedding(
device=device,
vocab_size=child.weight.shape[0],
embedding_dim=child.weight.shape[1],
groupsize=groupsize,
packed=packed,
),
)
else:
replace_embedding_weight_only_grouped_int8_per_channel(
child, device, bitwidth, groupsize, packed
)
class EmbeddingOnlyInt8QuantHandler(QuantHandler):
def __init__(
self,
mod,
device,
*,
bitwidth: int = 8,
groupsize: Optional[int] = None,
packed=False,
):
if isinstance(packed, str):
packed = packed == "True"
self.mod = mod
self.device = device
self.groupsize = groupsize
self.bitwidth = bitwidth
self.packed = packed
if (bitwidth != 4) and packed:
raise RuntimeError("pack only works with bitsize 4")
@torch.no_grad()
def create_quantized_state_dict(self, packed=False) -> Dict:
cur_state_dict = self.mod.state_dict()
if self.bitwidth == 4:
range_min = -8
range_max = 7
elif self.bitwidth == 8:
range_min = -128
range_max = 127
else:
raise ValueError(f"Unsupported bitwidth {self.bitwidth}")
for fqn, mod in self.mod.named_modules():
if isinstance(mod, nn.Embedding):
# print("****")
# print(f"Embedding identified: {fqn, mod}")
# print(f"weights size: {mod.weight.size()}")
# print(f"quantize {fqn}...")
# print(
# f"quantize {fqn, mod} with groupsize {self.groupsize}, bitwidth {self.bitwidth}"
# )
weight, scales, _ = dynamically_quantize_per_channel(
mod.weight.float(),
range_min,
range_max,
torch.int8,
self.groupsize,
scales_dtype=mod.weight.dtype,
)
if packed:
if weight.shape[-1] % 2 != 0:
raise RuntimeError("automatic padding not implemented yet")
weight_range_shifted = weight.add(8).view(torch.uint8)
weight_view = weight_range_shifted.view(
weight.shape[0], weight.shape[1] // 2, 2
)
weight_even = weight_view[:, :, 0] * 16 # left shift 4
weight_odd = weight_view[:, :, 1]
weight_packed = weight_even + weight_odd
weight = weight_packed
weight = weight.to(device=self.device)
scales = scales.to(device=self.device)
# Update state dict
cur_state_dict[f"{fqn}.weight"] = weight
# squeeze makes groupsize=rowsize unidimensional
cur_state_dict[f"{fqn}.scales"] = scales.squeeze(dim=-1)
return cur_state_dict
def convert_for_runtime(self) -> nn.Module:
replace_embedding_weight_only_grouped_int8_per_channel(
self.mod, self.device, self.bitwidth, self.groupsize, self.packed
)
return self.mod
def quantized_model(self) -> nn.Module:
model_updated_state_dict = self.create_quantized_state_dict(self.packed)
self.convert_for_runtime()
self.mod.load_state_dict(model_updated_state_dict)
return self.mod
class QuantizedGroupEmbedding(torch.nn.Module):
def __init__(
self,
device,
vocab_size: int,
embedding_dim: int,
groupsize: Optional[int] = None,
dtype=torch.half,
packed=False,
) -> None:
super().__init__()
if groupsize is None or groupsize == 0:
groupsize = embedding_dim
self.groupsize = groupsize
self.dtype = dtype
self.packed = packed
if not packed:
self.register_buffer(
"weight",
torch.empty(
(vocab_size, embedding_dim), dtype=torch.int8, device=device
),
)
else: # packed
self.register_buffer(
"weight",
torch.empty(
(vocab_size, embedding_dim // 2), dtype=torch.uint8, device=device
),
)
groups_per_row = (embedding_dim + groupsize - 1) // groupsize
if groups_per_row > 1:
self.register_buffer(
"scales",
torch.ones(
(vocab_size, groups_per_row), dtype=torch.float16, device=device
),
)
else:
self.register_buffer(
"scales", torch.ones((vocab_size,), dtype=torch.float16, device=device)
)
@torch.no_grad()
def forward(self, indices: torch.Tensor) -> torch.Tensor:
if False: # Used for Executorch
return torch.ops.llama_quantized.embedding_byte.dtype(
self.weight, self.scales, None, 0, 0, indices, dtype=self.dtype
)
# result_weights = self.weight.index_select(0, indices.view(-1))
# result_scales = self.scales.index_select(0, indices.view(-1))
if self.packed:
weight_even = self.weight.div(16, rounding_mode="trunc")
weight_odd = self.weight.remainder(16)
weight_unpacked = torch.stack((weight_even, weight_odd), dim=-1)
weight = weight_unpacked.view(self.weight.shape[0], -1)
weight = weight.view(torch.int8).add(-8)
else:
weight = self.weight
scales = self.scales.view(weight.shape[0], -1)
result_weights = F.embedding(indices, weight)
result_scales = F.embedding(indices, scales)
rw_view = result_weights.to(dtype=result_scales.dtype).view(
tuple(
result_weights.shape[:-1]
+ (
scales.shape[1],
-1,
)
)
)
rs_view = result_scales.view(
tuple(result_scales.shape[:-1])
+ (
scales.shape[1],
1,
)
)
# print(f"rw_view {rw_view.shape}")
# print(f"rs_view {rs_view.shape}")
r = rw_view * rs_view
return r.view(indices.size() + (-1,))
# r = result_weights.to(dtype=result_scales.dtype).view(list(result_weights.shape[:-1] + (scales.shape[1], -1, )) * result_scales.view(scales.shape[-1] + (scales.shape[1], 1, ))
#########################################################################
##### weight only int4 per channel groupwise quantized code ######
def _int4_prepare_int4_weight_and_scales_and_zeros(
weight_bf16, groupsize, inner_k_tiles
):
weight_int32, scales_and_zeros = group_quantize_tensor(
weight_bf16, n_bit=4, groupsize=groupsize
)
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
weight_int32, inner_k_tiles
)
return weight_int4pack, scales_and_zeros
def _int4_calc_padded_size(k, groupsize=1, innner_k_tiles=1):
from build.model import find_multiple
return find_multiple(k, 1024)
def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
origin_x_size = x.size()
x = x.reshape(-1, origin_x_size[-1])
if "cuda" in str(x.device):
c = torch.ops.aten._weight_int4pack_mm(
x.to(torch.bfloat16),
weight_int4pack,
groupsize,
scales_and_zeros.to(torch.bfloat16),
).to(
x.dtype
) # cast back to x.dtype
else:
c = torch.ops.aten._weight_int4pack_mm(
x,
weight_int4pack,
groupsize,
scales_and_zeros,
)
new_shape = origin_x_size[:-1] + (out_features,)
c = c.reshape(new_shape)
return c
def _int4_check_linear_int4_k(k, groupsize=1, inner_k_tiles=1):
return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
def replace_linear_int4(
module,
device,
groupsize,
inner_k_tiles,
padding_allowed,
):
for name, child in module.named_children():
if isinstance(child, nn.Linear):
if (
_int4_check_linear_int4_k(child.in_features, groupsize, inner_k_tiles)
or padding_allowed
):
setattr(
module,
name,
WeightOnlyInt4Linear(
device,
child.in_features,
child.out_features,
bias=False,
groupsize=groupsize,
inner_k_tiles=inner_k_tiles,
),
)
else:
replace_linear_int4(
child, device, groupsize, inner_k_tiles, padding_allowed
)
class WeightOnlyInt4QuantHandler(QuantHandler):
def __init__(
self, mod, device, *, groupsize=128, inner_k_tiles=8, padding_allowed=True
):
self.mod = mod
self.device = device
self.groupsize = groupsize
self.inner_k_tiles = inner_k_tiles
self.padding_allowed = padding_allowed
assert groupsize in [32, 64, 128, 256]
assert inner_k_tiles in [2, 4, 8]
@torch.no_grad()
def create_quantized_state_dict(self):
cur_state_dict = self.mod.state_dict()
for fqn, mod in self.mod.named_modules():
if isinstance(mod, torch.nn.Linear):
assert not mod.bias
out_features = mod.out_features
in_features = mod.in_features
assert out_features % 8 == 0, "require out_features % 8 == 0"
# print(f"linear: {fqn}, in={in_features}, out={out_features}")
weight = mod.weight.data
if not _int4_check_linear_int4_k(
in_features, self.groupsize, self.inner_k_tiles
):
if self.padding_allowed:
import torch.nn.functional as F
from build.model import find_multiple
print(
f"warning: {fqn} is padded to satisfy in_features % 1024 == 0"
)
padded_in_features = find_multiple(in_features, 1024)
weight = F.pad(
weight, pad=(0, padded_in_features - in_features)
)
else:
print(
f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, "
+ "and that groupsize and inner_k_tiles*16 evenly divide into it"
)
continue
weight_int4pack, scales_and_zeros = (
_int4_prepare_int4_weight_and_scales_and_zeros(
weight.to(torch.float), self.groupsize, self.inner_k_tiles
)
)
weight_int4pack = weight_int4pack.to(device=self.device)
scales_and_zeros = scales_and_zeros.to(device=self.device)
cur_state_dict[f"{fqn}.weight"] = weight_int4pack
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros
return cur_state_dict
def convert_for_runtime(self):
replace_linear_int4(
self.mod,
self.device,
self.groupsize,
self.inner_k_tiles,
self.padding_allowed,
)
return self.mod
def quantized_model(self) -> nn.Module:
model_updated_state_dict = self.create_quantized_state_dict()
self.convert_for_runtime()
self.mod.load_state_dict(model_updated_state_dict)
return self.mod
class WeightOnlyInt4Linear(torch.nn.Module):
__constants__ = ["in_features", "out_features"]
in_features: int
out_features: int
weight: torch.Tensor
def __init__(
self,
device: str,
in_features: int,
out_features: int,
bias=True,
dtype=None,
groupsize: int = 128,
inner_k_tiles: int = 8,
) -> None:
super().__init__()
self.padding = not _int4_check_linear_int4_k(
in_features, groupsize, inner_k_tiles
)
if self.padding:
from build.model import find_multiple
self.origin_in_features = in_features
in_features = find_multiple(in_features, 1024)
self.in_features = in_features
self.out_features = out_features
assert not bias, "require bias=False"
self.groupsize = groupsize
self.inner_k_tiles = inner_k_tiles
assert out_features % 8 == 0, "require out_features % 8 == 0"
assert (
in_features % (inner_k_tiles * 16) == 0
), "require in_features % (innerKTiles * 16) == 0"
self.register_buffer(
"weight",
torch.empty(
(
out_features // 8,
in_features // (inner_k_tiles * 16),
32,
inner_k_tiles // 2,
),
dtype=torch.int32,
device=device,
),
)
# MKG: torch.float
self.register_buffer(
"scales_and_zeros",
torch.empty(
(in_features // groupsize, out_features, 2),
dtype=get_precision(),
device=device,
),
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
# MKG torch.float
# input = input.to(torch.float)
if self.padding:
import torch.nn.functional as F
input = F.pad(input, pad=(0, self.in_features - self.origin_in_features))
return linear_forward_int4(
input, self.weight, self.scales_and_zeros, self.out_features, self.groupsize
)
#########################################################################
##### GPTQ #####
def _check_linear_int4_k(k, groupsize=1):
return k % groupsize == 0
class GPTQQuantHandler(QuantHandler):
"""
This class implements a GPTQ QuantHandler that can be used to apply GPTQ to a model in concert with the GenericGPTQRunner class.
Unlike the base QuantHandler class, the user does not need to implement the create_quantized_state_dict, instead they have to reimplement
__init__ such that it defines the functions for the quantization mode. User is expected to reimplement convert_for_runtime.
The following functions (which must be defined in __init__) are used to define the quantization mode for both GPTQ and
create_quantized_state_dict. Here is a description of each function.
get_qparams_func:
A function that calculates the quantization qparams for an input tensor.
Args:
weight: A 2d weight tensor with non-integer dtype.
Returns:
qparams: it can have any format but will need to be handled by the other defined functions below.
quantize_func:
A function that applies quantization to an input tensor. It should be noted
that this function needs to be able to handle quantizing the entire weight tensor, a single group,
or a single column.
Args:
weight: A 2d weight tensor with non-integer dtype.
qparams: the output from get_qparams_func
Returns:
quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
dequantize_func:
A function that dequantizes an input quantized weight tensor. It should be noted
that this function needs to be able to handle dequantizing the entire weight tensor, a single group,
or a single column.
Args:
quantized_weight: A 2d quantized weight tensor (generally with an integer dtype)
qparams: the output from get_qparams_func
Returns:
weight: A 2d weight tensor with non-integer dtype.