forked from pytorch/FBGEMM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
QuantUtils.h
366 lines (330 loc) · 11.3 KB
/
QuantUtils.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
#pragma once
#include "./FbgemmBuild.h"
#include "./QuantUtilsAvx2.h"
#include "./Types.h"
#include "./Utils.h"
#include <algorithm>
#include <cassert>
#include <cmath>
#include <cstdint>
#include <limits>
namespace fbgemm {
FBGEMM_API TensorQuantizationParams ChooseQuantizationParams(
float min,
float max,
std::int32_t qmin,
std::int32_t qmax,
bool preserve_sparsity = false,
bool force_scale_power_of_two = false);
FBGEMM_API void ChooseRequantizationMultiplier(
float real_multiplier,
std::int32_t* quantized_multiplier,
int* right_shift,
int requantization_multiplier_precision = 32);
////////////////////////////////////////////////////////////////////////////////
// Utility functions
// Clamp src in T1 to the desired precision and convert it to T2
// TODO: T26263653 fix signed-integer-overflow undefined behavior
template <typename T1, typename T2 = std::uint8_t>
NO_SANITIZE("signed-integer-overflow")
T2 clamp(T1 src, int precision, bool is_signed = false) {
std::int32_t min = is_signed ? -(1LL << (precision - 1)) : 0;
std::int32_t max =
is_signed ? ((1LL << (precision - 1)) - 1) : (1LL << precision) - 1;
// Make sure T1 and T2 can represent the precision
assert(min >= std::numeric_limits<T1>::lowest());
assert(min >= std::numeric_limits<T2>::lowest());
assert(max <= std::numeric_limits<T1>::max());
assert(max <= std::numeric_limits<T2>::max());
return std::min<T1>(std::max<T1>(src, min), max);
}
/// Quantize src using zero_point and scale, clamp to the specified precision,
/// and convert it to type T
template <typename T, bool LEGACY = true>
T Quantize(
float src,
std::int32_t zero_point,
float scale,
int result_precision,
bool result_is_signed = std::is_signed<T>::value) {
// Note: We want to multiply with src with inv_scale instead of
// dividing src by scale. The same is done in vector code and
// at other places.
//
// Example:
// With scale = 0.00214854861f, zero_point = 0 and src = 0.273939937f
// transformed_val is 127.5 for src * inv_scale while
// transformed_val is 127.499992 for src / scale.
// Eventually 127.5 gets rounded to 128 while 127.499992 gets rounded to 127.
float inv_scale = 1.0f / scale;
float transformed_val = src * inv_scale;
// nearbyint here performs round-to-nearest-ties-to-even with
// default rounding mode.
// For example, nearbyint(1.4) is 1.0, nearbyint(1.5) is 2.0
// and nearbyint(2.5) is 2.0
// Adding zero_point before or after rounding can make a difference
// in exactly halfway cases.
if (LEGACY) {
transformed_val = std::nearbyint(zero_point + transformed_val);
} else {
transformed_val = zero_point + std::nearbyint(transformed_val);
}
// Please note the use of double. Unlike float, a double can represent
// all int32 values exactly. Using a float results in a float value >
// INT32_MAX conversion to int32 in clamp function and hence an UBSAN error.
return clamp<double, T>(transformed_val, result_precision, result_is_signed);
}
template <typename T, bool LEGACY = true>
T Quantize(float src, const TensorQuantizationParams& qparams) {
return Quantize<T, LEGACY>(
src, qparams.zero_point, qparams.scale, qparams.precision);
}
template <typename T, bool LEGACY = true>
FBGEMM_API void Quantize(
const float* src,
T* dst,
int len,
const TensorQuantizationParams& qparams,
int thread_id = 0,
int num_threads = 1);
/*
* @brief Quantize floating point data in src to type T
*
* @tparam T output quantized data type (int8_t, uint8_t and int32_t are
* supported)
*
* @tparam T LAYOUT layout of input tensor in src. (KCX and KXC are supported)
* KCX corresponds to KCRS or KCTRS (for weight tensors with
* time dimension)
* KXC corresponds to KRSC or KTRSC (for weight tensors with
* time dimension)
*
* @param K Output channels for weight tensors
* @param C Number of channels
* @param X R*S or T*R*S
* @param G Groups (if G == C the function performs channelwise quantization;
* if 1 < G < C the function performs groupwise quantization;
* if G == 1 the function performs per tensor quantization;)
* @param scales floating point scales.
* Size should be equal G
* @param zero_points zero points (should be reprsentable in type T).
* Size should be equal G
*/
template <typename T, layout_t LAYOUT = layout_t::KCX>
FBGEMM_API void QuantizeGroupwise(
const float* src,
int K,
int C,
int X,
int G,
const float* scales,
const std::int32_t* zero_points,
T* dst);
template <typename T>
float Dequantize(T src, const TensorQuantizationParams& qparams) {
return qparams.scale * (src - qparams.zero_point);
}
template <typename T>
void Dequantize(
const T* src,
float* dst,
int len,
const TensorQuantizationParams& qparams,
int thread_id = 0,
int num_threads = 1) {
int i_begin, i_end;
fbgemmPartition1D(thread_id, num_threads, len, i_begin, i_end);
for (auto i = i_begin; i < i_end; i++) {
dst[i] = Dequantize(src[i], qparams);
}
}
template <typename T>
float FusedQuantizeDequantize(
float src,
const TensorQuantizationParams& qparams) {
T q = Quantize<T, false>(
src, qparams.zero_point, qparams.scale, qparams.precision);
return Dequantize<T>(q, qparams);
}
/*
Fused integer quantization dequantization kernel to accelerate
quantization-aware training. Quantize fp32 values in src to (u)int8 using the
provided qparams, and dequantize quantized integer values back into fp32.
*/
template <typename T>
FBGEMM_API void FusedQuantizeDequantize(
const float* src,
float* dst,
int len,
const TensorQuantizationParams& qparams,
int thread_id = 0,
int num_threads = 1,
float noise_ratio = 0.0f);
////////////////////////////////////////////////////////////////////////////////
// Requantization (pure fixed-point)
FBGEMM_API std::int64_t
SaturatingRoundingMulWithShift(std::int32_t a, std::int32_t b, int right_shift);
template <typename T>
T Requantize(
std::int32_t src, // int32 input before requantization
std::int32_t zero_point,
std::int32_t multiplier,
int right_shift,
int result_precision,
bool result_is_signed = false) {
std::int64_t quantized_down =
zero_point + SaturatingRoundingMulWithShift(src, multiplier, right_shift);
return clamp<std::int64_t, T>(
quantized_down, result_precision, result_is_signed);
}
template <typename T>
T RequantizeFixedPoint(
std::int32_t src, // int32 input before requantization
const RequantizationParams& params) {
return Requantize<T>(
src,
params.target_qparams.zero_point,
params.multiplier,
params.right_shift,
params.target_qparams.precision);
}
template <typename T>
FBGEMM_API void RequantizeFixedPoint(
const std::int32_t* src,
T* dst,
int len,
const RequantizationParams& params,
int thread_id = 0,
int num_threads = 1);
////////////////////////////////////////////////////////////////////////////////
// Requantization (with floats)
template <typename T>
T Requantize(
std::int32_t src, // int32 input before requantization
std::int32_t zero_point,
float multiplier,
int result_precision,
bool result_is_signed = false) {
long quantized_down = zero_point + std::lrintf(src * multiplier);
return clamp<long, T>(quantized_down, result_precision, result_is_signed);
}
template <typename T>
T Requantize(
std::int32_t src, // int32 input before requantization
const RequantizationParams& params) {
return Requantize<T>(
src,
params.target_qparams.zero_point,
params.real_multiplier,
params.target_qparams.precision);
}
template <typename T>
FBGEMM_API void Requantize(
const std::int32_t* src,
T* dst,
int len,
const RequantizationParams& params,
int thread_id = 0,
int num_threads = 1);
/**
* Convert float (fp32 or fp16) inputs to rowwise quantized outputs.
* bitrate specifies the number of bits in quantized output.
* Scale and Bias are in fp16. Each row's Scale and Bias are stored in
* the row itself (fused) at the end.
*
* @param bit_rate can be 2, 4, or 8
*/
template <typename InputType>
FBGEMM_API void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf(
int bit_rate,
const InputType* input,
size_t input_rows,
int input_columns,
std::uint8_t* output);
/**
* Convert fused rowwise quantized inputs to float (fp32 or fp16).
* bitrate specifies the number of bits in quantized input.
* Scale and Bias are in fp16. Each row's Scale and Bias are stored in
* the row itself (fused) at the end.
*
* @param bit_rate can be 2, 4, or 8
*/
template <typename OutputType>
FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf(
int bit_rate,
const uint8_t* input,
size_t input_rows,
int input_columns,
OutputType* output);
/**
* Convert float or half inputs to rowwise quantized (8-bit) outputs.
* Scale and Bias are in float. Each row's Scale and Bias are stored in
* the row itself (fused) at the end.
*
* This version intentionally supports only 8-bit because we want to discourage
* the usage of float scale and bias with 2 and 4 bit cases as that diminishes
* the overall memory savings.
*/
template <typename InputType>
FBGEMM_API void FloatOrHalfToFused8BitRowwiseQuantizedSBFloat(
const InputType* input,
size_t input_rows,
int input_columns,
std::uint8_t* output);
/**
* Convert fused rowwise quantized (8-bit) inputs to float or half outputs.
* Scale and Bias are in float. Each row's Scale and Bias are stored in
* the row itself (fused) at the end.
*
* This version intentionally supports only 8-bit because
* the corresponding quantize version only supports 8-bit.
*/
template <typename OutputType>
FBGEMM_API void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf(
const uint8_t* input,
size_t input_rows,
int input_columns,
OutputType* output);
/**
* Same as ToFusedNBitRowwiseQuantizedSBHalf but unoptimized.
* This should not be called directly except in testing.
*/
template <typename InputType>
FBGEMM_API void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfRef(
int bit_rate,
const InputType* input,
size_t input_rows,
int input_columns,
std::uint8_t* output);
/**
* Same as FloatOrHalfToFused8BitRowwiseQuantizedSBFloat but unoptimized.
* This should not be called directly except in testing.
*/
template <typename InputType>
FBGEMM_API void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatRef(
const InputType* input,
size_t input_rows,
int input_columns,
std::uint8_t* output);
/**
* Same as FusedNBitRowwiseQuantizedSBHalfToFloat but unoptimized.
* This should not be called directly except in testing.
*/
template <typename OutputType>
FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef(
int bit_rate,
const uint8_t* input,
size_t input_rows,
int input_columns,
OutputType* output);
/**
* Same as Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf but unoptimized.
* This should not be called directly except in testing.
*/
template <typename OutputType>
FBGEMM_API void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef(
const uint8_t* input,
size_t input_rows,
int input_columns,
OutputType* output);
} // namespace fbgemm