From a9397bf7215a53b06d60a966b3d7803beedfa226 Mon Sep 17 00:00:00 2001 From: Daya Khudia Date: Mon, 31 Aug 2020 16:41:27 -0700 Subject: [PATCH] Move embedding quantization kernels to fbgemm for better sharing between C2/PT (#419) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/419 Moves the kernel to fbgemm that quantizes embedding to 8, 4 or 2 bits with half scale/bias. ``` bit_rate, rows, cols, elems_per_usec, GB/Sec 2, 100, 16, 94.28, 0.38 2, 100, 64, 593.92, 2.38 2, 100, 128, 777.68, 3.11 2, 100, 256, 976.79, 3.91 2, 100, 512, 1146.61, 4.59 2, 100, 1024, 1314.89, 5.26 2, 100, 2048, 1312.45, 5.25 2, 120, 16, 94.59, 0.38 2, 120, 64, 584.89, 2.34 2, 120, 128, 809.44, 3.24 2, 120, 256, 1019.32, 4.08 2, 120, 512, 1137.77, 4.55 2, 120, 1024, 1156.94, 4.63 2, 120, 2048, 1321.16, 5.28 2, 1000, 16, 96.18, 0.38 2, 1000, 64, 612.17, 2.45 2, 1000, 128, 776.04, 3.10 2, 1000, 256, 981.44, 3.93 2, 1000, 512, 1117.79, 4.47 2, 1000, 1024, 1262.67, 5.05 2, 1000, 2048, 1397.23, 5.59 4, 100, 16, 93.90, 0.38 4, 100, 64, 607.12, 2.43 4, 100, 128, 817.09, 3.27 4, 100, 256, 1034.85, 4.14 4, 100, 512, 1251.44, 5.01 4, 100, 1024, 1414.89, 5.66 4, 100, 2048, 1413.13, 5.65 4, 120, 16, 93.74, 0.37 4, 120, 64, 612.77, 2.45 4, 120, 128, 847.48, 3.39 4, 120, 256, 1110.61, 4.44 4, 120, 512, 1265.75, 5.06 4, 120, 1024, 1239.66, 4.96 4, 120, 2048, 1433.69, 5.73 4, 1000, 16, 93.94, 0.38 4, 1000, 64, 664.46, 2.66 4, 1000, 128, 819.37, 3.28 4, 1000, 256, 1055.88, 4.22 4, 1000, 512, 1217.72, 4.87 4, 1000, 1024, 1374.96, 5.50 4, 1000, 2048, 1522.50, 6.09 8, 100, 16, 98.71, 0.39 8, 100, 64, 169.60, 0.68 8, 100, 128, 196.60, 0.79 8, 100, 256, 212.97, 0.85 8, 100, 512, 223.10, 0.89 8, 100, 1024, 226.54, 0.91 8, 100, 2048, 231.72, 0.93 8, 120, 16, 99.22, 0.40 8, 120, 64, 161.74, 0.65 8, 120, 128, 188.89, 0.76 8, 120, 256, 213.85, 0.86 8, 120, 512, 224.18, 0.90 8, 120, 1024, 226.87, 0.91 8, 120, 2048, 231.60, 0.93 8, 1000, 16, 97.97, 0.39 8, 1000, 64, 172.09, 0.69 8, 1000, 128, 196.11, 0.78 8, 1000, 256, 213.29, 0.85 8, 1000, 512, 222.82, 0.89 8, 1000, 1024, 225.41, 0.90 8, 1000, 2048, 231.66, 0.93 ``` Reviewed By: supriyar Differential Revision: D23320018 fbshipit-source-id: a12ec2bec80f77b7c3e15aadee6da94a2b7182fd --- bench/EmbeddingQuantizeBenchmark.cc | 83 ++++++++++++ include/fbgemm/QuantUtils.h | 34 ++++- include/fbgemm/QuantUtilsAvx2.h | 7 + src/QuantUtils.cc | 91 +++++++++++++ src/QuantUtilsAvx2.cc | 192 ++++++++++++++++++++++++++++ test/QuantUtilsTest.cc | 110 +++++++++++++++- 6 files changed, 512 insertions(+), 5 deletions(-) create mode 100644 bench/EmbeddingQuantizeBenchmark.cc diff --git a/bench/EmbeddingQuantizeBenchmark.cc b/bench/EmbeddingQuantizeBenchmark.cc new file mode 100644 index 0000000000..efbcdbcb8f --- /dev/null +++ b/bench/EmbeddingQuantizeBenchmark.cc @@ -0,0 +1,83 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include +#include +#include + +#ifdef _OPENMP +#include +#endif + +#include "./BenchUtils.h" +#include "fbgemm/QuantUtils.h" +#include "fbgemm/Types.h" + +using namespace std; +using namespace fbgemm; + +void performance_test() { + constexpr int NWARMUP = 4; + constexpr int NITER = 256; + + cout << setw(8) << "bit_rate" + << ", " << setw(6) << "rows" + << "," << setw(6) << "cols" + << "," << setw(16) << "elems_per_usec" + << "," << setw(10) << "GB/Sec" << endl; + for (int bit_rate : {2, 4, 8}) { + for (int rowSize : {100, 120, 1000}) { + for (int colSize : {16, 64, 128, 256, 512, 1024, 2048}) { + aligned_vector inpVec(rowSize * colSize); + randFill(inpVec, -10.0f, 10.0f); + + int elements_per_byte = 8 / bit_rate; + int out_emb_cols = + (colSize + elements_per_byte - 1) / elements_per_byte; + int outVecSize = rowSize * (out_emb_cols + 2 * sizeof(float16)); + aligned_vector outVec(outVecSize); + + double duration = 0.0f; + + duration = measureWithWarmup( + [&]() { + FloatToFusedNBitRowwiseQuantizedSBHalf( + bit_rate, inpVec.data(), rowSize, colSize, outVec.data()); + }, + NWARMUP, + NITER, + [&]() { + cache_evict(inpVec); + cache_evict(outVec); + }); + + float elements_per_usec = rowSize * colSize / (duration * 1e6); + + duration *= 1e9; // convert to ns + long bytes_read = rowSize * colSize * sizeof(float); + float gigabyes_per_sec = bytes_read / duration; + + cout << setw(8) << bit_rate << "," << setw(6) << rowSize << ", " + << setw(6) << colSize << ","; + cout << setw(16) << std::fixed << std::setprecision(2) << elements_per_usec << ", "; + cout << setw(10) << std::fixed << std::setprecision(2) << gigabyes_per_sec << endl; + } // for each cols + } // for each rows + } // for each bit_rate +} // performance_test + +int main() { +#ifdef _OPENMP + // Use 1 thread unless OMP_NUM_THREADS is explicit set. + const char* val = getenv("OMP_NUM_THREADS"); + if (val == nullptr || !*val) { + omp_set_num_threads(1); + } +#endif + performance_test(); + return 0; +} diff --git a/include/fbgemm/QuantUtils.h b/include/fbgemm/QuantUtils.h index b10a1252f0..4a8abb9a81 100644 --- a/include/fbgemm/QuantUtils.h +++ b/include/fbgemm/QuantUtils.h @@ -155,14 +155,15 @@ void Dequantize( template T FusedQuantizeDequantize(float src, const TensorQuantizationParams& qparams) { - T q = Quantize(src, qparams.zero_point, qparams.scale, qparams.precision); + T q = Quantize( + src, qparams.zero_point, qparams.scale, qparams.precision); return Dequantize(q, qparams); } /* -Fused integer quantization dequantization kernel to accelerate quantization-aware training. -Quantize fp32 values in src to (u)int8 using the provided qparams, and dequantize quantized -integer values back into fp32. +Fused integer quantization dequantization kernel to accelerate +quantization-aware training. Quantize fp32 values in src to (u)int8 using the +provided qparams, and dequantize quantized integer values back into fp32. */ template FBGEMM_API void FusedQuantizeDequantize( @@ -248,4 +249,29 @@ FBGEMM_API void Requantize( int thread_id = 0, int num_threads = 1); +/** + * Convert float inputs to rowwise quantized outputs. + * bitrate specifies the number of bits in quantized output. + * Scale and Bias are in fp16. Each row's Scale and Bias are stored in + * the row itself (fused) at the end. + * + * @param bit_rate can be 2, 4, or 8 + */ +FBGEMM_API void FloatToFusedNBitRowwiseQuantizedSBHalf( + int bit_rate, + const float* input, + int input_rows, + int input_columns, + std::uint8_t* output); + +/** + * Same as FloatToFusedNBitRowwiseQuantizedSBHalf but unoptimized. + * This should not be called directly except in testing. + */ +FBGEMM_API void FloatToFusedNBitRowwiseQuantizedSBHalfRef( + int bit_rate, + const float* input, + int input_rows, + int input_columns, + std::uint8_t* output); } // namespace fbgemm diff --git a/include/fbgemm/QuantUtilsAvx2.h b/include/fbgemm/QuantUtilsAvx2.h index 6ed81a65c9..62320d743b 100644 --- a/include/fbgemm/QuantUtilsAvx2.h +++ b/include/fbgemm/QuantUtilsAvx2.h @@ -119,4 +119,11 @@ FBGEMM_API void requantizeForFloatAvx2( int ld_in, const requantizationForFloatParams_t& r); +template +void FloatToFusedNBitRowwiseQuantizedSBHalfAvx2( + const float* input, + int input_rows, + int input_columns, + std::uint8_t* output); + } // namespace fbgemm diff --git a/src/QuantUtils.cc b/src/QuantUtils.cc index f70f5bdcda..99f9d43e36 100644 --- a/src/QuantUtils.cc +++ b/src/QuantUtils.cc @@ -6,6 +6,8 @@ #include "fbgemm/Fbgemm.h" +#include "fbgemm/Types.h" + namespace fbgemm { using namespace std; @@ -464,4 +466,93 @@ FBGEMM_API void RequantizeFixedPoint( } } +void FloatToFusedNBitRowwiseQuantizedSBHalfRef( + int bit_rate, + const float* input, + int input_rows, + int input_columns, + std::uint8_t* output) { + int num_elem_per_byte = 8 / bit_rate; + int output_columns = + (input_columns + num_elem_per_byte - 1) / num_elem_per_byte + + 2 * sizeof(float16); + for (std::size_t row = 0; row < input_rows; ++row) { + const float* input_row = input + row * input_columns; + std::uint8_t* output_row = output + row * output_columns; + float16* output_row_scale_bias = reinterpret_cast( + output_row + + (input_columns + num_elem_per_byte - 1) / num_elem_per_byte); + + float minimum_element = + *std::min_element(input_row, input_row + input_columns); + float maximum_element = + *std::max_element(input_row, input_row + input_columns); + + float16 minimum_element_fp16 = cpu_float2half_rn(minimum_element); + minimum_element = cpu_half2float(minimum_element_fp16); + const float range = maximum_element - minimum_element; + + float scale = range == 0 ? 1.0f : range / ((1 << bit_rate) - 1); + float16 scale_fp16 = cpu_float2half_rn(scale); + scale = cpu_half2float(scale_fp16); + if (scale == 0) { + // Corner case handling when maximum_element == minimum_element + // Any scale would work because X - minimum_element will be 0 for all X + scale = 1.0f; + } + float inverse_scale = 1.0f / scale; + if (std::isinf(inverse_scale)) { + scale = 1.0f; + inverse_scale = 1.0f; + } + + output_row_scale_bias[0] = cpu_float2half_rn(scale); + output_row_scale_bias[1] = minimum_element_fp16; + for (std::size_t col = 0; col < input_columns; ++col) { + float X = input_row[col]; + std::uint8_t quantized = std::max( + 0, + std::min( + std::lrintf((X - minimum_element) * inverse_scale), + (1 << bit_rate) - 1)); + if (col % num_elem_per_byte == 0) { + output_row[col / num_elem_per_byte] = quantized; + } else { + output_row[col / num_elem_per_byte] |= + (quantized << ((col % num_elem_per_byte) * bit_rate)); + } + } + } +} + +void FloatToFusedNBitRowwiseQuantizedSBHalf( + int bit_rate, + const float* input, + int input_rows, + int input_columns, + std::uint8_t* output) { + if (cpuinfo_initialize() && fbgemmHasAvx2Support()) { + switch (bit_rate) { + case 2: + FloatToFusedNBitRowwiseQuantizedSBHalfAvx2<2>( + input, input_rows, input_columns, output); + break; + case 4: + FloatToFusedNBitRowwiseQuantizedSBHalfAvx2<4>( + input, input_rows, input_columns, output); + break; + case 8: + FloatToFusedNBitRowwiseQuantizedSBHalfAvx2<8>( + input, input_rows, input_columns, output); + break; + default: + FloatToFusedNBitRowwiseQuantizedSBHalfRef( + bit_rate, input, input_rows, input_columns, output); + } + } else { + FloatToFusedNBitRowwiseQuantizedSBHalfRef( + bit_rate, input, input_rows, input_columns, output); + } +} + } // namespace fbgemm diff --git a/src/QuantUtilsAvx2.cc b/src/QuantUtilsAvx2.cc index ad009ba5b6..eacb0a2717 100644 --- a/src/QuantUtilsAvx2.cc +++ b/src/QuantUtilsAvx2.cc @@ -11,6 +11,7 @@ #include //for nearbyint #include //for numeric_limits #include //for assert +#include // for FLT_MAX #include //for memcpy #include "./MaskAvx2.h" @@ -1430,4 +1431,195 @@ INSTANTIATE_BIAS(false) #undef INSTANTIATE_Q_GRANS #undef INSTANTIATE_BIAS +static inline uint16_t floatToHalf(float val) { +#ifdef _MSC_VER + // Use _mm256_cvtps_ph/_mm256_cvtph_ps because _cvtsh_ss/_cvtss_sh don't + // exist in MSVC. + __m256 val_v = _mm256_set1_ps(val); + __m128i val_half_v = + _mm256_cvtps_ph(val_v, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + return static_cast(_mm_cvtsi128_si32(val_half_v)); +#else + return _cvtss_sh(val, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); +#endif +} +static inline float halfToFloat(uint16_t val) { +#ifdef _MSC_VER + return _mm256_cvtss_f32(_mm256_cvtph_ps(_mm_cvtsi32_si128(val))); +#else + return _cvtsh_ss(val); +#endif +} + +template +void FloatToFusedNBitRowwiseQuantizedSBHalfAvx2( + const float* input, + int input_rows, + int input_columns, + std::uint8_t* output) { + __m256i permute_mask1_v = + _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00); + + constexpr int VLEN = 8; + constexpr int NUM_ELEM_PER_BYTE = 8 / BIT_RATE; + int output_columns = + (input_columns + NUM_ELEM_PER_BYTE - 1) / NUM_ELEM_PER_BYTE + + 2 * sizeof(std::uint16_t); + for (std::size_t row = 0; row < input_rows; ++row) { + const float* input_row = input + row * input_columns; + std::uint8_t* output_row = output + row * output_columns; + std::uint16_t* output_row_scale_bias = reinterpret_cast( + output_row + + (input_columns + NUM_ELEM_PER_BYTE - 1) / NUM_ELEM_PER_BYTE); + + float minimum_element = FLT_MAX; + float maximum_element = -FLT_MAX; + __m256 min_v = _mm256_set1_ps(minimum_element); + __m256 max_v = _mm256_set1_ps(maximum_element); + std::size_t col; + for (col = 0; col < input_columns / VLEN * VLEN; col += VLEN) { + __m256 in_v = _mm256_loadu_ps(input_row + col); + min_v = _mm256_min_ps(min_v, in_v); + max_v = _mm256_max_ps(max_v, in_v); + } + alignas(64) float min_buf[VLEN], max_buf[VLEN]; + _mm256_store_ps(min_buf, min_v); + _mm256_store_ps(max_buf, max_v); + for (int i = 0; i < VLEN; ++i) { + minimum_element = std::min(minimum_element, min_buf[i]); + maximum_element = std::max(maximum_element, max_buf[i]); + } + for (; col < input_columns; ++col) { + minimum_element = std::min(minimum_element, input_row[col]); + maximum_element = std::max(maximum_element, input_row[col]); + } + + output_row_scale_bias[1] = floatToHalf(minimum_element); + minimum_element = halfToFloat(output_row_scale_bias[1]); + const float range = maximum_element - minimum_element; + + float scale = range == 0 ? 1.0f : range / ((1 << BIT_RATE) - 1); + std::uint16_t scale_fp16 = floatToHalf(scale); + scale = halfToFloat(scale_fp16); + if (scale == 0) { + // Corner case handling when maximum_element == minimum_element + // Any scale would work because maximum_element - minimum_element will be + // 0 for all X + scale = 1.0f; + } + float inverse_scale = 1.0f / scale; + if (std::isinf(inverse_scale)) { + scale = 1.0f; + inverse_scale = 1.0f; + } + + output_row_scale_bias[0] = floatToHalf(scale); + + __m256 inverse_scale_v = _mm256_set1_ps(inverse_scale); + min_v = _mm256_set1_ps(minimum_element); + + col = 0; + + if (BIT_RATE == 2 || BIT_RATE == 4) { + for (; col + 4 * VLEN <= input_columns; col += 4 * VLEN) { + __m256i x_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps( + _mm256_sub_ps(_mm256_loadu_ps(input_row + col), min_v), + inverse_scale_v)); + __m256i y_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps( + _mm256_sub_ps(_mm256_loadu_ps(input_row + col + VLEN), min_v), + inverse_scale_v)); + __m256i z_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps( + _mm256_sub_ps(_mm256_loadu_ps(input_row + col + 2 * VLEN), min_v), + inverse_scale_v)); + __m256i w_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps( + _mm256_sub_ps(_mm256_loadu_ps(input_row + col + 3 * VLEN), min_v), + inverse_scale_v)); + + // An instruction sequence to save 32 32-bit integers as 8-bit integers + __m256i xy_packed_v = _mm256_packs_epi32(x_rounded_v, y_rounded_v); + __m256i zw_packed_v = _mm256_packs_epi32(z_rounded_v, w_rounded_v); + __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v); + xyzw_packed_v = + _mm256_permutevar8x32_epi32(xyzw_packed_v, permute_mask1_v); + + // saturate to BIT_RATE + xyzw_packed_v = _mm256_min_epu8( + xyzw_packed_v, + _mm256_set1_epi8(static_cast((1 << BIT_RATE) - 1))); + + if (BIT_RATE == 4) { + // pack into lower 8-bit of each 16-bit + xyzw_packed_v = _mm256_and_si256( + _mm256_or_si256( + xyzw_packed_v, _mm256_srli_epi16(xyzw_packed_v, 4)), + _mm256_set1_epi16(0x00ff)); + } else { + // pack into lower 8-bit of each 32-bit + xyzw_packed_v = _mm256_and_si256( + _mm256_or_si256( + _mm256_or_si256( + xyzw_packed_v, _mm256_srli_epi32(xyzw_packed_v, 6)), + _mm256_or_si256( + _mm256_srli_epi32(xyzw_packed_v, 8 + 4), + _mm256_srli_epi32(xyzw_packed_v, 2 * 8 + 2))), + _mm256_set1_epi32(0x00ff)); + } + + __m128i out_v; + if (BIT_RATE == 4) { + // avx2 doesn't have _mm256_cvtepi16_epi8 + out_v = _mm_packus_epi16( + _mm256_castsi256_si128(xyzw_packed_v), + _mm256_extractf128_si256(xyzw_packed_v, 1)); + _mm_storeu_si128( + reinterpret_cast<__m128i*>(output_row + col / NUM_ELEM_PER_BYTE), + out_v); + } else { + // avx2 doesn't have _mm256_cvtepi32_epi8 + out_v = _mm_packus_epi32( + _mm256_castsi256_si128(xyzw_packed_v), + _mm256_extractf128_si256(xyzw_packed_v, 1)); + out_v = _mm_packus_epi16(out_v, out_v); + _mm_storel_epi64( + reinterpret_cast<__m128i*>(output_row + col / NUM_ELEM_PER_BYTE), + out_v); + } + } + } + + for (; col < input_columns; ++col) { + float X = input_row[col]; + std::uint8_t quantized = std::max( + 0, + std::min( + std::lrintf((X - minimum_element) * inverse_scale), + (1 << BIT_RATE) - 1)); + if (col % NUM_ELEM_PER_BYTE == 0) { + output_row[col / NUM_ELEM_PER_BYTE] = quantized; + } else { + output_row[col / NUM_ELEM_PER_BYTE] |= + (quantized << ((col % NUM_ELEM_PER_BYTE) * BIT_RATE)); + } + } + } +} + +template void FloatToFusedNBitRowwiseQuantizedSBHalfAvx2<2>( + const float* input, + int input_rows, + int input_columns, + std::uint8_t* output); + +template void FloatToFusedNBitRowwiseQuantizedSBHalfAvx2<4>( + const float* input, + int input_rows, + int input_columns, + std::uint8_t* output); + +template void FloatToFusedNBitRowwiseQuantizedSBHalfAvx2<8>( + const float* input, + int input_rows, + int input_columns, + std::uint8_t* output); + } // namespace fbgemm diff --git a/test/QuantUtilsTest.cc b/test/QuantUtilsTest.cc index 764dbc2913..23d57e22f5 100644 --- a/test/QuantUtilsTest.cc +++ b/test/QuantUtilsTest.cc @@ -9,10 +9,13 @@ #include #include #include +#include +#include #include #include "fbgemm/QuantUtils.h" +#include "fbgemm/Types.h" #include "fbgemm/Utils.h" using namespace std; @@ -26,6 +29,11 @@ class QuantizeGroupwiseTest class QuantizeTest : public testing::TestWithParam {}; class FusedQuantizeDequantizeTest : public testing::TestWithParam {}; +// Parameter are bit_rate (i.e., the number of bits in quantized values), +// input rows, and input columns +class EmbeddingQuantizeTest + : public testing::TestWithParam> {}; + INSTANTIATE_TEST_CASE_P( InstantiationName, QuantizeGroupwiseTest, @@ -41,6 +49,14 @@ INSTANTIATE_TEST_CASE_P( QuantizeTest, ::testing::Values(1, 2, 5, 8, 9, 16, 20, 28, 32, 33)); +INSTANTIATE_TEST_CASE_P( + InstantiationName, + EmbeddingQuantizeTest, + ::testing::Combine( + ::testing::ValuesIn({2, 4, 8}), + ::testing::ValuesIn({1, 2, 3}), + ::testing::ValuesIn({1, 2, 5, 8, 9, 16, 20, 28, 32, 33, 64, 65}))); + template void ref_impl( const vector& src, @@ -131,6 +147,69 @@ bool floatEqualAll(vector& a, vector& b) { return true; } +template +::testing::AssertionResult isQEmbeddingClose( + const vector& res, + const vector& res_ref, + int out_rows, + int out_emb_cols) { + bool match = true; + std::stringstream ss; + int ld = out_emb_cols + 2 * sizeof(T); + if (res.size() == res_ref.size()) { + for (int i = 0; i < out_rows; ++i) { + if (!match) { + break; + } + // compare embedding values + for (int j = 0; j < out_emb_cols; ++j) { + if (res[i * ld + j] != res_ref[i * ld + j]) { + match = false; + ss << " mismatch at (" << i << ", " << j << ") "; + ss << "ref: " << static_cast(res_ref[i * ld + j]) + << ", test: " << static_cast(res[i * ld + j]) << "\n"; + break; + } + } + // compare scale/bias + float scaleTest, scaleRef, biasTest, biasRef; + if (is_same::value) { + // half scale and bias + scaleTest = cpu_half2float(reinterpret_cast( + res.data() + i * ld + out_emb_cols)[0]); + biasTest = cpu_half2float(reinterpret_cast( + res.data() + i * ld + out_emb_cols)[1]); + scaleRef = cpu_half2float(reinterpret_cast( + res_ref.data() + i * ld + out_emb_cols)[0]); + biasRef = cpu_half2float(reinterpret_cast( + res_ref.data() + i * ld + out_emb_cols)[1]); + } else { + // float scale and bias + // TODO: + } + if (fabs(scaleTest - scaleRef) > std::numeric_limits::epsilon()) { + ss << " scale mismatch for row:" << i; + ss << " ref: " << scaleRef << ", test: " << scaleTest << "\n"; + match = false; + } + if (fabs(biasTest - biasRef) > std::numeric_limits::epsilon()) { + ss << " bias mismatch for row:" << i; + ss << " ref: " << biasRef << ", test: " << biasTest << "\n"; + match = false; + } + } + } else { + ss << " size mismatch "; + match = false; + } + + if (match) + return ::testing::AssertionSuccess(); + else + return ::testing::AssertionFailure() + << " Quantized Embeddings do not match." << ss.str(); +} + /** * Test for QuantizeGroupwise */ @@ -411,7 +490,6 @@ TEST(FusedQuantizeDequantizeTest, cornerCases) { src1.data(), dst_int8.data(), src1.size(), qparams); EXPECT_TRUE(floatEqualAll(dst_int8, ref)); - // Tests vectorized and remainder paths vector src2 = {3.40282e+38, -2.16845e+38, @@ -440,3 +518,33 @@ TEST(FusedQuantizeDequantizeTest, cornerCases) { EXPECT_TRUE(floatEqualAll(dst_uint8, ref2)); } + +TEST_P(EmbeddingQuantizeTest, embeddingHalfTest) { + int bit_rate, rows, cols; + tie(bit_rate, rows, cols) = GetParam(); + + random_device rd; + mt19937 gen(rd()); + + uniform_real_distribution disFP(-10.0f, 10.0f); + + vector inpVec(rows * cols); + + generate(inpVec.begin(), inpVec.end(), [&, disFP]() mutable { return disFP(gen); }); + + int elements_per_byte = 8 / bit_rate; + + int out_emb_cols = + (cols + elements_per_byte - 1) / elements_per_byte; + int outVecSize = rows * (out_emb_cols + 2 * sizeof(float16)); + + vector outVecRef(outVecSize); + vector outVecTest(outVecSize); + + FloatToFusedNBitRowwiseQuantizedSBHalfRef( + bit_rate, inpVec.data(), rows, cols, outVecRef.data()); + FloatToFusedNBitRowwiseQuantizedSBHalf( + bit_rate, inpVec.data(), rows, cols, outVecTest.data()); + + EXPECT_TRUE(isQEmbeddingClose(outVecTest, outVecRef, rows, out_emb_cols)); +}