Skip to content

Commit

Permalink
Back out "have bf16 quantize kernels use 2d blocks with loop to cover…
Browse files Browse the repository at this point in the history
… large tensors" (pytorch#572)

Summary:
Pull Request resolved: pytorch#572

Original commit changeset: dfc437200807

Reviewed By: yinbinm, xw285cornell

Differential Revision: D27256808

fbshipit-source-id: c32197e340c4602be31f5915134ec8e17996d52e
  • Loading branch information
jianyuh authored and facebook-github-bot committed Mar 24, 2021
1 parent 1d7eaf3 commit fc2c734
Showing 1 changed file with 21 additions and 30 deletions.
51 changes: 21 additions & 30 deletions fbgemm_gpu/include/fbgemm_gpu/quantize_ops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -144,22 +144,19 @@ __global__ void _fake_8bit_quantize_cuda_kernel(
constexpr float kEpsilon = 1e-8f;
int row = (int)blockIdx.x * blockDim.x + threadIdx.x;
int col = (int)blockIdx.y * blockDim.y + threadIdx.y;
const int row_incre = blockDim.y * gridDim.y;
const int col_incre = blockDim.x * gridDim.x;
for (/*row*/; row < nrows; row += row_incre) {

if (row < nrows && col < ncols) {
const float* input_row = input + row * ncols;
float* output_row = output + row * ncols;
for (/*col*/; col < ncols; col += col_incre) {
float minimum_element =
*thrust::min_element(thrust::device, input_row, input_row + ncols);
float maximum_element =
*thrust::max_element(thrust::device, input_row, input_row + ncols);
float range = maximum_element - minimum_element;
const auto inverse_scale = 255.0f / (range + kEpsilon);
std::uint8_t quantized_val =
std::lrintf((input_row[col] - minimum_element) * inverse_scale);
output_row[col] = quantized_val * (range / 255.0f) + minimum_element;
}
float minimum_element =
*thrust::min_element(thrust::device, input_row, input_row + ncols);
float maximum_element =
*thrust::max_element(thrust::device, input_row, input_row + ncols);
float range = maximum_element - minimum_element;
const auto inverse_scale = 255.0f / (range + kEpsilon);
std::uint8_t quantized_val =
std::lrintf((input_row[col] - minimum_element) * inverse_scale);
output_row[col] = quantized_val * (range / 255.0f) + minimum_element;
}
}

Expand All @@ -175,8 +172,8 @@ __global__ void _float_to_fusednbitrowwise_cuda_kernel(
(ncols + num_elem_per_byte - 1) / num_elem_per_byte + 2 * sizeof(__half);

int row = (int)blockIdx.x * blockDim.x + threadIdx.x;
const int row_incre = blockDim.x * gridDim.x;
for (/*row*/; row < nrows; row += row_incre) {

if (row < nrows) {
const float* input_row = input + row * ncols;
std::uint8_t* output_row = output + row * output_columns;
__half* output_row_scale_bias = reinterpret_cast<__half*>(
Expand Down Expand Up @@ -262,14 +259,12 @@ __global__ void _float_to_bfloat16_cuda_kernel(
const int nrows,
const int ncols,
uint16_t* __restrict__ output) {
int row = (int)blockIdx.y * blockDim.y + threadIdx.y;
int col = (int)blockIdx.x * blockDim.x + threadIdx.x;
const int row_incre = blockDim.y * gridDim.y;
const int col_incre = blockDim.x * gridDim.x;
for (/*row*/; row < nrows; row += row_incre) {
int row = (int)blockIdx.x * blockDim.x + threadIdx.x;

if (row < nrows) {
const float* input_row = input + row * ncols;
uint16_t* output_row = output + row * ncols;
for (/*col*/; col < ncols; col += col_incre) {
for (std::size_t col = 0; col < ncols; ++col) {
// Add 2^15 and right shift 16 to do round-nearest
output_row[col] =
(*reinterpret_cast<const uint32_t*>(input_row + col) + (1 << 15)) >>
Expand All @@ -285,11 +280,10 @@ __global__ void _bfloat16_to_float_cuda_kernel(
const int ncols,
float* __restrict__ output) {
int row = (int)blockIdx.y * blockDim.y + threadIdx.y;
int col = (int)blockIdx.x * blockDim.x + threadIdx.x;
const int col = (int)blockIdx.x * blockDim.x + threadIdx.x;
const int row_incre = blockDim.y * gridDim.y;
const int col_incre = blockDim.x * gridDim.x;
for (/*row*/; row < nrows; row += row_incre) {
for (/*col*/; col < ncols; col += col_incre) {
if (col < ncols) {
const uint16_t* input_row = input + row * ncols;
float* output_row = output + row * ncols;
uint32_t val_fp32 = static_cast<uint32_t>(
Expand Down Expand Up @@ -408,11 +402,8 @@ __global__ void _float_to_hfp8_cuda_kernel(
float max_pos) {
int row = (int)blockIdx.y * blockDim.y + threadIdx.y;
int col = (int)blockIdx.x * blockDim.x + threadIdx.x;
const int row_incre = blockDim.y * gridDim.y;
const int col_incre = blockDim.x * gridDim.x;

for (int i = row; i < nrows; i += row_incre) {
for (int j = col; j < ncols; j += col_incre) {
for (int i = row; i < nrows; i += blockDim.y) {
for (int j = col; j < ncols; j += blockDim.x) {
output[i * ncols + j] = float_to_hfp8(
input[i * ncols + j], ebits, mbits, bias, min_pos, max_pos);
}
Expand Down

0 comments on commit fc2c734

Please sign in to comment.