-
Notifications
You must be signed in to change notification settings - Fork 4
/
adaptive_rounding.py
80 lines (67 loc) · 3.09 KB
/
adaptive_rounding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import torch
from torch import nn
import logging
from qdiff.quant_layer import UniformAffineQuantizer, round_ste, floor_ste
logger = logging.getLogger(__name__)
class AdaRoundQuantizer(nn.Module):
"""
Adaptive Rounding Quantizer, used to optimize the rounding policy
by reconstructing the intermediate output.
Based on
Up or Down? Adaptive Rounding for Post-Training Quantization: https://arxiv.org/abs/2004.10568
:param uaq: UniformAffineQuantizer, used to initialize quantization parameters in this quantizer
:param round_mode: controls the forward pass in this quantizer
:param weight_tensor: initialize alpha
"""
def __init__(self, uaq: UniformAffineQuantizer, weight_tensor: torch.Tensor, round_mode='learned_round_sigmoid'):
super(AdaRoundQuantizer, self).__init__()
# copying all attributes from UniformAffineQuantizer
self.n_bits = uaq.n_bits
self.sym = uaq.sym
self.delta = nn.Parameter(uaq.delta)
self.zero_point = nn.Parameter(uaq.zero_point)
self.n_levels = uaq.n_levels
self.round_mode = round_mode
self.alpha = None
self.soft_targets = False
# params for sigmoid function
self.gamma, self.zeta = -0.1, 1.1
self.beta = 2/3
self.init_alpha(x=weight_tensor.clone())
def forward(self, x):
if self.round_mode == 'nearest':
x_int = round_ste(x / self.delta)
elif self.round_mode == 'nearest_ste':
x_int = round_ste(x / self.delta)
elif self.round_mode == 'stochastic':
# x_floor = torch.floor(x / self.delta)
x_floor = floor_ste(x / self.delta)
rest = (x / self.delta) - x_floor # rest of rounding
x_int = x_floor + torch.bernoulli(rest)
logger.info('Draw stochastic sample')
elif self.round_mode == 'learned_hard_sigmoid':
# x_floor = torch.floor(x / self.delta)
x_floor = floor_ste(x / self.delta)
if self.soft_targets:
x_int = x_floor + self.get_soft_targets()
else:
x_int = x_floor + (self.alpha >= 0).float()
else:
raise ValueError('Wrong rounding mode')
x_quant = torch.clamp(x_int + self.zero_point, 0, self.n_levels - 1)
x_float_q = (x_quant - self.zero_point) * self.delta
return x_float_q
def get_soft_targets(self):
return torch.clamp(torch.sigmoid(self.alpha) * (self.zeta - self.gamma) + self.gamma, 0, 1)
def init_alpha(self, x: torch.Tensor):
x_floor = torch.floor(x / self.delta)
if self.round_mode == 'learned_hard_sigmoid':
# logger.info('Init alpha to be FP32')
rest = (x / self.delta) - x_floor # rest of rounding [0, 1)
alpha = -torch.log((self.zeta - self.gamma) / (rest - self.gamma) - 1) # => sigmoid(alpha) = rest
self.alpha = nn.Parameter(alpha)
else:
raise NotImplementedError
def extra_repr(self):
s = 'bit={n_bits}, symmetric={sym}, round_mode={round_mode}'
return s.format(**self.__dict__)