Skip to content

Commit

Permalink
Fix convert indexing bugs (pytorch#367)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#367

The type conversion code uses 32 bit indexing leading to curruprtion of
large tensors. Caffe2 operators used in production such as FloatToHalf used for
embedding table quantization are affected. To make things worse, the trainer
corrupts the data silently without causing visible crashes.

Reviewed By: jianyuh

Differential Revision: D21372206

fbshipit-source-id: fcbd55f806fd72bc77dbb51ec8386e9f1c37f568
  • Loading branch information
Pawel Garbacki authored and facebook-github-bot committed May 3, 2020
1 parent e6a4c5a commit fb6b95d
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 51 deletions.
34 changes: 17 additions & 17 deletions include/fbgemm/FbgemmConvert.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,54 +19,54 @@ typedef uint16_t bfloat16;
* implementation.
*
*/
FBGEMM_API void FloatToBfloat16_ref(const float* src, bfloat16* dst, int size);
FBGEMM_API void FloatToBfloat16_ref(const float* src, bfloat16* dst, size_t size);

/**
* @ Transform all entries in a matrix from bfloat16 to fp32: reference
* implementation.
*
*/
FBGEMM_API void Bfloat16ToFloat_ref(const bfloat16* src, float* dst, int size);
FBGEMM_API void Bfloat16ToFloat_ref(const bfloat16* src, float* dst, size_t size);

/**
* @ Transform all entries in a matrix from fp32 to bfloat16: simd
* implementation.
*
*/
FBGEMM_API void FloatToBfloat16_simd(const float* src, bfloat16* dst, int size);
FBGEMM_API void FloatToBfloat16_simd(const float* src, bfloat16* dst, size_t size);

/**
* @ Transform all entries in a matrix from bfloat16 to fp32: simd
* implementation.
*
*/
FBGEMM_API void Bfloat16ToFloat_simd(const bfloat16* src, float* dst, int size);
FBGEMM_API void Bfloat16ToFloat_simd(const bfloat16* src, float* dst, size_t size);

/**
* @brief AVX2 implementation to convert fp32 numbers to bf16 numbers.
*
*/
FBGEMM_API void FloatToBfloat16_avx2(const float* src, bfloat16* dst, int size);
FBGEMM_API void FloatToBfloat16_avx2(const float* src, bfloat16* dst, size_t size);

/**
* @brief AVX512 implementation to convert fp32 numbers to bf16 numbers.
*
*/
FBGEMM_API void
FloatToBfloat16_avx512(const float* src, bfloat16* dst, int size);
FloatToBfloat16_avx512(const float* src, bfloat16* dst, size_t size);

/**
* @brief AVX2 implementation to convert bf16 numbers to fp32 numbers.
*
*/
FBGEMM_API void Bfloat16ToFloat_avx2(const bfloat16* src, float* dst, int size);
FBGEMM_API void Bfloat16ToFloat_avx2(const bfloat16* src, float* dst, size_t size);

/**
* @brief AVX512 implementation to convert bf16 numbers to fp32 numbers.
*
*/
FBGEMM_API void
Bfloat16ToFloat_avx512(const bfloat16* src, float* dst, int size);
Bfloat16ToFloat_avx512(const bfloat16* src, float* dst, size_t size);

/**
* @ Transform all entries in a matrix from fp32 to float16: reference
Expand All @@ -76,15 +76,15 @@ Bfloat16ToFloat_avx512(const bfloat16* src, float* dst, int size);
FBGEMM_API void FloatToFloat16_ref(
const float* src,
float16* dst,
int size,
size_t size,
bool do_clip = false);

/**
* @ Transform all entries in a matrix from float16 to fp32: reference
* implementation.
*
*/
FBGEMM_API void Float16ToFloat_ref(const float16* src, float* dst, int size);
FBGEMM_API void Float16ToFloat_ref(const float16* src, float* dst, size_t size);

/**
* @ Transform all entries in a matrix from fp32 to float16: simd
Expand All @@ -94,15 +94,15 @@ FBGEMM_API void Float16ToFloat_ref(const float16* src, float* dst, int size);
FBGEMM_API void FloatToFloat16_simd(
const float* src,
float16* dst,
int size,
size_t size,
bool do_clip = false);

/**
* @ Transform all entries in a matrix from float16 to fp32: simd
* implementation.
*
*/
FBGEMM_API void Float16ToFloat_simd(const float16* src, float* dst, int size);
FBGEMM_API void Float16ToFloat_simd(const float16* src, float* dst, size_t size);

/**
* @brief AVX2 implementation to convert fp32 numbers to fp16 numbers.
Expand All @@ -111,7 +111,7 @@ FBGEMM_API void Float16ToFloat_simd(const float16* src, float* dst, int size);
FBGEMM_API void FloatToFloat16_avx2(
const float* src,
float16* dst,
int size,
size_t size,
bool do_clip = false);

/**
Expand All @@ -121,20 +121,20 @@ FBGEMM_API void FloatToFloat16_avx2(
FBGEMM_API void FloatToFloat16_avx512(
const float* src,
float16* dst,
int size,
size_t size,
bool do_clip = false);

/**
* @brief AVX2 implementation to convert fp16 numbers to fp32 numbers.
*
*/
FBGEMM_API void Float16ToFloat_avx2(const float16* src, float* dst, int size);
FBGEMM_API void Float16ToFloat_avx2(const float16* src, float* dst, size_t size);

/**
* @brief AVX512 implementation to convert fp16 numbers to fp32 numbers.
*
*/
FBGEMM_API void Float16ToFloat_avx512(const float16* src, float* dst, int size);
FBGEMM_API void Float16ToFloat_avx512(const float16* src, float* dst, size_t size);

/**
* @brief Transform all entries in a matrix from fp32 to float16 and back to
Expand All @@ -143,7 +143,7 @@ FBGEMM_API void Float16ToFloat_avx512(const float16* src, float* dst, int size);
FBGEMM_API void RoundToFloat16(
const float* input,
float* output,
int len,
size_t size,
bool clamp = false,
bool clamp_denorms = false);

Expand Down
4 changes: 2 additions & 2 deletions src/FbgemmBfloat16Convert.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ using namespace std;

namespace fbgemm {

void FloatToBfloat16_simd(const float* src, bfloat16* dst, int size) {
void FloatToBfloat16_simd(const float* src, bfloat16* dst, size_t size) {
// Run time CPU detection
if (cpuinfo_initialize()) {
if (fbgemmHasAvx512Support()) {
Expand All @@ -58,7 +58,7 @@ void FloatToBfloat16_simd(const float* src, bfloat16* dst, int size) {
}
}

void Bfloat16ToFloat_simd(const bfloat16* src, float* dst, int size) {
void Bfloat16ToFloat_simd(const bfloat16* src, float* dst, size_t size) {
// Run time CPU detection
if (cpuinfo_initialize()) {
if (fbgemmHasAvx512Support()) {
Expand Down
8 changes: 4 additions & 4 deletions src/FbgemmBfloat16ConvertAvx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,16 @@ inline void Bfloat16ToFloatKernelAvx2(const bfloat16* src, float* dst) {

} // namespace

void FloatToBfloat16_avx2(const float* src, bfloat16* dst, int size) {
int i = 0;
void FloatToBfloat16_avx2(const float* src, bfloat16* dst, size_t size) {
size_t i = 0;
for (i = 0; i + 8 * 2 <= size; i += 8 * 2) {
FloatToBfloat16KernelAvx2(src + i, dst + i);
}
FloatToBfloat16_ref(src + i, dst + i, size - i);
}

void Bfloat16ToFloat_avx2(const bfloat16* src, float* dst, int size) {
int i = 0;
void Bfloat16ToFloat_avx2(const bfloat16* src, float* dst, size_t size) {
size_t i = 0;
for (i = 0; i + 8 <= size; i += 8) {
Bfloat16ToFloatKernelAvx2(src + i, dst + i);
}
Expand Down
8 changes: 4 additions & 4 deletions src/FbgemmBfloat16ConvertAvx512.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,16 @@ inline void Bfloat16ToFloatKernelAvx512(const bfloat16* src, float* dst) {

} // namespace

void FloatToBfloat16_avx512(const float* src, bfloat16* dst, int size) {
int i = 0;
void FloatToBfloat16_avx512(const float* src, bfloat16* dst, size_t size) {
size_t i = 0;
for (i = 0; i + 16 <= size; i += 16) {
FloatToBfloat16KernelAvx512(src + i, dst + i);
}
FloatToBfloat16_avx2(src + i, dst + i, size - i);
}

void Bfloat16ToFloat_avx512(const bfloat16* src, float* dst, int size) {
int i = 0;
void Bfloat16ToFloat_avx512(const bfloat16* src, float* dst, size_t size) {
size_t i = 0;
for (i = 0; i + 16 <= size; i += 16) {
Bfloat16ToFloatKernelAvx512(src + i, dst + i);
}
Expand Down
10 changes: 5 additions & 5 deletions src/FbgemmFloat16Convert.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace fbgemm {
void FloatToFloat16_simd(
const float* src,
float16* dst,
int size,
size_t size,
bool do_clip) {
// Run time CPU detection
if (cpuinfo_initialize()) {
Expand All @@ -52,7 +52,7 @@ void FloatToFloat16_simd(
}
}

void Float16ToFloat_simd(const float16* src, float* dst, int size) {
void Float16ToFloat_simd(const float16* src, float* dst, size_t size) {
// Run time CPU detection
if (cpuinfo_initialize()) {
if (fbgemmHasAvx512Support()) {
Expand All @@ -71,7 +71,7 @@ void Float16ToFloat_simd(const float16* src, float* dst, int size) {
void RoundToFloat16(
const float* input,
float* output,
int size,
size_t size,
bool clamp,
bool clamp_denorms) {
std::vector<fbgemm::float16> data_fp16(size);
Expand All @@ -80,7 +80,7 @@ void RoundToFloat16(

if (clamp) {
// TODO: Use intrinsics to optimize clamping performance.
for (int i = 0; i < size; ++i) {
for (size_t i = 0; i < size; ++i) {
output[i] = std::max(std::min(output[i], 65504.0f), -65504.0f);
}
}
Expand All @@ -94,7 +94,7 @@ void RoundToFloat16(
union epsilon_t epsilon;
epsilon.i = 0x38800000u; // 1 / 16384

for (int i = 0; i < size; ++i) {
for (size_t i = 0; i < size; ++i) {
if (std::abs(output[i]) < epsilon.f) {
output[i] = 0.0;
}
Expand Down
10 changes: 5 additions & 5 deletions src/FbgemmFloat16ConvertAvx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,25 +45,25 @@ inline void Float16ToFloatKernelAvx2(const float16* src, float* dst) {
void FloatToFloat16_avx2(
const float* src,
float16* dst,
int size,
size_t size,
bool do_clip) {
if (do_clip) {
int i = 0;
size_t i = 0;
for (i = 0; i + 8 <= size; i += 8) {
FloatToFloat16KernelAvx2WithClip(src + i, dst + i);
}
FloatToFloat16_ref(src + i, dst + i, size - i, do_clip);
} else {
int i = 0;
size_t i = 0;
for (i = 0; i + 8 <= size; i += 8) {
FloatToFloat16KernelAvx2(src + i, dst + i);
}
FloatToFloat16_ref(src + i, dst + i, size - i);
}
}

void Float16ToFloat_avx2(const float16* src, float* dst, int size) {
int i = 0;
void Float16ToFloat_avx2(const float16* src, float* dst, size_t size) {
size_t i = 0;
for (i = 0; i + 8 <= size; i += 8) {
Float16ToFloatKernelAvx2(src + i, dst + i);
}
Expand Down
10 changes: 5 additions & 5 deletions src/FbgemmFloat16ConvertAvx512.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,25 +45,25 @@ inline void Float16ToFloatKernelAvx512(const float16* src, float* dst) {
void FloatToFloat16_avx512(
const float* src,
float16* dst,
int size,
size_t size,
bool do_clip) {
if (do_clip) {
int i = 0;
size_t i = 0;
for (i = 0; i + 16 <= size; i += 16) {
FloatToFloat16KernelAvx512WithClip(src + i, dst + i);
}
FloatToFloat16_avx2(src + i, dst + i, size - i, do_clip);
} else {
int i = 0;
size_t i = 0;
for (i = 0; i + 16 <= size; i += 16) {
FloatToFloat16KernelAvx512(src + i, dst + i);
}
FloatToFloat16_avx2(src + i, dst + i, size - i);
}
}

void Float16ToFloat_avx512(const float16* src, float* dst, int size) {
int i = 0;
void Float16ToFloat_avx512(const float16* src, float* dst, size_t size) {
size_t i = 0;
for (i = 0; i + 16 <= size; i += 16) {
Float16ToFloatKernelAvx512(src + i, dst + i);
}
Expand Down
18 changes: 9 additions & 9 deletions src/RefImplementations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,36 +24,36 @@ namespace fbgemm {
void FloatToFloat16_ref(
const float* src,
float16* dst,
int size,
size_t size,
bool do_clip) {
constexpr float FP16_MAX = 65504.f;
if (do_clip) {
for (int i = 0; i < size; i++) {
for (size_t i = 0; i < size; i++) {
float cur_src = std::max(-FP16_MAX, std::min(src[i], FP16_MAX));
dst[i] = cpu_float2half_rn(cur_src);
}
} else {
for (int i = 0; i < size; i++) {
for (size_t i = 0; i < size; i++) {
dst[i] = cpu_float2half_rn(src[i]);
}
}
}

void Float16ToFloat_ref(const float16* src, float* dst, int size) {
for (int i = 0; i < size; i++) {
void Float16ToFloat_ref(const float16* src, float* dst, size_t size) {
for (size_t i = 0; i < size; i++) {
dst[i] = cpu_half2float(src[i]);
}
}

void FloatToBfloat16_ref(const float* src, bfloat16* dst, int size) {
for (int i = 0; i < size; i++) {
void FloatToBfloat16_ref(const float* src, bfloat16* dst, size_t size) {
for (size_t i = 0; i < size; i++) {
// Add 2^15 and right shift 16 to do round-nearest
dst[i] = (*reinterpret_cast<const uint32_t*>(src + i) + (1 << 15)) >> 16;
}
}

void Bfloat16ToFloat_ref(const bfloat16* src, float* dst, int size) {
for (int i = 0; i < size; i++) {
void Bfloat16ToFloat_ref(const bfloat16* src, float* dst, size_t size) {
for (size_t i = 0; i < size; i++) {
uint32_t val_fp32 =
static_cast<uint32_t>(reinterpret_cast<const uint16_t*>(src)[i]) << 16;
reinterpret_cast<uint32_t*>(dst)[i] = val_fp32;
Expand Down

0 comments on commit fb6b95d

Please sign in to comment.