forked from pytorch/FBGEMM
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathTypes.h
265 lines (227 loc) · 7.69 KB
/
Types.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <cstdint>
#include <cstdlib>
#include <cstring>
#ifndef __is_identifier
#define __is_identifier(x) 1
#endif
#define __has_keyword(__x) !(__is_identifier(__x))
// TODO: we're disabling native fp16 on Windows to workaround test failures
// due to "undefined symbol __gnu_h2f_ieee" error. We should follup on this
// later.
#if __has_keyword(__fp16) && !defined(_WIN32)
#define HAS_NATIVE_FP16_TYPE
typedef __fp16 native_fp16_t;
#elif __has_keyword(_Float16) && !defined(_WIN32)
#define HAS_NATIVE_FP16_TYPE
typedef _Float16 native_fp16_t;
#else
typedef void native_fp16_t;
#endif
namespace fbgemm {
using float16 = std::uint16_t;
using bfloat16 = std::uint16_t;
// The IEEE754 standard species a binary16 as having the following format:
// SEEEEEMMMMMMMMMM
// 0432109876543210
// That is:
// * 1 sign bit
// * 5 exponent bits
// * 10 mantissa/significand bits (an 11th bit is implicit)
constexpr uint32_t f16_num_bits = 16;
constexpr uint32_t f16_num_exponent_bits = 5;
constexpr uint32_t f16_num_mantissa_bits = 10;
constexpr uint32_t f16_num_non_sign_bits =
f16_num_exponent_bits + f16_num_mantissa_bits;
constexpr uint32_t f16_exponent_mask = 0b1'1111; // 5 bits
constexpr uint32_t f16_sign_bit = 1u
<< (f16_num_exponent_bits + f16_num_mantissa_bits);
constexpr uint32_t f16_exponent_bits = f16_exponent_mask
<< f16_num_mantissa_bits;
constexpr uint32_t f16_mantissa_mask = 0b11'1111'1111; // 10 bits
constexpr uint32_t f16_exponent_bias = 15;
constexpr uint32_t f16_nan = 0x7F'FF;
// The IEEE754 standard specifies a binary32 as having:
// SEEEEEEEEMMMMMMMMMMMMMMMMMMMMMMM
// That is:
// * 1 sign bit
// * 8 exponent bits
// * 23 mantissa/significand bits (a 24th bit is implicit)
constexpr uint32_t f32_num_exponent_bits = 8;
constexpr uint32_t f32_num_mantissa_bits = 23;
constexpr uint32_t f32_exponent_mask = 0b1111'1111; // 8 bits
constexpr uint32_t f32_mantissa_mask = 0x7F'FF'FF; // 23 bits
constexpr uint32_t f32_exponent_bias = 127;
constexpr uint32_t f32_all_non_sign_mask = 0x7F'FF'FF'FF; // 31 bits
constexpr uint32_t f32_most_significant_bit = 1u << 22; // Turn on 23rd bit
constexpr uint32_t f32_num_non_sign_bits =
f32_num_exponent_bits + f32_num_mantissa_bits;
// Round to nearest even
static inline float16 cpu_float2half_rn(float f) {
static_assert(
sizeof(uint32_t) == sizeof(float),
"Programming error sizeof(uint32_t) != sizeof(float)");
uint32_t* xp = reinterpret_cast<uint32_t*>(&f);
uint32_t x = *xp;
uint32_t u = (x & f32_all_non_sign_mask);
// Get rid of +NaN/-NaN case first.
if (u > 0x7f800000) {
return static_cast<float16>(f16_nan);
}
uint32_t sign = ((x >> f16_num_bits) & f16_sign_bit);
// Get rid of +Inf/-Inf, +0/-0.
if (u > 0x477fefff) {
return static_cast<float16>(sign | f16_exponent_bits);
}
if (u < 0x33000001) {
return static_cast<float16>(sign | 0x0000);
}
uint32_t exponent = ((u >> f32_num_mantissa_bits) & f32_exponent_mask);
uint32_t mantissa = (u & f32_mantissa_mask);
uint32_t shift;
if (exponent > f32_exponent_bias - f16_exponent_bias) {
shift = f32_num_mantissa_bits - f16_num_mantissa_bits;
exponent -= f32_exponent_bias - f16_exponent_bias;
} else {
shift = (f32_exponent_bias - 1) - exponent;
exponent = 0;
mantissa |=
(1u
<< f32_num_mantissa_bits); // Bump the least significant exponent bit
}
const uint32_t lsb = (1u << shift);
const uint32_t lsb_s1 = (lsb >> 1);
const uint32_t lsb_m1 = (lsb - 1);
// Round to nearest even.
const uint32_t remainder = (mantissa & lsb_m1);
mantissa >>= shift;
if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) {
++mantissa;
if (!(mantissa & f16_mantissa_mask)) {
++exponent;
mantissa = 0;
}
}
return static_cast<float16>(
sign | (exponent << f16_num_mantissa_bits) | mantissa);
}
// Round to zero
static inline float16 cpu_float2half_rz(float f) {
static_assert(
sizeof(uint32_t) == sizeof(float),
"Programming error sizeof(uint32_t) != sizeof(float)");
const uint32_t* xp = reinterpret_cast<uint32_t*>(&f);
const uint32_t x = *xp;
const uint32_t u = (x & f32_all_non_sign_mask);
// Get rid of +NaN/-NaN case first.
if (u > 0x7f800000) {
return static_cast<float16>(f16_nan);
}
uint32_t sign = ((x >> f16_num_bits) & f16_sign_bit);
// Get rid of +Inf/-Inf, +0/-0.
if (u > 0x477fefff) {
return static_cast<float16>(sign | f16_exponent_bits);
}
if (u < 0x33000001) {
return static_cast<float16>(sign | 0x0000);
}
uint32_t exponent = ((u >> f32_num_mantissa_bits) & f32_exponent_mask);
uint32_t mantissa = (u & f32_mantissa_mask);
uint32_t shift;
if (exponent > f32_exponent_bias - f16_exponent_bias) {
shift = f32_num_mantissa_bits - f16_num_mantissa_bits;
exponent -= f32_exponent_bias - f16_exponent_bias;
} else {
shift = (f32_exponent_bias - 1) - exponent;
exponent = 0;
mantissa |=
(1u
<< f32_num_mantissa_bits); // Bump the least significant exponent bit
}
// Round to zero.
mantissa >>= shift;
return static_cast<float16>(
sign | (exponent << f16_num_mantissa_bits) | mantissa);
}
// Converts a 16-bit unsigned integer representation of a IEEE754 half-precision
// float into an IEEE754 32-bit single-precision float
inline float cpu_half2float_ref(const float16 h) {
// Get sign and exponent alone by themselves
uint32_t sign_bit = (h >> f16_num_non_sign_bits) & 1;
uint32_t exponent = (h >> f16_num_mantissa_bits) & f16_exponent_mask;
// Shift mantissa so that it fills the most significant bits of a float32
uint32_t mantissa = (h & f16_mantissa_mask)
<< (f32_num_mantissa_bits - f16_num_mantissa_bits);
if (exponent == f16_exponent_mask) { // NaN or Inf
if (mantissa) {
mantissa = f32_mantissa_mask;
sign_bit = 0;
}
exponent = f32_exponent_mask;
} else if (!exponent) { // Denorm or Zero
if (mantissa) {
uint32_t msb;
exponent = f32_exponent_bias - f16_exponent_bias + 1;
do {
msb = mantissa & f32_most_significant_bit;
mantissa <<= 1; // normalize
--exponent;
} while (!msb);
mantissa &= f32_mantissa_mask; // 1.mantissa is implicit
}
} else {
exponent += f32_exponent_bias - f16_exponent_bias;
}
const uint32_t i = (sign_bit << f32_num_non_sign_bits) |
(exponent << f32_num_mantissa_bits) | mantissa;
float ret;
std::memcpy(&ret, &i, sizeof(float));
return ret;
}
// Same as the previous function, but use the built-in fp16 to fp32
// conversion provided by the compiler
inline float cpu_half2float(const float16 h) {
#ifdef HAS_NATIVE_FP16_TYPE
__fp16 h_fp16;
std::memcpy(&h_fp16, &h, sizeof(__fp16));
return h_fp16;
#else
return cpu_half2float_ref(h);
#endif
}
inline float16 cpu_float2half(const float f) {
#ifdef HAS_NATIVE_FP16_TYPE
__fp16 h = f;
float16 res;
std::memcpy(&res, &h, sizeof(__fp16));
return res;
#else
return cpu_float2half_rn(f);
#endif
}
static inline float cpu_bf162float(bfloat16 src) {
float ret;
uint32_t val_fp32 =
static_cast<uint32_t>(reinterpret_cast<const uint16_t*>(&src)[0]) << 16;
memcpy(&ret, &val_fp32, sizeof(float));
return ret;
}
static inline bfloat16 cpu_float2bfloat16(float src) {
uint32_t temp;
memcpy(&temp, &src, sizeof(uint32_t));
return (temp + (1u << 15)) >> 16;
}
inline int64_t round_up(int64_t val, int64_t unit) {
return (val + unit - 1) / unit * unit;
}
inline int64_t div_up(int64_t val, int64_t unit) {
return (val + unit - 1) / unit;
}
} // namespace fbgemm