Skip to content

Commit

Permalink
Make all fp16 routines work with fp32 as input instead of fp64, since…
Browse files Browse the repository at this point in the history
… that is what hardware supports anyway.
  • Loading branch information
sesse committed Jan 16, 2016
1 parent 0830ff0 commit 35ab975
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 201 deletions.
4 changes: 2 additions & 2 deletions fft_input.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ void FFTInput::set_gl_state(GLuint glsl_program_num, const string& prefix, unsig
// Convert to fp16.
fp16_int_t *kernel = new fp16_int_t[fft_width * fft_height * 2];
for (int i = 0; i < fft_width * fft_height; ++i) {
kernel[i * 2 + 0] = fp64_to_fp16(out[i][0]);
kernel[i * 2 + 1] = fp64_to_fp16(out[i][1]);
kernel[i * 2 + 0] = fp32_to_fp16(out[i][0]);
kernel[i * 2 + 1] = fp32_to_fp16(out[i][1]);
}

// (Re-)upload the texture.
Expand Down
8 changes: 4 additions & 4 deletions fft_pass_effect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,10 @@ void FFTPassEffect::generate_support_texture()
support_texture_index = subfft_size - support_texture_index - 1;
sign = -1.0;
}
tmp[support_texture_index * 4 + 0] = fp64_to_fp16(sign * (src1 - i * stride) / double(input_size));
tmp[support_texture_index * 4 + 1] = fp64_to_fp16(sign * (src2 - i * stride) / double(input_size));
tmp[support_texture_index * 4 + 2] = fp64_to_fp16(twiddle_real);
tmp[support_texture_index * 4 + 3] = fp64_to_fp16(twiddle_imag);
tmp[support_texture_index * 4 + 0] = fp32_to_fp16(sign * (src1 - i * stride) / double(input_size));
tmp[support_texture_index * 4 + 1] = fp32_to_fp16(sign * (src2 - i * stride) / double(input_size));
tmp[support_texture_index * 4 + 2] = fp32_to_fp16(twiddle_real);
tmp[support_texture_index * 4 + 3] = fp32_to_fp16(twiddle_imag);
}

