Skip to content

Commit

Permalink
part2: Move embedding quantization kernels to fbgemm for better shari…
Browse files Browse the repository at this point in the history
…ng between C2/PT (pytorch#425)

Summary:
Pull Request resolved: pytorch#425

8bit with float scale and bias.

Test and benchmark added.

```
With scale and bias as float
bit_rate,   rows,  cols,  elems_per_usec,    GB/Sec
       8,   100,     16,          556.20,       2.22
       8,   100,     64,         1022.51,       4.09
       8,   100,    128,         1121.43,       4.49
       8,   100,    256,         1292.61,       5.17
       8,   100,    512,         1526.69,       6.11
       8,   100,   1024,         1407.09,       5.63
       8,   100,   2048,         1620.34,       6.48
       8,   120,     16,          562.60,       2.25
       8,   120,     64,         1058.52,       4.23
       8,   120,    128,         1082.74,       4.33
       8,   120,    256,         1382.87,       5.53
       8,   120,    512,         1513.15,       6.05
       8,   120,   1024,         1441.19,       5.76
       8,   120,   2048,         1634.99,       6.54
       8,  1000,     16,          598.05,       2.39
       8,  1000,     64,         1151.16,       4.60
       8,  1000,    128,         1071.58,       4.29
       8,  1000,    256,         1278.66,       5.11
       8,  1000,    512,         1441.13,       5.76
       8,  1000,   1024,         1605.48,       6.42
       8,  1000,   2048,         1764.24,       7.06
```

Reviewed By: supriyar

Differential Revision: D23455486

fbshipit-source-id: e0dea307c42d614747302544a7179fa40194dad6
  • Loading branch information
dskhudia authored and facebook-github-bot committed Sep 4, 2020
1 parent 1289a1f commit 7e401c0
Show file tree
Hide file tree
Showing 6 changed files with 264 additions and 10 deletions.
46 changes: 37 additions & 9 deletions bench/EmbeddingQuantizeBenchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <initializer_list>
#include <iomanip>
#include <iostream>
#include <vector>

#ifdef _OPENMP
#include <omp.h>
Expand All @@ -20,33 +21,57 @@
using namespace std;
using namespace fbgemm;

// T is the type of scale and bias
template <typename T>
void performance_test() {
constexpr int NWARMUP = 4;
constexpr int NITER = 256;

if (is_same<T, float16>::value) {
cout << "With scale and bias as float16" << endl;
} else {
cout << "With scale and bias as float" << endl;
}
cout << setw(8) << "bit_rate"
<< ", " << setw(6) << "rows"
<< "," << setw(6) << "cols"
<< "," << setw(16) << "elems_per_usec"
<< "," << setw(10) << "GB/Sec" << endl;
for (int bit_rate : {2, 4, 8}) {
std::vector<int> bit_rates;
if (is_same<T, float16>::value) {
bit_rates = {2, 4, 8};
} else {
// float
bit_rates = {8};
}
for (int bit_rate : bit_rates) {
for (int rowSize : {100, 120, 1000}) {
for (int colSize : {16, 64, 128, 256, 512, 1024, 2048}) {
aligned_vector<float> inpVec(rowSize * colSize);
randFill<float>(inpVec, -10.0f, 10.0f);

int elements_per_byte = 8 / bit_rate;
int out_emb_cols =
(colSize + elements_per_byte - 1) / elements_per_byte;
int out_emb_cols = colSize;

if (is_same<T, float16>::value) {
int elements_per_byte = 8 / bit_rate;
out_emb_cols = (colSize + elements_per_byte - 1) / elements_per_byte;
}
int outVecSize = rowSize * (out_emb_cols + 2 * sizeof(float16));
aligned_vector<uint8_t> outVec(outVecSize);

double duration = 0.0f;

duration = measureWithWarmup(
[&]() {
FloatToFusedNBitRowwiseQuantizedSBHalf(
bit_rate, inpVec.data(), rowSize, colSize, outVec.data());
is_same<T, float16>::value
? FloatToFusedNBitRowwiseQuantizedSBHalf(
bit_rate,
inpVec.data(),
rowSize,
colSize,
outVec.data())
: FloatToFused8BitRowwiseQuantizedSBFloat(
inpVec.data(), rowSize, colSize, outVec.data());
},
NWARMUP,
NITER,
Expand All @@ -63,8 +88,10 @@ void performance_test() {

cout << setw(8) << bit_rate << "," << setw(6) << rowSize << ", "
<< setw(6) << colSize << ",";
cout << setw(16) << std::fixed << std::setprecision(2) << elements_per_usec << ", ";
cout << setw(10) << std::fixed << std::setprecision(2) << gigabyes_per_sec << endl;
cout << setw(16) << std::fixed << std::setprecision(2)
<< elements_per_usec << ", ";
cout << setw(10) << std::fixed << std::setprecision(2)
<< gigabyes_per_sec << endl;
} // for each cols
} // for each rows
} // for each bit_rate
Expand All @@ -78,6 +105,7 @@ int main() {
omp_set_num_threads(1);
}
#endif
performance_test();
performance_test<float16>();
performance_test<float>();
return 0;
}
27 changes: 27 additions & 0 deletions include/fbgemm/QuantUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -274,4 +274,31 @@ FBGEMM_API void FloatToFusedNBitRowwiseQuantizedSBHalfRef(
int input_rows,
int input_columns,
std::uint8_t* output);

/**
* Convert float inputs to rowwise quantized (8-bit) outputs.
* Scale and Bias are in float. Each row's Scale and Bias are stored in
* the row itself (fused) at the end.
*
* This version intentionally supports only 8-bit because we want to discourage
* the usage of float scale and bias with 2 and 4 bit cases as that diminishes
* the overall memory savings.
*
*/
FBGEMM_API void FloatToFused8BitRowwiseQuantizedSBFloat(
const float* input,
int input_rows,
int input_columns,
std::uint8_t* output);

/**
* Same as FloatToFused8BitRowwiseQuantizedSBFloat but unoptimized.
* This should not be called directly except in testing.
*/
FBGEMM_API void FloatToFused8BitRowwiseQuantizedSBFloatRef(
const float* input,
int input_rows,
int input_columns,
std::uint8_t* output);

} // namespace fbgemm
6 changes: 6 additions & 0 deletions include/fbgemm/QuantUtilsAvx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,10 @@ void FloatToFusedNBitRowwiseQuantizedSBHalfAvx2(
int input_columns,
std::uint8_t* output);

void FloatToFused8BitRowwiseQuantizedSBFloatAvx2(
const float* input,
int input_rows,
int input_columns,
std::uint8_t* output);

} // namespace fbgemm
44 changes: 44 additions & 0 deletions src/QuantUtils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -555,4 +555,48 @@ void FloatToFusedNBitRowwiseQuantizedSBHalf(
}
}

void FloatToFused8BitRowwiseQuantizedSBFloatRef(
const float* input,
int input_rows,
int input_columns,
std::uint8_t* output) {
constexpr float kEpsilon = 1e-8f;

int output_columns = input_columns + 2 * sizeof(float);
for (std::size_t row = 0; row < input_rows; ++row) {
const float* input_row = input + row * input_columns;
std::uint8_t* output_row = output + row * output_columns;
float* output_row_scale_bias =
reinterpret_cast<float*>(output_row + input_columns);

float minimum_element =
*std::min_element(input_row, input_row + input_columns);
float maximum_element =
*std::max_element(input_row, input_row + input_columns);
float range = maximum_element - minimum_element;

output_row_scale_bias[0] = range / 255.0f;
output_row_scale_bias[1] = minimum_element;
const auto inverse_scale = 255.0f / (range + kEpsilon);
for (std::size_t col = 0; col < input_columns; ++col) {
output_row[col] =
std::lrintf((input_row[col] - minimum_element) * inverse_scale);
}
}
}

void FloatToFused8BitRowwiseQuantizedSBFloat(
const float* input,
int input_rows,
int input_columns,
std::uint8_t* output) {
if (cpuinfo_initialize() && fbgemmHasAvx2Support()) {
FloatToFused8BitRowwiseQuantizedSBFloatAvx2(
input, input_rows, input_columns, output);
} else {
FloatToFused8BitRowwiseQuantizedSBFloatRef(
input, input_rows, input_columns, output);
}
}

} // namespace fbgemm
101 changes: 101 additions & 0 deletions src/QuantUtilsAvx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1622,4 +1622,105 @@ template void FloatToFusedNBitRowwiseQuantizedSBHalfAvx2<8>(
int input_columns,
std::uint8_t* output);

