forked from mit-han-lab/smoothquant
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfake_quant.py
112 lines (95 loc) · 3.99 KB
/
fake_quant.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
from torch import nn
from functools import partial
def quantize_weight_per_channel_absmax(w, n_bits=8):
# w: (out_features, in_features)
scales = w.abs().max(dim=-1, keepdim=True)[0]
q_max = 2**(n_bits-1)-1
scales.clamp_(min=1e-5).div_(q_max)
w.div_(scales).round_().mul_(scales)
return w
@torch.no_grad()
def quantize_weight_per_tensor_absmax(w, n_bits=8):
# w: (out_features, in_features)
scales = w.abs().max()
q_max = 2**(n_bits-1)-1
scales.clamp_(min=1e-5).div_(q_max)
w.div_(scales).round_().mul_(scales)
return w
@torch.no_grad()
def quantize_activation_per_token_absmax(t, n_bits=8):
t_shape = t.shape
t.view(-1, t_shape[-1])
scales = t.abs().max(dim=-1, keepdim=True)[0]
q_max = 2**(n_bits-1)-1
scales.clamp_(min=1e-5).div_(q_max)
t.div_(scales).round_().mul_(scales)
return t
@torch.no_grad()
def quantize_activation_per_tensor_absmax(t, n_bits=8):
t_shape = t.shape
t.view(-1, t_shape[-1])
scales = t.abs().max()
q_max = 2**(n_bits-1)-1
scales.clamp_(min=1e-5).div_(q_max)
t.div_(scales).round_().mul_(scales)
return t
class W8A8Linear(nn.Module):
def __init__(self, in_features, out_features, bias=True, act_quant='per_token', quantize_output=False):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.register_buffer('weight', torch.randn(self.out_features,
self.in_features, dtype=torch.float16, requires_grad=False))
if bias:
self.register_buffer('bias', torch.zeros(
(1, self.out_features), dtype=torch.float16, requires_grad=False))
else:
self.register_buffer('bias', None)
if act_quant == 'per_token':
self.act_quant_name = 'per_token'
self.act_quant = partial(
quantize_activation_per_token_absmax, n_bits=8)
elif act_quant == 'per_tensor':
self.act_quant_name = 'per_tensor'
self.act_quant = partial(
quantize_activation_per_tensor_absmax, n_bits=8)
else:
raise ValueError(f'Invalid act_quant: {act_quant}')
if quantize_output:
self.output_quant_name = self.act_quant_name
self.output_quant = self.act_quant
else:
self.output_quant_name = 'None'
self.output_quant = lambda x: x
def to(self, *args, **kwargs):
super(W8A8Linear, self).to(*args, **kwargs)
self.weight = self.weight.to(*args, **kwargs)
if self.bias is not None:
self.bias = self.bias.to(*args, **kwargs)
return self
@torch.no_grad()
def forward(self, x):
q_x = self.act_quant(x)
y = torch.functional.F.linear(q_x, self.weight, self.bias)
q_y = self.output_quant(y)
return q_y
@staticmethod
def from_float(module, weight_quant='per_channel', act_quant='per_token', quantize_output=False):
assert isinstance(module, torch.nn.Linear)
new_module = W8A8Linear(
module.in_features, module.out_features, module.bias is not None, act_quant=act_quant, quantize_output=quantize_output)
if weight_quant == 'per_channel':
new_module.weight = quantize_weight_per_channel_absmax(
module.weight, n_bits=8) # use 8-bit integer for weight
elif weight_quant == 'per_tensor':
new_module.weight = quantize_weight_per_tensor_absmax(
module.weight, n_bits=8)
else:
raise ValueError(f'Invalid weight_quant: {weight_quant}')
new_module.weight_quant_name = weight_quant
if module.bias is not None:
new_module.bias = module.bias
return new_module
def __repr__(self):
return f'W8A8Linear({self.in_features}, {self.out_features}, bias={self.bias is not None}, weight_quant={self.weight_quant_name}, act_quant={self.act_quant_name}, output_quant={self.output_quant_name})'