// Supposedly FFTs are very sensitive to inaccuracies in the twiddle factors,
Expand Down
134 changes: 60 additions & 74 deletions fp16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,46 +3,46 @@
namespace movit {
namespace {

union fp64 {
double f;
unsigned long long ll;
union fp32 {
float f;
unsigned int u;
};

template<class FP16_INT_T,
int FP16_BIAS, int FP16_MANTISSA_BITS, int FP16_EXPONENT_BITS, int FP16_MAX_EXPONENT,
int FP64_BIAS, int FP64_MANTISSA_BITS, int FP64_EXPONENT_BITS, int FP64_MAX_EXPONENT>
inline double fp_upconvert(FP16_INT_T x)
int FP32_BIAS, int FP32_MANTISSA_BITS, int FP32_EXPONENT_BITS, int FP32_MAX_EXPONENT>
inline float fp_upconvert(FP16_INT_T x)
{
int sign = x.val >> (FP16_MANTISSA_BITS + FP16_EXPONENT_BITS);
int exponent = (x.val & ((1ULL << (FP16_MANTISSA_BITS + FP16_EXPONENT_BITS)) - 1)) >> FP16_MANTISSA_BITS;
unsigned long long mantissa = x.val & ((1ULL << FP16_MANTISSA_BITS) - 1);
int exponent = (x.val & ((1U << (FP16_MANTISSA_BITS + FP16_EXPONENT_BITS)) - 1)) >> FP16_MANTISSA_BITS;
unsigned int mantissa = x.val & ((1U << FP16_MANTISSA_BITS) - 1);

int sign64;
int exponent64;
unsigned long long mantissa64;
int sign32;
int exponent32;
unsigned int mantissa32;

if (exponent == 0) {
/*
* Denormals, or zero. Zero is still zero, denormals become
* ordinary numbers.
*/
if (mantissa == 0) {
sign64 = sign;
exponent64 = 0;
mantissa64 = 0;
sign32 = sign;
exponent32 = 0;
mantissa32 = 0;
} else {
sign64 = sign;
exponent64 = FP64_BIAS - FP16_BIAS;
mantissa64 = mantissa << (FP64_MANTISSA_BITS - FP16_MANTISSA_BITS + 1);
sign32 = sign;
exponent32 = FP32_BIAS - FP16_BIAS;
mantissa32 = mantissa << (FP32_MANTISSA_BITS - FP16_MANTISSA_BITS + 1);

/* Normalize the number. */
while ((mantissa64 & (1ULL << FP64_MANTISSA_BITS)) == 0) {
--exponent64;
mantissa64 <<= 1;
while ((mantissa32 & (1U << FP32_MANTISSA_BITS)) == 0) {
--exponent32;
mantissa32 <<= 1;
}

/* Clear the now-implicit one-bit. */
mantissa64 &= ~(1ULL << FP64_MANTISSA_BITS);
mantissa32 &= ~(1U << FP32_MANTISSA_BITS);
}
} else if (exponent == FP16_MAX_EXPONENT) {
/*
Expand All @@ -51,44 +51,44 @@ inline double fp_upconvert(FP16_INT_T x)
* keep the first bit (which signals signalling/non-signalling
* in many implementations).
*/
sign64 = sign;
exponent64 = FP64_MAX_EXPONENT;
mantissa64 = mantissa << (FP64_MANTISSA_BITS - FP16_MANTISSA_BITS);
sign32 = sign;
exponent32 = FP32_MAX_EXPONENT;
mantissa32 = mantissa << (FP32_MANTISSA_BITS - FP16_MANTISSA_BITS);
} else {
sign64 = sign;
sign32 = sign;

/* Up-conversion is simple. Just re-bias the exponent... */
exponent64 = exponent + FP64_BIAS - FP16_BIAS;
exponent32 = exponent + FP32_BIAS - FP16_BIAS;

/* ...and convert the mantissa. */
mantissa64 = mantissa << (FP64_MANTISSA_BITS - FP16_MANTISSA_BITS);
mantissa32 = mantissa << (FP32_MANTISSA_BITS - FP16_MANTISSA_BITS);
}

union fp64 nx;
nx.ll = ((unsigned long long)sign64 << (FP64_MANTISSA_BITS + FP64_EXPONENT_BITS))
| ((unsigned long long)exponent64 << FP64_MANTISSA_BITS)
| mantissa64;
union fp32 nx;
nx.u = ((unsigned int)sign32 << (FP32_MANTISSA_BITS + FP32_EXPONENT_BITS))
| ((unsigned int)exponent32 << FP32_MANTISSA_BITS)
| mantissa32;
return nx.f;
}
unsigned long long shift_right_with_round(unsigned long long x, unsigned shift)

unsigned int shift_right_with_round(unsigned int x, unsigned shift)
{
/* shifts >= 64 need to be special-cased */
if (shift > 64) {
/* shifts >= 32 need to be special-cased */
if (shift > 32) {
return 0;
} else if (shift == 64) {
if (x > (1ULL << 63)) {
} else if (shift == 32) {
if (x > (1U << 31)) {
return 1;
} else {
return 0;
}
}

unsigned long long round_part = x & ((1ULL << shift) - 1);
if (round_part < (1ULL << (shift - 1))) {
unsigned int round_part = x & ((1U << shift) - 1);
if (round_part < (1U << (shift - 1))) {
/* round down */
x >>= shift;
} else if (round_part > (1ULL << (shift - 1))) {
} else if (round_part > (1U << (shift - 1))) {
/* round up */
x >>= shift;
++x;
Expand All @@ -104,31 +104,31 @@ unsigned long long shift_right_with_round(unsigned long long x, unsigned shift)

template<class FP16_INT_T,
int FP16_BIAS, int FP16_MANTISSA_BITS, int FP16_EXPONENT_BITS, int FP16_MAX_EXPONENT,
int FP64_BIAS, int FP64_MANTISSA_BITS, int FP64_EXPONENT_BITS, int FP64_MAX_EXPONENT>
inline FP16_INT_T fp_downconvert(double x)
int FP32_BIAS, int FP32_MANTISSA_BITS, int FP32_EXPONENT_BITS, int FP32_MAX_EXPONENT>
inline FP16_INT_T fp_downconvert(float x)
{
union fp64 nx;
union fp32 nx;
nx.f = x;
unsigned long long f = nx.ll;
int sign = f >> (FP64_MANTISSA_BITS + FP64_EXPONENT_BITS);
int exponent = (f & ((1ULL << (FP64_MANTISSA_BITS + FP64_EXPONENT_BITS)) - 1)) >> FP64_MANTISSA_BITS;
unsigned long long mantissa = f & ((1ULL << FP64_MANTISSA_BITS) - 1);
unsigned int f = nx.u;
int sign = f >> (FP32_MANTISSA_BITS + FP32_EXPONENT_BITS);
int exponent = (f & ((1U << (FP32_MANTISSA_BITS + FP32_EXPONENT_BITS)) - 1)) >> FP32_MANTISSA_BITS;
unsigned int mantissa = f & ((1U << FP32_MANTISSA_BITS) - 1);

int sign16;
int exponent16;
unsigned long long mantissa16;
unsigned int mantissa16;

if (exponent == 0) {
/*
* Denormals, or zero. The largest possible 64-bit
* Denormals, or zero. The largest possible 32-bit
* denormal is about +- 2^-1022, and the smallest possible
* 16-bit denormal is +- 2^-24. Thus, we can safely
* just set all of these to zero (but keep the sign bit).
*/
sign16 = sign;
exponent16 = 0;
mantissa16 = 0;
} else if (exponent == FP64_MAX_EXPONENT) {
} else if (exponent == FP32_MAX_EXPONENT) {
/*
* Infinities or NaN (mantissa=0 => infinity, otherwise NaN).
* We don't care much about NaNs, so let us just keep the first
Expand All @@ -142,25 +142,25 @@ inline FP16_INT_T fp_downconvert(double x)
} else {
sign16 = sign; /* undefined */
exponent16 = FP16_MAX_EXPONENT;
mantissa16 = mantissa >> (FP64_MANTISSA_BITS - FP16_MANTISSA_BITS);
mantissa16 = mantissa >> (FP32_MANTISSA_BITS - FP16_MANTISSA_BITS);
if (mantissa16 == 0) {
mantissa16 = 1;
}
}
} else {
/* Re-bias the exponent, and check if we will create a denormal. */
exponent16 = exponent + FP16_BIAS - FP64_BIAS;
exponent16 = exponent + FP16_BIAS - FP32_BIAS;
if (exponent16 <= 0) {
int shift_amount = FP64_MANTISSA_BITS - FP16_MANTISSA_BITS - exponent16 + 1;
int shift_amount = FP32_MANTISSA_BITS - FP16_MANTISSA_BITS - exponent16 + 1;
sign16 = sign;
exponent16 = 0;
mantissa16 = shift_right_with_round(mantissa | (1ULL << FP64_MANTISSA_BITS), shift_amount);
mantissa16 = shift_right_with_round(mantissa | (1U << FP32_MANTISSA_BITS), shift_amount);

/*
* We could actually have rounded back into the lowest possible non-denormal
* here, so check for that.
*/
if (mantissa16 == (1ULL << FP16_MANTISSA_BITS)) {
if (mantissa16 == (1U << FP16_MANTISSA_BITS)) {
exponent16 = 1;
mantissa16 = 0;
}
Expand All @@ -171,10 +171,10 @@ inline FP16_INT_T fp_downconvert(double x)
* mode.
*/
sign16 = sign;
mantissa16 = shift_right_with_round(mantissa, FP64_MANTISSA_BITS - FP16_MANTISSA_BITS);
mantissa16 = shift_right_with_round(mantissa, FP32_MANTISSA_BITS - FP16_MANTISSA_BITS);

/* Check if we overflowed and need to increase the exponent. */
if (mantissa16 == (1ULL << FP16_MANTISSA_BITS)) {
if (mantissa16 == (1U << FP16_MANTISSA_BITS)) {
++exponent16;
mantissa16 = 0;
}
Expand Down Expand Up @@ -213,34 +213,20 @@ const int FP16_MAX_EXPONENT = (1 << FP16_EXPONENT_BITS) - 1;

#ifndef __F16C__

double fp16_to_fp64(fp16_int_t x)
float fp16_to_fp32(fp16_int_t x)
{
return fp_upconvert<fp16_int_t,
FP16_BIAS, FP16_MANTISSA_BITS, FP16_EXPONENT_BITS, FP16_MAX_EXPONENT,
FP64_BIAS, FP64_MANTISSA_BITS, FP64_EXPONENT_BITS, FP64_MAX_EXPONENT>(x);
FP32_BIAS, FP32_MANTISSA_BITS, FP32_EXPONENT_BITS, FP32_MAX_EXPONENT>(x);
}

fp16_int_t fp64_to_fp16(double x)
fp16_int_t fp32_to_fp16(float x)
{
return fp_downconvert<fp16_int_t,
FP16_BIAS, FP16_MANTISSA_BITS, FP16_EXPONENT_BITS, FP16_MAX_EXPONENT,
FP64_BIAS, FP64_MANTISSA_BITS, FP64_EXPONENT_BITS, FP64_MAX_EXPONENT>(x);
FP32_BIAS, FP32_MANTISSA_BITS, FP32_EXPONENT_BITS, FP32_MAX_EXPONENT>(x);
}

#endif

double fp32_to_fp64(fp32_int_t x)
{
return fp_upconvert<fp32_int_t,
FP32_BIAS, FP32_MANTISSA_BITS, FP32_EXPONENT_BITS, FP32_MAX_EXPONENT,
FP64_BIAS, FP64_MANTISSA_BITS, FP64_EXPONENT_BITS, FP64_MAX_EXPONENT>(x);
}

fp32_int_t fp64_to_fp32(double x)
{
return fp_downconvert<fp32_int_t,
FP32_BIAS, FP32_MANTISSA_BITS, FP32_EXPONENT_BITS, FP32_MAX_EXPONENT,
FP64_BIAS, FP64_MANTISSA_BITS, FP64_EXPONENT_BITS, FP64_MAX_EXPONENT>(x);
}

} // namespace
32 changes: 12 additions & 20 deletions fp16.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,45 +26,37 @@ struct fp16_int_t {

// Use the f16c instructions from Haswell if available (and we know that they
// are at compile time).
static inline double fp16_to_fp64(fp16_int_t x)
static inline float fp16_to_fp32(fp16_int_t x)
{
return _cvtsh_ss(x.val);
}

static inline fp16_int_t fp64_to_fp16(double x)
static inline fp16_int_t fp32_to_fp16(float x)
{
// NOTE: Strictly speaking, there are some select values where this isn't correct,
// since we first round to fp32 and then to fp16.
fp16_int_t ret;
ret.val = _cvtss_sh(x, 0);
return ret;
}

#else

double fp16_to_fp64(fp16_int_t x);
fp16_int_t fp64_to_fp16(double x);
float fp16_to_fp32(fp16_int_t x);
fp16_int_t fp32_to_fp16(float x);

#endif

// These are not very useful by themselves, but are implemented using the same
// code as the fp16 ones (just with different constants), so they are useful
// for verifying against the FPU in unit tests.
double fp32_to_fp64(fp32_int_t x);
fp32_int_t fp64_to_fp32(double x);

// Overloads for use in templates.
static inline double to_fp64(double x) { return x; }
static inline double to_fp64(float x) { return x; }
static inline double to_fp64(fp16_int_t x) { return fp16_to_fp64(x); }
static inline float to_fp32(double x) { return x; }
static inline float to_fp32(float x) { return x; }
static inline float to_fp32(fp16_int_t x) { return fp16_to_fp32(x); }

template<class T> inline T from_fp64(double x);
template<> inline double from_fp64<double>(double x) { return x; }
template<> inline float from_fp64<float>(double x) { return x; }
template<> inline fp16_int_t from_fp64<fp16_int_t>(double x) { return fp64_to_fp16(x); }
template<class T> inline T from_fp32(float x);
template<> inline double from_fp32<double>(float x) { return x; }
template<> inline float from_fp32<float>(float x) { return x; }
template<> inline fp16_int_t from_fp32<fp16_int_t>(float x) { return fp32_to_fp16(x); }

template<class From, class To>
inline To convert_float(From x) { return from_fp64<To>(to_fp64(x)); }
inline To convert_float(From x) { return from_fp32<To>(to_fp32(x)); }

template<class Same>
inline Same convert_float(Same x) { return x; }
Expand Down
Loading

0 comments on commit 35ab975

Please sign in to comment.