Skip to content

Commit

Permalink
make functions defined in quantize_ops.cuh inline (pytorch#612)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#612

It is dangerous to define non-inline function in header file without. This happens to work because only one compilation unit includes quantize_ops.cuh but this breaks (as seen from a follow-up diff) to generate duplicated symbols

Reviewed By: jianyuh

Differential Revision: D28632972

fbshipit-source-id: e0a51b7405ee72ba8dce0da8ba6fd92c9bb46ce6
  • Loading branch information
jspark1105 authored and facebook-github-bot committed May 24, 2021
1 parent 5bbb70d commit dd3abd3
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions fbgemm_gpu/include/fbgemm_gpu/quantize_ops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ __device__ inline __attribute__((always_inline)) T quantize_ops_shfl_xor(const T
#endif
}

__global__ void _get_8bit_qparam_cuda_kernel(
__global__ inline void _get_8bit_qparam_cuda_kernel(
const float* __restrict__ input,
int nrows,
int ncols,
Expand Down Expand Up @@ -78,7 +78,7 @@ __global__ void _get_8bit_qparam_cuda_kernel(
range_list[row] = range;
}

__global__ void _compute_8bit_quantize_cuda_kernel(
__global__ inline void _compute_8bit_quantize_cuda_kernel(
const float* const __restrict__ input,
const float* const __restrict__ range_list,
const int nrows,
Expand Down Expand Up @@ -110,7 +110,7 @@ __global__ void _compute_8bit_quantize_cuda_kernel(
}

// FP32 -> Fused 8-bit rowwise kernel
__global__ void _float_to_fused8bitrowwise_cuda_kernel(
__global__ inline void _float_to_fused8bitrowwise_cuda_kernel(
const float* __restrict__ input,
int nrows,
int ncols,
Expand Down Expand Up @@ -145,7 +145,7 @@ __global__ void _float_to_fused8bitrowwise_cuda_kernel(
}

// Fused 8-bit rowwise -> FP32 kernel
__global__ void _fused8bitrowwise_to_float_cuda_kernel(
__global__ inline void _fused8bitrowwise_to_float_cuda_kernel(
const std::uint8_t* const __restrict__ input,
const int nrows,
const int ncols,
Expand All @@ -169,7 +169,7 @@ __global__ void _fused8bitrowwise_to_float_cuda_kernel(
}

// Fake 8-bit quantize kernel: FP32 -> UINT8 rowwise -> FP32
__global__ void _fake_8bit_quantize_cuda_kernel(
__global__ inline void _fake_8bit_quantize_cuda_kernel(
const float* __restrict__ input,
int nrows,
int ncols,
Expand Down Expand Up @@ -197,7 +197,7 @@ __global__ void _fake_8bit_quantize_cuda_kernel(
}

// FP32 -> Fused 4/2-bit rowwise kernel
__global__ void _float_to_fusednbitrowwise_cuda_kernel(
__global__ inline void _float_to_fusednbitrowwise_cuda_kernel(
int bit_rate,
const float* __restrict__ input,
int nrows,
Expand Down Expand Up @@ -259,7 +259,7 @@ __global__ void _float_to_fusednbitrowwise_cuda_kernel(
}

// Fused 4/2-bit rowwise -> FP32 kernel
__global__ void _fusednbitrowwise_to_float_cuda_kernel(
__global__ inline void _fusednbitrowwise_to_float_cuda_kernel(
const int bit_rate,
const std::uint8_t* input,
const int nrows,
Expand Down Expand Up @@ -290,7 +290,7 @@ __global__ void _fusednbitrowwise_to_float_cuda_kernel(
}

// FP32 -> BF16 kernel
__global__ void _float_to_bfloat16_cuda_kernel(
__global__ inline void _float_to_bfloat16_cuda_kernel(
const float* __restrict__ input,
const int nrows,
const int ncols,
Expand All @@ -312,7 +312,7 @@ __global__ void _float_to_bfloat16_cuda_kernel(
}

// BF16 -> FP32 kernel
__global__ void _bfloat16_to_float_cuda_kernel(
__global__ inline void _bfloat16_to_float_cuda_kernel(
const uint16_t* __restrict__ input,
const int nrows,
const int ncols,
Expand Down Expand Up @@ -340,7 +340,7 @@ typedef union {

// TODO: add a flag later to control whether underflow
// flushes to 0 or clips to smallest denorm number.
__device__ uint8_t float_to_hfp8(
__device__ inline uint8_t float_to_hfp8(
float val_fp,
int ebits,
int mbits,
Expand Down Expand Up @@ -399,7 +399,7 @@ __device__ uint8_t float_to_hfp8(
return bfp8_val;
}

__device__ float
__device__ inline float
hfp8_to_float(uint8_t hfp8_val, int ebits, int mbits, int bias) {
fint32 val_out, sign, multiplier;

Expand Down Expand Up @@ -429,7 +429,7 @@ hfp8_to_float(uint8_t hfp8_val, int ebits, int mbits, int bias) {
return val_out.F;
}

__global__ void _float_to_hfp8_cuda_kernel(
__global__ inline void _float_to_hfp8_cuda_kernel(
const float* __restrict__ input,
const int nrows,
const int ncols,
Expand All @@ -452,7 +452,7 @@ __global__ void _float_to_hfp8_cuda_kernel(
}
}

__global__ void _hfp8_to_float_cuda_kernel(
__global__ inline void _hfp8_to_float_cuda_kernel(
const uint8_t* __restrict__ input,
const int nrows,
const int ncols,
Expand Down

0 comments on commit dd3abd3

Please sign in to comment.