Skip to content

Commit

Permalink
Add quantize and dequantize functions between rowwise-quantized int8 …
Browse files Browse the repository at this point in the history
…+ fp32 scale/bias and fp32 embeddings. (pytorch#623)

Summary:
Pull Request resolved: pytorch#623

Template-ize FloatToFused8BitRowwiseQuantizedSBFloat and Fused8BitRowwiseQuantizedSBFloatToFloat, for opt and ref versions.
Similar to D28620537 (pytorch@9cb33bc) and D28875981 (pytorch@77a4792).

Reviewed By: dskhudia

Differential Revision: D28918591

fbshipit-source-id: c70a1552d3d01648e848c710d16885ad5b1b5e47
  • Loading branch information
caogao authored and facebook-github-bot committed Jun 10, 2021
1 parent 0520ad5 commit 3182125
Show file tree
Hide file tree
Showing 5 changed files with 295 additions and 77 deletions.
49 changes: 41 additions & 8 deletions include/fbgemm/QuantUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -326,29 +326,60 @@ FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf(
* This version intentionally supports only 8-bit because we want to discourage
* the usage of float scale and bias with 2 and 4 bit cases as that diminishes
* the overall memory savings.
*
* TODO(T91361248): deprecate and replace with FloatOrHalfToFused8BitRowwiseQuantizedSBFloat.
*/
FBGEMM_API void FloatToFused8BitRowwiseQuantizedSBFloat(
const float* input,
int input_rows,
int input_columns,
std::uint8_t* output);

/**
* Convert float or half inputs to rowwise quantized (8-bit) outputs.
* Scale and Bias are in float. Each row's Scale and Bias are stored in
* the row itself (fused) at the end.
*
* This version intentionally supports only 8-bit because we want to discourage
* the usage of float scale and bias with 2 and 4 bit cases as that diminishes
* the overall memory savings.
*/
template <typename InputType>
FBGEMM_API void FloatOrHalfToFused8BitRowwiseQuantizedSBFloat(
const InputType* input,
int input_rows,
int input_columns,
std::uint8_t* output);

/**
* Convert fused rowwise quantized (8-bit) inputs to float outputs.
* Scale and Bias are in float. Each row's Scale and Bias are stored in
* the row itself (fused) at the end.
*
* This version intentionally supports only 8-bit because
* the corresponding quantize version only supports 8-bit.
*
* TODO(T91361248): deprecate and replace with Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf.
*/
FBGEMM_API void Fused8BitRowwiseQuantizedSBFloatToFloat(
const uint8_t* input,
int input_rows,
int input_columns,
float* output);

/**
* Convert fused rowwise quantized (8-bit) inputs to float or half outputs.
* Scale and Bias are in float. Each row's Scale and Bias are stored in
* the row itself (fused) at the end.
*
* This version intentionally supports only 8-bit because
* the corresponding quantize version only supports 8-bit.
*/
template <typename OutputType>
FBGEMM_API void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf(
const uint8_t* input,
int input_rows,
int input_columns,
OutputType* output);

/**
* Same as ToFusedNBitRowwiseQuantizedSBHalf but unoptimized.
* This should not be called directly except in testing.
Expand All @@ -362,11 +393,12 @@ FBGEMM_API void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfRef(
std::uint8_t* output);

/**
* Same as FloatToFused8BitRowwiseQuantizedSBFloat but unoptimized.
* Same as FloatOrHalfToFused8BitRowwiseQuantizedSBFloat but unoptimized.
* This should not be called directly except in testing.
*/
FBGEMM_API void FloatToFused8BitRowwiseQuantizedSBFloatRef(
const float* input,
template <typename InputType>
FBGEMM_API void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatRef(
const InputType* input,
int input_rows,
int input_columns,
std::uint8_t* output);
Expand All @@ -384,13 +416,14 @@ FBGEMM_API void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef(
OutputType* output);

/**
* Same as Fused8BitRowwiseQuantizedSBFloatToFloat but unoptimized.
* Same as Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf but unoptimized.
* This should not be called directly except in testing.
*/
FBGEMM_API void Fused8BitRowwiseQuantizedSBFloatToFloatRef(
template <typename OutputType>
FBGEMM_API void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef(
const uint8_t* input,
int input_rows,
int input_columns,
float* output);
OutputType* output);

} // namespace fbgemm
10 changes: 6 additions & 4 deletions include/fbgemm/QuantUtilsAvx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,9 @@ void FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfAvx2(
int input_columns,
std::uint8_t* output);

void FloatToFused8BitRowwiseQuantizedSBFloatAvx2(
const float* input,
template <typename InputType>
void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatAvx2(
const InputType* input,
int input_rows,
int input_columns,
std::uint8_t* output);
Expand All @@ -145,10 +146,11 @@ void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfAvx2(
int input_columns,
OutputType* output);

void Fused8BitRowwiseQuantizedSBFloatToFloatAvx2(
template <typename OutputType>
void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2(
const std::uint8_t* input,
int input_rows,
int input_columns,
float* output);
OutputType* output);

} // namespace fbgemm
92 changes: 73 additions & 19 deletions src/QuantUtils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -619,50 +619,70 @@ void FloatToFusedNBitRowwiseQuantizedSBHalf(
bit_rate, input, input_rows, input_columns, output);
}

void FloatToFused8BitRowwiseQuantizedSBFloatRef(
const float* input,
template <typename InputType>
void FloatOrHalfToFused8BitRowwiseQuantizedSBFloatRef(
const InputType* input,
int input_rows,
int input_columns,
std::uint8_t* output) {
constexpr float kEpsilon = 1e-8f;

int output_columns = input_columns + 2 * sizeof(float);
std::vector<float> input_row_float(input_columns);
for (std::size_t row = 0; row < input_rows; ++row) {
const float* input_row = input + row * input_columns;
const InputType* input_row = input + row * input_columns;
std::uint8_t* output_row = output + row * output_columns;
float* output_row_scale_bias =
reinterpret_cast<float*>(output_row + input_columns);

for (std::size_t col = 0; col < input_columns; ++col) {
if (std::is_same<InputType, float>()) {
input_row_float[col] = input_row[col];
} else {
input_row_float[col] = cpu_half2float(input_row[col]);
}
}

float minimum_element =
*std::min_element(input_row, input_row + input_columns);
*std::min_element(input_row_float.begin(), input_row_float.end());
float maximum_element =
*std::max_element(input_row, input_row + input_columns);
*std::max_element(input_row_float.begin(), input_row_float.end());
float range = maximum_element - minimum_element;

output_row_scale_bias[0] = range / 255.0f;
output_row_scale_bias[1] = minimum_element;
const auto inverse_scale = 255.0f / (range + kEpsilon);
for (std::size_t col = 0; col < input_columns; ++col) {
output_row[col] =
std::lrintf((input_row[col] - minimum_element) * inverse_scale);
std::lrintf((input_row_float[col] - minimum_element) * inverse_scale);
}
}
}

void FloatToFused8BitRowwiseQuantizedSBFloat(
const float* input,
template <typename InputType>
void FloatOrHalfToFused8BitRowwiseQuantizedSBFloat(
const InputType* input,
int input_rows,
int input_columns,
std::uint8_t* output) {
if (cpuinfo_initialize() && fbgemmHasAvx2Support()) {
FloatToFused8BitRowwiseQuantizedSBFloatAvx2(
FloatOrHalfToFused8BitRowwiseQuantizedSBFloatAvx2<InputType>(
input, input_rows, input_columns, output);
} else {
FloatToFused8BitRowwiseQuantizedSBFloatRef(
FloatOrHalfToFused8BitRowwiseQuantizedSBFloatRef<InputType>(
input, input_rows, input_columns, output);
}
}

void FloatToFused8BitRowwiseQuantizedSBFloat(
const float* input,
int input_rows,
int input_columns,
std::uint8_t* output) {
FloatOrHalfToFused8BitRowwiseQuantizedSBFloat<float>(
input, input_rows, input_columns, output);
}

template <typename OutputType>
void FusedNBitRowwiseQuantizedSBHalfToFloatOrHalfRef(
int bit_rate,
Expand Down Expand Up @@ -741,40 +761,56 @@ void FusedNBitRowwiseQuantizedSBHalfToFloat(
bit_rate, input, input_rows, input_columns, output);
}

void Fused8BitRowwiseQuantizedSBFloatToFloatRef(
template <typename OutputType>
void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef(
const std::uint8_t* input,
int input_rows,
int input_columns,
float* output) {
OutputType* output) {
int output_columns = input_columns - 2 * sizeof(float);

for (std::size_t row = 0; row < input_rows; ++row) {
const std::uint8_t* input_row = input + row * input_columns;
const float* input_row_scale_bias =
reinterpret_cast<const float*>(input_row + output_columns);
float* output_row = output + row * output_columns;
OutputType* output_row = output + row * output_columns;

for (std::size_t col = 0; col < output_columns; ++col) {
output_row[col] =
float output_value =
input_row[col] * input_row_scale_bias[0] + input_row_scale_bias[1];
if (std::is_same<OutputType, float>()) {
output_row[col] = output_value;
} else {
output_row[col] = cpu_float2half_rn(output_value);
}
}
}
}

void Fused8BitRowwiseQuantizedSBFloatToFloat(
template <typename OutputType>
void Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf(
const std::uint8_t* input,
int input_rows,
int input_columns,
float* output) {
OutputType* output) {
if (cpuinfo_initialize() && fbgemmHasAvx2Support()) {
Fused8BitRowwiseQuantizedSBFloatToFloatAvx2(
Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfAvx2<OutputType>(
input, input_rows, input_columns, output);
} else {
Fused8BitRowwiseQuantizedSBFloatToFloatRef(
Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef<OutputType>(
input, input_rows, input_columns, output);
}
}

void Fused8BitRowwiseQuantizedSBFloatToFloat(
const uint8_t* input,
int input_rows,
int input_columns,
float* output) {
Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf<float>(
input, input_rows, input_columns, output);
}

#define INSTANTIATE_QuantizationFunctions(type) \
template FBGEMM_API void \
FloatOrHalfToFusedNBitRowwiseQuantizedSBHalfRef<type>( \
Expand All @@ -801,7 +837,25 @@ void Fused8BitRowwiseQuantizedSBFloatToFloat(
const uint8_t* input, \
int input_rows, \
int input_columns, \
type* output);
type* output); \
template FBGEMM_API void \
FloatOrHalfToFused8BitRowwiseQuantizedSBFloatRef<type>( \
const type* input, \
int input_rows, \
int input_columns, \
std::uint8_t* output); \
template FBGEMM_API void \
FloatOrHalfToFused8BitRowwiseQuantizedSBFloat<type>( \
const type* input, \
int input_rows, \
int input_columns, \
std::uint8_t* output); \
template FBGEMM_API void \
Fused8BitRowwiseQuantizedSBFloatToFloatOrHalfRef<type>( \
const uint8_t* input, int input_rows, int input_columns, type* output); \
template FBGEMM_API void \
Fused8BitRowwiseQuantizedSBFloatToFloatOrHalf<type>( \
const uint8_t* input, int input_rows, int input_columns, type* output);

// clang-format off
INSTANTIATE_QuantizationFunctions(float)
Expand Down
Loading

0 comments on commit 3182125

Please sign in to comment.