forked from pytorch/FBGEMM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathQuantUtilsAvx2.h
157 lines (136 loc) · 3.93 KB
/
QuantUtilsAvx2.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#pragma once
#include <cstdint>
#include "./FbgemmBuild.h"
#include "./UtilsAvx2.h"
namespace fbgemm {
// Structs from gemmlowp
//
// A structure to hold quantization parameters 'scale' and 'zero_point'.
// The meaning of these values is as the constants in the quantization equation
//
// real_value = scale * (quantized_value - zero_point)
//
// In other words, 'zero_point' is the quantized value that corresponds
// to the real value 0, and 'scale' is the difference of real values
// corresponding to consecutive quantized values.
struct FBGEMM_API TensorQuantizationParams {
float scale;
std::int32_t zero_point;
int precision;
float Min() const;
float Max() const;
};
// Parameters when we scale from int32 intermediate matrix multiplication
// results to 8-bit integers
struct FBGEMM_API RequantizationParams {
// For floating-point requantization
float real_multiplier;
// For fixed-point requantization
std::int32_t multiplier;
int right_shift;
TensorQuantizationParams target_qparams;
};
////////////////////////////////////////////////////////////////////////////////
// Utility functions
template <typename T = std::uint8_t, bool LEGACY = true>
void QuantizeAvx2(
const float* src,
T* dst,
int len,
const TensorQuantizationParams& qparams);
template <typename T = std::uint8_t>
void FusedQuantizeDequantizeAvx2(
const float* src,
float* dst,
int len,
const TensorQuantizationParams& qparams,
float noise_ratio = 0.0f);
/*
* Random number generator in [0, 9]: https://www.jstatsoft.org/v08/i14/paper
*/
uint32_t FBGEMM_API Xor128(void);
/**
* @brief Find the min and max value in a float matrix.
*/
void FBGEMM_API FindMinMax(const float* m, float* min, float* max, int len);
void RequantizeFixedPointAvx2(
const std::int32_t* src,
std::uint8_t* dst,
int len,
const RequantizationParams& params);
void RequantizeAvx2(
const std::int32_t* src,
std::uint8_t* dst,
int len,
const RequantizationParams& params);
/**
* @brief Requantize with avx2 and bias is fused.
*/
template <
bool A_SYMMETRIC,
bool B_SYMMETRIC,
QuantizationGranularity Q_GRAN,
bool HAS_BIAS,
bool FUSE_RELU,
typename BIAS_TYPE = std::int32_t,
bool DIRECT = false>
FBGEMM_API void requantizeOutputProcessingAvx2(
std::uint8_t* out,
const std::int32_t* inp,
const block_type_t& block,
int ld_out,
int ld_in,
const requantizationParams_t<BIAS_TYPE>& r);
template <
bool A_SYMMETRIC,
bool B_SYMMETRIC,
QuantizationGranularity Q_GRAN,
bool HAS_BIAS,
bool FUSE_RELU,
int C_PER_G,
typename BIAS_TYPE = std::int32_t>
FBGEMM_API void requantizeOutputProcessingGConvAvx2(
std::uint8_t* out,
const std::int32_t* inp,
const block_type_t& block,
int ld_out,
int ld_in,
const requantizationParams_t<BIAS_TYPE>& r);
template <
bool A_SYMMETRIC,
bool B_SYMMETRIC,
QuantizationGranularity Q_GRAN,
bool HAS_BIAS,
bool FUSE_RELU>
FBGEMM_API void requantizeForFloatAvx2(
float* out,
const std::int32_t* inp,
const block_type_t& block,
int ld_out,
int ld_in,
const requantizationForFloatParams_t& r);
template <typename InputType, int BIT_RATE>
void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2(
const InputType* input,
size_t input_rows,
int input_columns,
std::uint8_t* output);
template <typename InputType>
void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatAvx2(
const InputType* input,
size_t input_rows,
int input_columns,
std::uint8_t* output);
template <typename OutputType, int BIT_RATE>
void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2(
const std::uint8_t* input,
size_t input_rows,
int input_columns,
OutputType* output);
template <typename OutputType>
void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2(
const std::uint8_t* input,
size_t input_rows,
int input_columns,
OutputType* output);
} // namespace fbgemm