void FloatToFused8BitRowwiseQuantizedSBFloatAvx2(
const float* input,
int input_rows,
int input_columns,
std::uint8_t* output) {
constexpr int VLEN = 8;
constexpr float kEpsilon = 1e-8f;

__m256i permute_mask1_v =
_mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
// clang-format off
__m256i shuffle_mask_v = _mm256_set_epi8(
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0x0c, 0x08, 0x04, 0x00,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0x0c, 0x08, 0x04, 0x00);
// clang-format on

__m256i permute_mask2_v =
_mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00);

int output_columns = input_columns + 2 * sizeof(float);
for (std::size_t row = 0; row < input_rows; ++row) {
const float* input_row = input + row * input_columns;
std::uint8_t* output_row = output + row * output_columns;
float* output_row_scale_bias =
reinterpret_cast<float*>(output_row + input_columns);

float minimum_element = FLT_MAX;
float maximum_element = -FLT_MAX;
__m256 min_v = _mm256_set1_ps(minimum_element);
__m256 max_v = _mm256_set1_ps(maximum_element);
std::size_t col;
for (col = 0; col < input_columns / VLEN * VLEN; col += VLEN) {
__m256 in_v = _mm256_loadu_ps(input_row + col);
min_v = _mm256_min_ps(min_v, in_v);
max_v = _mm256_max_ps(max_v, in_v);
}
alignas(64) float min_buf[VLEN], max_buf[VLEN];
_mm256_store_ps(min_buf, min_v);
_mm256_store_ps(max_buf, max_v);
for (int i = 0; i < VLEN; ++i) {
minimum_element = std::min(minimum_element, min_buf[i]);
maximum_element = std::max(maximum_element, max_buf[i]);
}
for (; col < input_columns; ++col) {
minimum_element = std::min(minimum_element, input_row[col]);
maximum_element = std::max(maximum_element, input_row[col]);
}

float range = maximum_element - minimum_element;

output_row_scale_bias[0] = range / 255.0f;
output_row_scale_bias[1] = minimum_element;
const auto inverse_scale = 255.0f / (range + kEpsilon);
min_v = _mm256_set1_ps(minimum_element);
__m256 inverse_scale_v = _mm256_set1_ps(inverse_scale);

for (col = 0; col < input_columns / (4 * VLEN) * (4 * VLEN);
col += 4 * VLEN) {
__m256i x_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps(
_mm256_sub_ps(_mm256_loadu_ps(input_row + col), min_v),
inverse_scale_v));
__m256i y_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps(
_mm256_sub_ps(_mm256_loadu_ps(input_row + col + VLEN), min_v),
inverse_scale_v));
__m256i z_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps(
_mm256_sub_ps(_mm256_loadu_ps(input_row + col + 2 * VLEN), min_v),
inverse_scale_v));
__m256i w_rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps(
_mm256_sub_ps(_mm256_loadu_ps(input_row + col + 3 * VLEN), min_v),
inverse_scale_v));

