Skip to content

Commit

Permalink
Re-apply "[fbgemm_gpu] have bf16 quantize kernels use 2d blocks with …
Browse files Browse the repository at this point in the history
…loop to cover large tensors" (pytorch#583)

Summary:
Pull Request resolved: pytorch#583

Original commit changeset: c32197e340c4

Difference with the original version D27002397 (pytorch@2b5c02d):
D27389636

Reviewed By: jspark1105, yinbinm

Differential Revision: D27389576

fbshipit-source-id: c320e0090da1d4aa6539b604f7b986c5265381ee
  • Loading branch information
jianyuh authored and facebook-github-bot committed Mar 31, 2021
1 parent e839120 commit 5a8d11f
Showing 1 changed file with 38 additions and 29 deletions.
67 changes: 38 additions & 29 deletions fbgemm_gpu/include/fbgemm_gpu/quantize_ops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -142,21 +142,24 @@ __global__ void _fake_8bit_quantize_cuda_kernel(
int ncols,
float* __restrict__ output) {
constexpr float kEpsilon = 1e-8f;
int row = (int)blockIdx.x * blockDim.x + threadIdx.x;
int col = (int)blockIdx.y * blockDim.y + threadIdx.y;

if (row < nrows && col < ncols) {
const int row_incre = blockDim.y * gridDim.y;
for (int row = blockIdx.x * blockDim.x + threadIdx.x; row < nrows;
row += row_incre) {
const float* input_row = input + row * ncols;
float* output_row = output + row * ncols;
float minimum_element =
*thrust::min_element(thrust::device, input_row, input_row + ncols);
float maximum_element =
*thrust::max_element(thrust::device, input_row, input_row + ncols);
float range = maximum_element - minimum_element;
const auto inverse_scale = 255.0f / (range + kEpsilon);
std::uint8_t quantized_val =
std::lrintf((input_row[col] - minimum_element) * inverse_scale);
output_row[col] = quantized_val * (range / 255.0f) + minimum_element;
const int col_incre = blockDim.x * gridDim.x;
for (int col = blockIdx.y * blockDim.y + threadIdx.y; col < ncols;
col += col_incre) {
float minimum_element =
*thrust::min_element(thrust::device, input_row, input_row + ncols);
float maximum_element =
*thrust::max_element(thrust::device, input_row, input_row + ncols);
float range = maximum_element - minimum_element;
const auto inverse_scale = 255.0f / (range + kEpsilon);
std::uint8_t quantized_val =
std::lrintf((input_row[col] - minimum_element) * inverse_scale);
output_row[col] = quantized_val * (range / 255.0f) + minimum_element;
}
}
}

Expand All @@ -172,8 +175,8 @@ __global__ void _float_to_fusednbitrowwise_cuda_kernel(
(ncols + num_elem_per_byte - 1) / num_elem_per_byte + 2 * sizeof(__half);

int row = (int)blockIdx.x * blockDim.x + threadIdx.x;

if (row < nrows) {
const int row_incre = blockDim.x * gridDim.x;
for (/*row*/; row < nrows; row += row_incre) {
const float* input_row = input + row * ncols;
std::uint8_t* output_row = output + row * output_columns;
__half* output_row_scale_bias = reinterpret_cast<__half*>(
Expand Down Expand Up @@ -259,12 +262,14 @@ __global__ void _float_to_bfloat16_cuda_kernel(
const int nrows,
const int ncols,
uint16_t* __restrict__ output) {
int row = (int)blockIdx.x * blockDim.x + threadIdx.x;

if (row < nrows) {
const int row_incre = blockDim.y * gridDim.y;
const int col_incre = blockDim.x * gridDim.x;
for (int row = blockIdx.y * blockDim.y + threadIdx.y; row < nrows;
row += row_incre) {
const float* input_row = input + row * ncols;
uint16_t* output_row = output + row * ncols;
for (std::size_t col = 0; col < ncols; ++col) {
for (int col = blockIdx.x * blockDim.x + threadIdx.x; col < ncols;
col += col_incre) {
// Add 2^15 and right shift 16 to do round-nearest
output_row[col] =
(*reinterpret_cast<const uint32_t*>(input_row + col) + (1 << 15)) >>
Expand All @@ -279,11 +284,12 @@ __global__ void _bfloat16_to_float_cuda_kernel(
const int nrows,
const int ncols,
float* __restrict__ output) {
int row = (int)blockIdx.y * blockDim.y + threadIdx.y;
const int col = (int)blockIdx.x * blockDim.x + threadIdx.x;
const int row_incre = blockDim.y * gridDim.y;
for (/*row*/; row < nrows; row += row_incre) {
if (col < ncols) {
const int col_incre = blockDim.x * gridDim.x;
for (int row = blockIdx.y * blockDim.y + threadIdx.y; row < nrows;
row += row_incre) {
for (int col = blockIdx.x * blockDim.x + threadIdx.x; col < ncols;
col += col_incre) {
const uint16_t* input_row = input + row * ncols;
float* output_row = output + row * ncols;
uint32_t val_fp32 = static_cast<uint32_t>(
Expand Down Expand Up @@ -400,12 +406,15 @@ __global__ void _float_to_hfp8_cuda_kernel(
int bias,
float min_pos,
float max_pos) {
int row = (int)blockIdx.y * blockDim.y + threadIdx.y;
int col = (int)blockIdx.x * blockDim.x + threadIdx.x;
for (int i = row; i < nrows; i += blockDim.y) {
for (int j = col; j < ncols; j += blockDim.x) {
output[i * ncols + j] = float_to_hfp8(
input[i * ncols + j], ebits, mbits, bias, min_pos, max_pos);
const int row_incre = blockDim.y * gridDim.y;
const int col_incre = blockDim.x * gridDim.x;

for (int row = blockIdx.y * blockDim.y + threadIdx.y; row < nrows;
row += row_incre) {
for (int col = blockIdx.x * blockDim.x + threadIdx.x; col < ncols;
col += col_incre) {
output[row * ncols + col] = float_to_hfp8(
input[row * ncols + col], ebits, mbits, bias, min_pos, max_pos);
}
}
}
Expand Down

0 comments on commit 5a8d11f

Please sign in to comment.