Skip to content

Commit

Permalink
Move embedding quantization kernels to fbgemm for better sharing betw…
Browse files Browse the repository at this point in the history
…een C2/PT (pytorch#419)

Summary:
Pull Request resolved: pytorch#419

Moves the kernel to fbgemm that quantizes embedding to 8, 4 or 2 bits with half scale/bias.

```
bit_rate,   rows,  cols,  elems_per_usec,    GB/Sec
       2,   100,     16,           94.28,       0.38
       2,   100,     64,          593.92,       2.38
       2,   100,    128,          777.68,       3.11
       2,   100,    256,          976.79,       3.91
       2,   100,    512,         1146.61,       4.59
       2,   100,   1024,         1314.89,       5.26
       2,   100,   2048,         1312.45,       5.25
       2,   120,     16,           94.59,       0.38
       2,   120,     64,          584.89,       2.34
       2,   120,    128,          809.44,       3.24
       2,   120,    256,         1019.32,       4.08
       2,   120,    512,         1137.77,       4.55
       2,   120,   1024,         1156.94,       4.63
       2,   120,   2048,         1321.16,       5.28
       2,  1000,     16,           96.18,       0.38
       2,  1000,     64,          612.17,       2.45
       2,  1000,    128,          776.04,       3.10
       2,  1000,    256,          981.44,       3.93
       2,  1000,    512,         1117.79,       4.47
       2,  1000,   1024,         1262.67,       5.05
       2,  1000,   2048,         1397.23,       5.59
       4,   100,     16,           93.90,       0.38
       4,   100,     64,          607.12,       2.43
       4,   100,    128,          817.09,       3.27
       4,   100,    256,         1034.85,       4.14
       4,   100,    512,         1251.44,       5.01
       4,   100,   1024,         1414.89,       5.66
       4,   100,   2048,         1413.13,       5.65
       4,   120,     16,           93.74,       0.37
       4,   120,     64,          612.77,       2.45
       4,   120,    128,          847.48,       3.39
       4,   120,    256,         1110.61,       4.44
       4,   120,    512,         1265.75,       5.06
       4,   120,   1024,         1239.66,       4.96
       4,   120,   2048,         1433.69,       5.73
       4,  1000,     16,           93.94,       0.38
       4,  1000,     64,          664.46,       2.66
       4,  1000,    128,          819.37,       3.28
       4,  1000,    256,         1055.88,       4.22
       4,  1000,    512,         1217.72,       4.87
       4,  1000,   1024,         1374.96,       5.50
       4,  1000,   2048,         1522.50,       6.09
       8,   100,     16,           98.71,       0.39
       8,   100,     64,          169.60,       0.68
       8,   100,    128,          196.60,       0.79
       8,   100,    256,          212.97,       0.85
       8,   100,    512,          223.10,       0.89
       8,   100,   1024,          226.54,       0.91
       8,   100,   2048,          231.72,       0.93
       8,   120,     16,           99.22,       0.40
       8,   120,     64,          161.74,       0.65
       8,   120,    128,          188.89,       0.76
       8,   120,    256,          213.85,       0.86
       8,   120,    512,          224.18,       0.90
       8,   120,   1024,          226.87,       0.91
       8,   120,   2048,          231.60,       0.93
       8,  1000,     16,           97.97,       0.39
       8,  1000,     64,          172.09,       0.69
       8,  1000,    128,          196.11,       0.78
       8,  1000,    256,          213.29,       0.85
       8,  1000,    512,          222.82,       0.89
       8,  1000,   1024,          225.41,       0.90
       8,  1000,   2048,          231.66,       0.93
```

Reviewed By: supriyar

Differential Revision: D23320018

fbshipit-source-id: a12ec2bec80f77b7c3e15aadee6da94a2b7182fd
  • Loading branch information
dskhudia authored and facebook-github-bot committed Aug 31, 2020
1 parent 0048c0d commit a9397bf
Show file tree
Hide file tree
Showing 6 changed files with 512 additions and 5 deletions.
83 changes: 83 additions & 0 deletions bench/EmbeddingQuantizeBenchmark.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <chrono>
#include <initializer_list>
#include <iomanip>
#include <iostream>

#ifdef _OPENMP
#include <omp.h>
#endif

#include "./BenchUtils.h"
#include "fbgemm/QuantUtils.h"
#include "fbgemm/Types.h"

using namespace std;
using namespace fbgemm;

void performance_test() {
constexpr int NWARMUP = 4;
constexpr int NITER = 256;

cout << setw(8) << "bit_rate"
<< ", " << setw(6) << "rows"
<< "," << setw(6) << "cols"
<< "," << setw(16) << "elems_per_usec"
<< "," << setw(10) << "GB/Sec" << endl;
for (int bit_rate : {2, 4, 8}) {
for (int rowSize : {100, 120, 1000}) {
for (int colSize : {16, 64, 128, 256, 512, 1024, 2048}) {
aligned_vector<float> inpVec(rowSize * colSize);
randFill<float>(inpVec, -10.0f, 10.0f);

int elements_per_byte = 8 / bit_rate;
int out_emb_cols =
(colSize + elements_per_byte - 1) / elements_per_byte;
int outVecSize = rowSize * (out_emb_cols + 2 * sizeof(float16));
aligned_vector<uint8_t> outVec(outVecSize);

double duration = 0.0f;

duration = measureWithWarmup(
[&]() {
FloatToFusedNBitRowwiseQuantizedSBHalf(
bit_rate, inpVec.data(), rowSize, colSize, outVec.data());
},
NWARMUP,
NITER,
[&]() {
cache_evict(inpVec);
cache_evict(outVec);
});

float elements_per_usec = rowSize * colSize / (duration * 1e6);

duration *= 1e9; // convert to ns
long bytes_read = rowSize * colSize * sizeof(float);
float gigabyes_per_sec = bytes_read / duration;

cout << setw(8) << bit_rate << "," << setw(6) << rowSize << ", "
<< setw(6) << colSize << ",";
cout << setw(16) << std::fixed << std::setprecision(2) << elements_per_usec << ", ";
cout << setw(10) << std::fixed << std::setprecision(2) << gigabyes_per_sec << endl;
} // for each cols
} // for each rows
} // for each bit_rate
} // performance_test

int main() {
#ifdef _OPENMP
// Use 1 thread unless OMP_NUM_THREADS is explicit set.
const char* val = getenv("OMP_NUM_THREADS");
if (val == nullptr || !*val) {
omp_set_num_threads(1);
}
#endif
performance_test();
return 0;
}
34 changes: 30 additions & 4 deletions include/fbgemm/QuantUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,15 @@ void Dequantize(

template <typename T>
T FusedQuantizeDequantize(float src, const TensorQuantizationParams& qparams) {
T q = Quantize<T, false>(src, qparams.zero_point, qparams.scale, qparams.precision);
T q = Quantize<T, false>(
src, qparams.zero_point, qparams.scale, qparams.precision);
return Dequantize<T>(q, qparams);
}

/*
Fused integer quantization dequantization kernel to accelerate quantization-aware training.
Quantize fp32 values in src to (u)int8 using the provided qparams, and dequantize quantized
integer values back into fp32.
Fused integer quantization dequantization kernel to accelerate
quantization-aware training. Quantize fp32 values in src to (u)int8 using the
provided qparams, and dequantize quantized integer values back into fp32.
*/
template <typename T>
FBGEMM_API void FusedQuantizeDequantize(
Expand Down Expand Up @@ -248,4 +249,29 @@ FBGEMM_API void Requantize(
int thread_id = 0,
int num_threads = 1);

/**
* Convert float inputs to rowwise quantized outputs.
* bitrate specifies the number of bits in quantized output.
* Scale and Bias are in fp16. Each row's Scale and Bias are stored in
* the row itself (fused) at the end.
*
* @param bit_rate can be 2, 4, or 8
*/
FBGEMM_API void FloatToFusedNBitRowwiseQuantizedSBHalf(
int bit_rate,
const float* input,
int input_rows,
int input_columns,
std::uint8_t* output);

/**
* Same as FloatToFusedNBitRowwiseQuantizedSBHalf but unoptimized.
* This should not be called directly except in testing.
*/
FBGEMM_API void FloatToFusedNBitRowwiseQuantizedSBHalfRef(
int bit_rate,
const float* input,
int input_rows,
int input_columns,
std::uint8_t* output);
} // namespace fbgemm
7 changes: 7 additions & 0 deletions include/fbgemm/QuantUtilsAvx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,4 +119,11 @@ FBGEMM_API void requantizeForFloatAvx2(
int ld_in,
const requantizationForFloatParams_t& r);

template <int BIT_RATE>
void FloatToFusedNBitRowwiseQuantizedSBHalfAvx2(
const float* input,
int input_rows,
int input_columns,
std::uint8_t* output);

} // namespace fbgemm
91 changes: 91 additions & 0 deletions src/QuantUtils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

#include "fbgemm/Fbgemm.h"

#include "fbgemm/Types.h"

namespace fbgemm {

using namespace std;
Expand Down Expand Up @@ -464,4 +466,93 @@ FBGEMM_API void RequantizeFixedPoint<uint8_t>(
}
}

void FloatToFusedNBitRowwiseQuantizedSBHalfRef(
int bit_rate,
const float* input,
int input_rows,
int input_columns,
std::uint8_t* output) {
int num_elem_per_byte = 8 / bit_rate;
int output_columns =
(input_columns + num_elem_per_byte - 1) / num_elem_per_byte +
2 * sizeof(float16);
for (std::size_t row = 0; row < input_rows; ++row) {
const float* input_row = input + row * input_columns;
std::uint8_t* output_row = output + row * output_columns;
float16* output_row_scale_bias = reinterpret_cast<float16*>(
output_row +
(input_columns + num_elem_per_byte - 1) / num_elem_per_byte);

float minimum_element =
*std::min_element(input_row, input_row + input_columns);
float maximum_element =
*std::max_element(input_row, input_row + input_columns);

float16 minimum_element_fp16 = cpu_float2half_rn(minimum_element);
minimum_element = cpu_half2float(minimum_element_fp16);
const float range = maximum_element - minimum_element;

float scale = range == 0 ? 1.0f : range / ((1 << bit_rate) - 1);
float16 scale_fp16 = cpu_float2half_rn(scale);
scale = cpu_half2float(scale_fp16);
if (scale == 0) {
// Corner case handling when maximum_element == minimum_element
// Any scale would work because X - minimum_element will be 0 for all X
scale = 1.0f;
}
float inverse_scale = 1.0f / scale;
if (std::isinf(inverse_scale)) {
scale = 1.0f;
inverse_scale = 1.0f;
}

output_row_scale_bias[0] = cpu_float2half_rn(scale);
output_row_scale_bias[1] = minimum_element_fp16;
for (std::size_t col = 0; col < input_columns; ++col) {
float X = input_row[col];
std::uint8_t quantized = std::max(
0,
std::min<int>(
std::lrintf((X - minimum_element) * inverse_scale),
(1 << bit_rate) - 1));
if (col % num_elem_per_byte == 0) {
output_row[col / num_elem_per_byte] = quantized;
} else {
output_row[col / num_elem_per_byte] |=
(quantized << ((col % num_elem_per_byte) * bit_rate));
}
}
}
}

void FloatToFusedNBitRowwiseQuantizedSBHalf(
int bit_rate,
const float* input,
int input_rows,
int input_columns,
std::uint8_t* output) {
if (cpuinfo_initialize() && fbgemmHasAvx2Support()) {
switch (bit_rate) {
case 2:
FloatToFusedNBitRowwiseQuantizedSBHalfAvx2<2>(
input, input_rows, input_columns, output);
break;
case 4:
FloatToFusedNBitRowwiseQuantizedSBHalfAvx2<4>(
input, input_rows, input_columns, output);
break;
case 8:
FloatToFusedNBitRowwiseQuantizedSBHalfAvx2<8>(
input, input_rows, input_columns, output);
break;
default:
FloatToFusedNBitRowwiseQuantizedSBHalfRef(
bit_rate, input, input_rows, input_columns, output);
}
} else {
FloatToFusedNBitRowwiseQuantizedSBHalfRef(
bit_rate, input, input_rows, input_columns, output);
}
}

} // namespace fbgemm
Loading

0 comments on commit a9397bf

Please sign in to comment.