// An instruction sequence to save 32 32-bit integers as 8-bit integers
__m256i xy_packed_v = _mm256_packs_epi32(x_rounded_v, y_rounded_v);
__m256i zw_packed_v = _mm256_packs_epi32(z_rounded_v, w_rounded_v);
__m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v);
xyzw_packed_v =
_mm256_permutevar8x32_epi32(xyzw_packed_v, permute_mask1_v);
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(output_row + col), xyzw_packed_v);
}
for (; col < input_columns / VLEN * VLEN; col += VLEN) {
__m256i rounded_v = _mm256_cvtps_epi32(_mm256_mul_ps(
_mm256_sub_ps(_mm256_loadu_ps(input_row + col), min_v),
inverse_scale_v));

// An instruction sequence to save 8 32-bit integers as 8-bit integers
rounded_v = _mm256_shuffle_epi8(rounded_v, shuffle_mask_v);
rounded_v = _mm256_permutevar8x32_epi32(rounded_v, permute_mask2_v);
_mm_storel_epi64(
reinterpret_cast<__m128i*>(output_row + col),
_mm256_castsi256_si128(rounded_v));
}
for (; col < input_columns; ++col) {
output_row[col] =
std::lrintf((input_row[col] - minimum_element) * inverse_scale);
}
}
}

} // namespace fbgemm
50 changes: 49 additions & 1 deletion test/QuantUtilsTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ class FusedQuantizeDequantizeTest : public testing::TestWithParam<int> {};
class EmbeddingQuantizeTest
: public testing::TestWithParam<tuple<int, int, int>> {};

