diff --git a/bench/EmbeddingQuantizeBenchmark.cc b/bench/EmbeddingQuantizeBenchmark.cc new file mode 100644 index 0000000000..efbcdbcb8f --- /dev/null +++ b/bench/EmbeddingQuantizeBenchmark.cc @@ -0,0 +1,83 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include +#include + +#ifdef _OPENMP +#include +#endif + +#include "./BenchUtils.h" +#include "fbgemm/QuantUtils.h" +#include "fbgemm/Types.h" + +using namespace std; +using namespace fbgemm; + +void performance_test() { + constexpr int NWARMUP = 4; + constexpr int NITER = 256; + + cout << setw(8) << "bit_rate" + << ", " << setw(6) << "rows" + << "," << setw(6) << "cols" + << "," << setw(16) << "elems_per_usec" + << "," << setw(10) << "GB/Sec" << endl; + for (int bit_rate : {2, 4, 8}) { + for (int rowSize : {100, 120, 1000}) { + for (int colSize : {16, 64, 128, 256, 512, 1024, 2048}) { + aligned_vector inpVec(rowSize * colSize); + randFill(inpVec, -10.0f, 10.0f); + + int elements_per_byte = 8 / bit_rate; + int out_emb_cols = + (colSize + elements_per_byte - 1) / elements_per_byte; + int outVecSize = rowSize * (out_emb_cols + 2 * sizeof(float16)); + aligned_vector outVec(outVecSize); + + double duration = 0.0f; + + duration = measureWithWarmup( + [&]() { + FloatToFusedNBitRowwiseQuantizedSBHalf( + bit_rate, inpVec.data(), rowSize, colSize, outVec.data()); + }, + NWARMUP, + NITER, + [&]() { + cache_evict(inpVec); + cache_evict(outVec); + }); + + float elements_per_usec = rowSize * colSize / (duration * 1e6); + + duration *= 1e9; // convert to ns + long bytes_read = rowSize * colSize * sizeof(float); + float gigabyes_per_sec = bytes_read / duration; + + cout << setw(8) << bit_rate << "," << setw(6) << rowSize << ", " + << setw(6) << colSize << ","; + cout << setw(16) << std::fixed << std::setprecision(2) << elements_per_usec << ", "; + cout << setw(10) << std::fixed << std::setprecision(2) << gigabyes_per_sec << endl; + } // for each cols + } // for each rows + } // for each bit_rate +} // performance_test + +int main() { +#ifdef _OPENMP + // Use 1 thread unless OMP_NUM_THREADS is explicit set. + const char* val = getenv("OMP_NUM_THREADS"); + if (val == nullptr || !*val) { + omp_set_num_threads(1); + } +#endif + performance_test(); + return 0; +} diff --git a/include/fbgemm/QuantUtils.h b/include/fbgemm/QuantUtils.h index b10a1252f0..4a8abb9a81 100644 --- a/include/fbgemm/QuantUtils.h +++ b/include/fbgemm/QuantUtils.h @@ -155,14 +155,15 @@ void Dequantize( template T FusedQuantizeDequantize(float src, const TensorQuantizationParams& qparams) { - T q = Quantize(src, qparams.zero_point, qparams.scale, qparams.precision); + T q = Quantize( + src, qparams.zero_point, qparams.scale, qparams.precision); return Dequantize(q, qparams); } /* -Fused integer quantization dequantization kernel to accelerate quantization-aware training. -Quantize fp32 values in src to (u)int8 using the provided qparams, and dequantize quantized -integer values back into fp32. +Fused integer quantization dequantization kernel to accelerate +quantization-aware training. Quantize fp32 values in src to (u)int8 using the +provided qparams, and dequantize quantized integer values back into fp32. */ template FBGEMM_API void FusedQuantizeDequantize( @@ -248,4 +249,29 @@ FBGEMM_API void Requantize( int thread_id = 0, int num_threads = 1); +/** + * Convert float inputs to rowwise quantized outputs. + * bitrate specifies the number of bits in quantized output. + * Scale and Bias are in fp16. Each row's Scale and Bias are stored in + * the row itself (fused) at the end. + * + * @param bit_rate can be 2, 4, or 8 + */ +FBGEMM_API void FloatToFusedNBitRowwiseQuantizedSBHalf( + int bit_rate, + const float* input, + int input_rows, + int input_columns, + std::uint8_t* output); + +/** + * Same as FloatToFusedNBitRowwiseQuantizedSBHalf but unoptimized. + * This should not be called directly except in testing. + */ +FBGEMM_API void FloatToFusedNBitRowwiseQuantizedSBHalfRef( + int bit_rate, + const float* input, + int input_rows, + int input_columns, + std::uint8_t* output); } // namespace fbgemm diff --git a/include/fbgemm/QuantUtilsAvx2.h b/include/fbgemm/QuantUtilsAvx2.h index 6ed81a65c9..62320d743b 100644 --- a/include/fbgemm/QuantUtilsAvx2.h +++ b/include/fbgemm/QuantUtilsAvx2.h @@ -119,4 +119,11 @@ FBGEMM_API void requantizeForFloatAvx2( int ld_in, const requantizationForFloatParams_t& r); +template +void FloatToFusedNBitRowwiseQuantizedSBHalfAvx2( + const float* input, + int input_rows, + int input_columns, + std::uint8_t* output); + } // namespace fbgemm diff --git a/src/QuantUtils.cc b/src/QuantUtils.cc index f70f5bdcda..99f9d43e36 100644 --- a/src/QuantUtils.cc +++ b/src/QuantUtils.cc @@ -6,6 +6,8 @@ #include "fbgemm/Fbgemm.h" +#include "fbgemm/Types.h" + namespace fbgemm { using namespace std; @@ -464,4 +466,93 @@ FBGEMM_API void RequantizeFixedPoint( } } +void FloatToFusedNBitRowwiseQuantizedSBHalfRef( + int bit_rate, + const float* input, + int input_rows, + int input_columns, + std::uint8_t* output) { + int num_elem_per_byte = 8 / bit_rate; + int output_columns = + (input_columns + num_elem_per_byte - 1) / num_elem_per_byte + + 2 * sizeof(float16); + for (std::size_t row = 0; row < input_rows; ++row) { + const float* input_row = input + row * input_columns; + std::uint8_t* output_row = output + row * output_columns; + float16* output_row_scale_bias = reinterpret_cast( + output_row + + (input_columns + num_elem_per_byte - 1) / num_elem_per_byte); + + float minimum_element = + *std::min_element(input_row, input_row + input_columns); + float maximum_element = + *std::max_element(input_row, input_row + input_columns); + + float16 minimum_element_fp16 = cpu_float2half_rn(minimum_element); + minimum_element = cpu_half2float(minimum_element_fp16); + const float range = maximum_element - minimum_element; + + float scale = range == 0 ? 1.0f : range / ((1 << bit_rate) - 1); + float16 scale_fp16 = cpu_float2half_rn(scale); + scale = cpu_half2float(scale_fp16); + if (scale == 0) { + // Corner case handling when maximum_element == minimum_element + // Any scale would work because X - minimum_element will be 0 for all X + scale = 1.0f; + } + float inverse_scale = 1.0f / scale; + if (std::isinf(inverse_scale)) { + scale = 1.0f; + inverse_scale = 1.0f; + } + + output_row_scale_bias[0] = cpu_float2half_rn(scale); + output_row_scale_bias[1] = minimum_element_fp16; + for (std::size_t col = 0; col < input_columns; ++col) { + float X = input_row[col]; + std::uint8_t quantized = std::max( + 0, + std::min( + std::lrintf((X - minimum_element) * inverse_scale), + (1 << bit_rate) - 1)); + if (col % num_elem_per_byte == 0) { + output_row[col / num_elem_per_byte] = quantized; + } else { + output_row[col / num_elem_per_byte] |= + (quantized << ((col % num_elem_per_byte) * bit_rate)); + } + } + } +} + +void FloatToFusedNBitRowwiseQuantizedSBHalf( + int bit_rate, + const float* input, + int input_rows, + int input_columns, + std::uint8_t* output) { + if (cpuinfo_initialize() && fbgemmHasAvx2Support()) { + switch (bit_rate) { + case 2: + FloatToFusedNBitRowwiseQuantizedSBHalfAvx2<2>( + input, input_rows, input_columns, output); + break; + case 4: + FloatToFusedNBitRowwiseQuantizedSBHalfAvx2<4>( + input, input_rows, input_columns, output); + break; + case 8: + FloatToFusedNBitRowwiseQuantizedSBHalfAvx2<8>( + input, input_rows, input_columns, output); + break; + default: + FloatToFusedNBitRowwiseQuantizedSBHalfRef( + bit_rate, input, input_rows, input_columns, output); + } + } else { + FloatToFusedNBitRowwiseQuantizedSBHalfRef( + bit_rate, input, input_rows, input_columns, output); + } +} + } // namespace fbgemm diff --git a/src/QuantUtilsAvx2.cc b/src/QuantUtilsAvx2.cc index ad009ba5b6..eacb0a2717 100644 --- a/src/QuantUtilsAvx2.cc +++ b/src/QuantUtilsAvx2.cc @@ -11,6 +11,7 @@ #include //for nearbyint #include //for numeric_limits #include //for assert +#include // for FLT_MAX #include //for memcpy #include "./MaskAvx2.h" @@ -1430,4 +1431,195 @@ INSTANTIATE_BIAS(false) #undef INSTANTIATE_Q_GRANS #undef INSTANTIATE_BIAS +static inline uint16_t floatToHalf(float val) { +#ifdef _MSC_VER + // Use _mm256_cvtps_ph/_mm256_cvtph_ps because _cvtsh_ss/_cvtss_sh don't + // exist in MSVC. + __m256 val_v = _mm256_set1_ps(val); + __m128i val_half_v = + _mm256_cvtps_ph(val_v, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + return static_cast(_mm_cvtsi128_si32(val_half_v)); +#else + return _cvtss_sh(val, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); +#endif +} +static inline float halfToFloat(uint16_t val) { +#ifdef _MSC_VER + return _mm256_cvtss_f32(_mm256_cvtph_ps(_mm_cvtsi32_si128(val))); +#else + return _cvtsh_ss(val); +#endif +} + +template +void FloatToFusedNBitRowwiseQuantizedSBHalfAvx2( + const float* input, + int input_rows, + int input_columns, + std::uint8_t* output) { + __m256i permute_mask1_v = + _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00); + + constexpr int VLEN = 8; + constexpr int NUM_ELEM_PER_BYTE = 8 / BIT_RATE; + int output_columns = + (input_columns + NUM_ELEM_PER_BYTE - 1) / NUM_ELEM_PER_BYTE + + 2 * sizeof(std::uint16_t); + for (std::size_t row = 0; row < input_rows; ++row) { + const float* input_row = input + row * input_columns; + std::uint8_t* output_row = output + row * output_columns; + std::uint16_t* output_row_scale_bias = reinterpret_cast( + output_row + + (input_columns + NUM_ELEM_PER_BYTE - 1) / NUM_ELEM_PER_BYTE); + + float minimum_element = FLT_MAX; + float maximum_element = -FLT_MAX; + __m256 min_v = _mm256_set1_ps(minimum_element); + __m256 max_v = _mm256_set1_ps(maximum_element); + std::size_t col; + for (col = 0; col < input_columns / VLEN * VLEN; col += VLEN) { + __m256 in_v = _mm256_loadu_ps(input_row + col); + min_v = _mm256_min_ps(min_v, in_v); + max_v = _mm256_max_ps(max_v, in_v); + } + alignas(64) float min_buf[VLEN], max_buf[VLEN]; + _mm256_store_ps(min_buf, min_v); + _mm256_store_ps(max_buf, max_v); + for (int i = 0; i < VLEN; ++i) { + minimum_element = std::min(minimum_element, min_buf[i]); + maximum_element = std::max(maximum_element, max_buf[i]); + } + for (; col < input_columns; ++col) { + minimum_element = std::min(minimum_element, input_row[col]); + maximum_element = std::max(maximum_element, input_row[col]); + } + + output_row_scale_bias[1] = floatToHalf(minimum_element); + minimum_element = halfToFloat(output_row_scale_bias[1]); + const float range = maximum_element - minimum_element; + + float scale = range == 0 ? 1.0f : range / ((1 << BIT_RATE) - 1); + std::uint16_t scale_fp16 = floatToHalf(scale); + scale = halfToFloat(scale_fp16); + if (scale == 0) { + // Corner case handling when maximum_element == minimum_element + // Any scale would work because maximum_element - minimum_element will be + // 0 for all X + scale = 1.0f; + } + float inverse_scale = 1.0f / scale; + if (std::isinf(inverse_scale)) { + scale = 1.0f; + inverse_scale = 1.0f; + } + + output_row_scale_bias[0] = floatToHalf(scale); + + __m256 inverse_scale_v = _mm256_set1_ps(inverse_scale); + min_v = _mm256_set1_ps(minimum_element); + + col = 0; + + if (BIT_RATE == 2 || BIT_RATE == 4) { + for (; col + 4 * VLEN <= input_columns; col += 4 * VLEN) { + __m256i x_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps( + _mm256_sub_ps(_mm256_loadu_ps(input_row + col), min_v), + inverse_scale_v)); + __m256i y_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps( + _mm256_sub_ps(_mm256_loadu_ps(input_row + col + VLEN), min_v), + inverse_scale_v)); + __m256i z_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps( + _mm256_sub_ps(_mm256_loadu_ps(input_row + col + 2 * VLEN), min_v), + inverse_scale_v)); + __m256i w_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps( + _mm256_sub_ps(_mm256_loadu_ps(input_row + col + 3 * VLEN), min_v), + inverse_scale_v)); + + // An instruction sequence to save 32 32-bit integers as 8-bit integers + __m256i xy_packed_v = _mm256_packs_epi32(x_rounded_v, y_rounded_v); + __m256i zw_packed_v = _mm256_packs_epi32(z_rounded_v, w_rounded_v); + __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v); + xyzw_packed_v = + _mm256_permutevar8x32_epi32(xyzw_packed_v, permute_mask1_v); + + // saturate to BIT_RATE + xyzw_packed_v = _mm256_min_epu8( + xyzw_packed_v, + _mm256_set1_epi8(static_cast((1 << BIT_RATE) - 1))); + + if (BIT_RATE == 4) { + // pack into lower 8-bit of each 16-bit + xyzw_packed_v = _mm256_and_si256( + _mm256_or_si256( + xyzw_packed_v, _mm256_srli_epi16(xyzw_packed_v, 4)), + _mm256_set1_epi16(0x00ff)); + } else { + // pack into lower 8-bit of each 32-bit + xyzw_packed_v = _mm256_and_si256( + _mm256_or_si256( + _mm256_or_si256( + xyzw_packed_v, _mm256_srli_epi32(xyzw_packed_v, 6)), + _mm256_or_si256( + _mm256_srli_epi32(xyzw_packed_v, 8 + 4), + _mm256_srli_epi32(xyzw_packed_v, 2 * 8 + 2))), + _mm256_set1_epi32(0x00ff)); + } + + __m128i out_v; + if (BIT_RATE == 4) { + // avx2 doesn't have _mm256_cvtepi16_epi8 + out_v = _mm_packus_epi16( + _mm256_castsi256_si128(xyzw_packed_v), + _mm256_extractf128_si256(xyzw_packed_v, 1)); + _mm_storeu_si128( + reinterpret_cast<__m128i*>(output_row + col / NUM_ELEM_PER_BYTE), + out_v); + } else { + // avx2 doesn't have _mm256_cvtepi32_epi8 + out_v = _mm_packus_epi32( + _mm256_castsi256_si128(xyzw_packed_v), + _mm256_extractf128_si256(xyzw_packed_v, 1)); + out_v = _mm_packus_epi16(out_v, out_v); + _mm_storel_epi64( + reinterpret_cast<__m128i*>(output_row + col / NUM_ELEM_PER_BYTE), + out_v); + } + } + } + + for (; col < input_columns; ++col) { + float X = input_row[col]; + std::uint8_t quantized = std::max( + 0, + std::min( + std::lrintf((X - minimum_element) * inverse_scale), + (1 << BIT_RATE) - 1)); + if (col % NUM_ELEM_PER_BYTE == 0) { + output_row[col / NUM_ELEM_PER_BYTE] = quantized; + } else { + output_row[col / NUM_ELEM_PER_BYTE] |= + (quantized << ((col % NUM_ELEM_PER_BYTE) * BIT_RATE)); + } + } + } +} + +template void FloatToFusedNBitRowwiseQuantizedSBHalfAvx2<2>( + const float* input, + int input_rows, + int input_columns, + std::uint8_t* output); + +template void FloatToFusedNBitRowwiseQuantizedSBHalfAvx2<4>( + const float* input, + int input_rows, + int input_columns, + std::uint8_t* output); + +template void FloatToFusedNBitRowwiseQuantizedSBHalfAvx2<8>( + const float* input, + int input_rows, + int input_columns, + std::uint8_t* output); + } // namespace fbgemm diff --git a/test/QuantUtilsTest.cc b/test/QuantUtilsTest.cc index 764dbc2913..23d57e22f5 100644 --- a/test/QuantUtilsTest.cc +++ b/test/QuantUtilsTest.cc @@ -9,10 +9,13 @@ #include #include #include +#include +#include #include #include "fbgemm/QuantUtils.h" +#include "fbgemm/Types.h" #include "fbgemm/Utils.h" using namespace std; @@ -26,6 +29,11 @@ class QuantizeGroupwiseTest class QuantizeTest : public testing::TestWithParam {}; class FusedQuantizeDequantizeTest : public testing::TestWithParam {}; +// Parameter are bit_rate (i.e., the number of bits in quantized values), +// input rows, and input columns +class EmbeddingQuantizeTest + : public testing::TestWithParam> {}; + INSTANTIATE_TEST_CASE_P( InstantiationName, QuantizeGroupwiseTest, @@ -41,6 +49,14 @@ INSTANTIATE_TEST_CASE_P( QuantizeTest, ::testing::Values(1, 2, 5, 8, 9, 16, 20, 28, 32, 33)); +INSTANTIATE_TEST_CASE_P( + InstantiationName, + EmbeddingQuantizeTest, + ::testing::Combine( + ::testing::ValuesIn({2, 4, 8}), + ::testing::ValuesIn({1, 2, 3}), + ::testing::ValuesIn({1, 2, 5, 8, 9, 16, 20, 28, 32, 33, 64, 65}))); + template void ref_impl( const vector& src, @@ -131,6 +147,69 @@ bool floatEqualAll(vector& a, vector& b) { return true; } +template +::testing::AssertionResult isQEmbeddingClose( + const vector& res, + const vector& res_ref, + int out_rows, + int out_emb_cols) { + bool match = true; + std::stringstream ss; + int ld = out_emb_cols + 2 * sizeof(T); + if (res.size() == res_ref.size()) { + for (int i = 0; i < out_rows; ++i) { + if (!match) { + break; + } + // compare embedding values + for (int j = 0; j < out_emb_cols; ++j) { + if (res[i * ld + j] != res_ref[i * ld + j]) { + match = false; + ss << " mismatch at (" << i << ", " << j << ") "; + ss << "ref: " << static_cast(res_ref[i * ld + j]) + << ", test: " << static_cast(res[i * ld + j]) << "\n"; + break; + } + } + // compare scale/bias + float scaleTest, scaleRef, biasTest, biasRef; + if (is_same::value) { + // half scale and bias + scaleTest = cpu_half2float(reinterpret_cast( + res.data() + i * ld + out_emb_cols)[0]); + biasTest = cpu_half2float(reinterpret_cast( + res.data() + i * ld + out_emb_cols)[1]); + scaleRef = cpu_half2float(reinterpret_cast( + res_ref.data() + i * ld + out_emb_cols)[0]); + biasRef = cpu_half2float(reinterpret_cast( + res_ref.data() + i * ld + out_emb_cols)[1]); + } else { + // float scale and bias + // TODO: + } + if (fabs(scaleTest - scaleRef) > std::numeric_limits::epsilon()) { + ss << " scale mismatch for row:" << i; + ss << " ref: " << scaleRef << ", test: " << scaleTest << "\n"; + match = false; + } + if (fabs(biasTest - biasRef) > std::numeric_limits::epsilon()) { + ss << " bias mismatch for row:" << i; + ss << " ref: " << biasRef << ", test: " << biasTest << "\n"; + match = false; + } + } + } else { + ss << " size mismatch "; + match = false; + } + + if (match) + return ::testing::AssertionSuccess(); + else + return ::testing::AssertionFailure() + << " Quantized Embeddings do not match." << ss.str(); +} + /** * Test for QuantizeGroupwise */ @@ -411,7 +490,6 @@ TEST(FusedQuantizeDequantizeTest, cornerCases) { src1.data(), dst_int8.data(), src1.size(), qparams); EXPECT_TRUE(floatEqualAll(dst_int8, ref)); - // Tests vectorized and remainder paths vector src2 = {3.40282e+38, -2.16845e+38, @@ -440,3 +518,33 @@ TEST(FusedQuantizeDequantizeTest, cornerCases) { EXPECT_TRUE(floatEqualAll(dst_uint8, ref2)); } + +TEST_P(EmbeddingQuantizeTest, embeddingHalfTest) { + int bit_rate, rows, cols; + tie(bit_rate, rows, cols) = GetParam(); + + random_device rd; + mt19937 gen(rd()); + + uniform_real_distribution disFP(-10.0f, 10.0f); + + vector inpVec(rows * cols); + + generate(inpVec.begin(), inpVec.end(), [&, disFP]() mutable { return disFP(gen); }); + + int elements_per_byte = 8 / bit_rate; + + int out_emb_cols = + (cols + elements_per_byte - 1) / elements_per_byte; + int outVecSize = rows * (out_emb_cols + 2 * sizeof(float16)); + + vector outVecRef(outVecSize); + vector outVecTest(outVecSize); + + FloatToFusedNBitRowwiseQuantizedSBHalfRef( + bit_rate, inpVec.data(), rows, cols, outVecRef.data()); + FloatToFusedNBitRowwiseQuantizedSBHalf( + bit_rate, inpVec.data(), rows, cols, outVecTest.data()); + + EXPECT_TRUE(isQEmbeddingClose(outVecTest, outVecRef, rows, out_emb_cols)); +}