// Parameter are input rows and input columns
// Scale and Bias are of type float (SBFloat)
class EmbeddingQuantizeSBFloatTest
: public testing::TestWithParam<tuple<int, int>> {};

INSTANTIATE_TEST_CASE_P(
InstantiationName,
QuantizeGroupwiseTest,
Expand All @@ -57,6 +62,13 @@ INSTANTIATE_TEST_CASE_P(
::testing::ValuesIn({1, 2, 3}),
::testing::ValuesIn({1, 2, 5, 8, 9, 16, 20, 28, 32, 33, 64, 65})));

INSTANTIATE_TEST_CASE_P(
InstantiationName,
EmbeddingQuantizeSBFloatTest,
::testing::Combine(
::testing::ValuesIn({1, 2, 3}),
::testing::ValuesIn({1, 2, 5, 8, 9, 16, 20, 28, 32, 33, 64, 65})));

template <typename T, layout_t LT>
void ref_impl(
const vector<float>& src,
Expand Down Expand Up @@ -185,7 +197,14 @@ ::testing::AssertionResult isQEmbeddingClose(
res_ref.data() + i * ld + out_emb_cols)[1]);
} else {
// float scale and bias
// TODO:
scaleTest = reinterpret_cast<const float*>(
res.data() + i * ld + out_emb_cols)[0];
biasTest = reinterpret_cast<const float*>(
res.data() + i * ld + out_emb_cols)[1];
scaleRef = reinterpret_cast<const float*>(
res_ref.data() + i * ld + out_emb_cols)[0];
biasRef = reinterpret_cast<const float*>(
res_ref.data() + i * ld + out_emb_cols)[1];
}
if (fabs(scaleTest - scaleRef) > std::numeric_limits<float>::epsilon()) {
ss << " scale mismatch for row:" << i;
Expand Down Expand Up @@ -548,3 +567,32 @@ TEST_P(EmbeddingQuantizeTest, embeddingHalfTest) {

EXPECT_TRUE(isQEmbeddingClose<float16>(outVecTest, outVecRef, rows, out_emb_cols));
}

TEST_P(EmbeddingQuantizeSBFloatTest, embeddingFloatTest) {
int rows, cols;
tie(rows, cols) = GetParam();

random_device rd;
mt19937 gen(rd());

uniform_real_distribution<float> disFP(-10.0f, 10.0f);

vector<float> inpVec(rows * cols);

generate(inpVec.begin(), inpVec.end(), [&, disFP]() mutable {
return disFP(gen);
});

int outVecSize = rows * (cols + 2 * sizeof(float));

vector<uint8_t> outVecRef(outVecSize);
vector<uint8_t> outVecTest(outVecSize);

FloatToFused8BitRowwiseQuantizedSBFloatRef(
inpVec.data(), rows, cols, outVecRef.data());
FloatToFused8BitRowwiseQuantizedSBFloat(
inpVec.data(), rows, cols, outVecTest.data());

// The number of input columns is the same as the number of output columns
EXPECT_TRUE(isQEmbeddingClose<float>(outVecTest, outVecRef, rows, cols));
}

0 comments on commit 7e401c0

Please sign in to comment.