Skip to content

Commit

Permalink
fix bf16&fp16 quantize to nf4&fp4 (PaddlePaddle#1805)
Browse files Browse the repository at this point in the history
  • Loading branch information
ceci3 authored Nov 9, 2023
1 parent e085e0d commit a08750c
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 29 deletions.
24 changes: 24 additions & 0 deletions csrc/lc/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,30 @@ typedef enum LC_DataType_t
template <typename T, int DATA_TYPE> void quantize_blockwise(const float * code, const T *A, float *absmax, unsigned char *out, int blocksize, int n);
template<typename T, int DATA_TYPE> void dequantize_blockwise(const float *code, const unsigned char *A, float *absmax, T *out, int block_size, int n);

template <paddle::DataType D>
class PDTraits;

template <>
class PDTraits<paddle::DataType::FLOAT32> {
public:
typedef float DataType;
typedef float data_t;
};

template <>
class PDTraits<paddle::DataType::FLOAT16> {
public:
typedef half DataType;
typedef paddle::float16 data_t;
};

template <>
class PDTraits<paddle::DataType::BFLOAT16> {
public:
typedef __nv_bfloat16 DataType;
typedef paddle::bfloat16 data_t;
};


#define CUDA_CHECK_RETURN(value) { \
cudaError_t _m_cudaStat = value; \
Expand Down
49 changes: 22 additions & 27 deletions csrc/lc/quantize_blockwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -362,42 +362,37 @@ MAKE_kQuantizeBlockwise(__nv_bfloat16, 64, 2, NF4)



template <typename T, int DATA_TYPE> void quantize_blockwise(const float *code, const T *A, float *absmax, unsigned char *out, int blocksize, int n)
template <paddle::DataType D, int DATA_TYPE> void quantize_blockwise(const float *code, const paddle::Tensor& A, float *absmax, unsigned char *out, int blocksize, int n)
{
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;

int num_blocks = n/blocksize;
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;

const DataType_* A_data = reinterpret_cast<const DataType_*>(A.data<data_t>());
if(blocksize == 4096)
kQuantizeBlockwise<T, 4096, 4, 0><<<num_blocks, 1024>>>(code, A, absmax, out, n);
kQuantizeBlockwise<DataType_, 4096, 4, 0><<<num_blocks, 1024>>>(code, A_data, absmax, out, n);
else if(blocksize == 2048)
kQuantizeBlockwise<T, 2048, 4, DATA_TYPE><<<num_blocks, 512>>>(code, A, absmax, out, n);
kQuantizeBlockwise<DataType_, 2048, 4, DATA_TYPE><<<num_blocks, 512>>>(code, A_data, absmax, out, n);
else if(blocksize == 1024)
kQuantizeBlockwise<T, 1024, 4, DATA_TYPE><<<num_blocks, 256>>>(code, A, absmax, out, n);
kQuantizeBlockwise<DataType_, 1024, 4, DATA_TYPE><<<num_blocks, 256>>>(code, A_data, absmax, out, n);
else if(blocksize == 512)
kQuantizeBlockwise<T, 512, 2, DATA_TYPE><<<num_blocks, 256>>>(code, A, absmax, out, n);
kQuantizeBlockwise<DataType_, 512, 2, DATA_TYPE><<<num_blocks, 256>>>(code, A_data, absmax, out, n);
else if(blocksize == 256)
kQuantizeBlockwise<T, 256, 2, DATA_TYPE><<<num_blocks, 128>>>(code, A, absmax, out, n);
kQuantizeBlockwise<DataType_, 256, 2, DATA_TYPE><<<num_blocks, 128>>>(code, A_data, absmax, out, n);
else if(blocksize == 128)
kQuantizeBlockwise<T, 128, 2, DATA_TYPE><<<num_blocks, 64>>>(code, A, absmax, out, n);
kQuantizeBlockwise<DataType_, 128, 2, DATA_TYPE><<<num_blocks, 64>>>(code, A_data, absmax, out, n);
else if(blocksize == 64)
kQuantizeBlockwise<T, 64, 2, DATA_TYPE><<<num_blocks, 32>>>(code, A, absmax, out, n);
kQuantizeBlockwise<DataType_, 64, 2, DATA_TYPE><<<num_blocks, 32>>>(code, A_data, absmax, out, n);
else
PD_THROW("only support blocksize is [64, 128, 256, 512, 1024, 2048, 4096].");


CUDA_CHECK_RETURN(cudaPeekAtLastError());
}

template void quantize_blockwise<half, General8bit>(const float *code, const half *A, float *absmax, unsigned char *out, int blocksize, int n);
template void quantize_blockwise<half, FP4>(const float *code, const half *A, float *absmax, unsigned char *out, int blocksize, int n);
template void quantize_blockwise<half, NF4>(const float *code, const half *A, float *absmax, unsigned char *out, int blocksize, int n);
template void quantize_blockwise<float, General8bit>(const float *code, const float *A, float *absmax, unsigned char *out, int blocksize, int n);
template void quantize_blockwise<float, FP4>(const float *code, const float *A, float *absmax, unsigned char *out, int blocksize, int n);
template void quantize_blockwise<float, NF4>(const float *code, const float *A, float *absmax, unsigned char *out, int blocksize, int n);
template void quantize_blockwise<__nv_bfloat16, General8bit>(const float *code, const __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, int n);
template void quantize_blockwise<__nv_bfloat16, FP4>(const float *code, const __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, int n);
template void quantize_blockwise<__nv_bfloat16, NF4>(const float *code, const __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, int n);

std::vector<paddle::Tensor> QuantizeBlockwise(const paddle::Tensor& input, const paddle::Tensor& code, int blocksize, std::string quant_type) {
int n = input.numel();
std::vector<int64_t> out_shape = input.shape();
Expand All @@ -410,28 +405,28 @@ std::vector<paddle::Tensor> QuantizeBlockwise(const paddle::Tensor& input, const
switch(input.type()) {
case paddle::DataType::FLOAT32:
if (quant_type == "8bit")
quantize_blockwise<float, General8bit>(code.data<float>(), input.data<float>(), absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::FLOAT32, General8bit>(code.data<float>(), input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
else if (quant_type == "nf4") {
quantize_blockwise<float, NF4>(NULL, input.data<float>(), absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::FLOAT32, NF4>(NULL, input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
}
else if (quant_type == "fp4")
quantize_blockwise<float, FP4>(NULL, input.data<float>(), absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::FLOAT32, FP4>(NULL, input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
return {out, absmax};
case paddle::DataType::FLOAT16:
if (quant_type == "8bit")
quantize_blockwise<half, General8bit>(code.data<float>(), input.data<half>(), absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::FLOAT16, General8bit>(code.data<float>(), input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
else if (quant_type == "nf4")
quantize_blockwise<half, NF4>(NULL, input.data<half>(), absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::FLOAT16, NF4>(NULL, input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
else if (quant_type == "fp4")
quantize_blockwise<half, FP4>(NULL, input.data<half>(), absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::FLOAT16, FP4>(NULL, input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
return {out, absmax};
case paddle::DataType::BFLOAT16:
if (quant_type == "8bit")
quantize_blockwise<__nv_bfloat16, General8bit>(code.data<float>(), input.data<__nv_bfloat16>(), absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::BFLOAT16, General8bit>(code.data<float>(), input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
else if (quant_type == "nf4")
quantize_blockwise<__nv_bfloat16, NF4>(NULL, input.data<__nv_bfloat16>(), absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::BFLOAT16, NF4>(NULL, input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
else if (quant_type == "fp4")
quantize_blockwise<__nv_bfloat16, FP4>(NULL, input.data<__nv_bfloat16>(), absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
quantize_blockwise<paddle::DataType::BFLOAT16, FP4>(NULL, input, absmax.data<float>(), out.data<unsigned char>(), blocksize, n);
return {out, absmax};

default:
Expand Down
6 changes: 4 additions & 2 deletions paddleslim/lc/quantizers/quant_func.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import paddle
from paddleslim_ops import quant_blockwise, dequant_blockwise

Expand Down Expand Up @@ -90,6 +91,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
for i in range(gap):
values.append(0)
values.sort()
code = paddle.to_tensor(values)
code /= code.max()

return code
Expand All @@ -110,15 +112,15 @@ def dequantize_fp4(x, absmax, blocksize):
def quantize_8bit(x, code, blocksize, quant_type="fp8"):
if code is None:
if quant_type=="fp8":
code = paddle.to_tensor(create_fp8_map(signed=True, exponent_bits=2, precision_bits=1, total_bits=4))
code = create_fp8_map(signed=True, exponent_bits=2, precision_bits=1, total_bits=4)
else:
code = paddle.to_tensor(create_dynamic_map())
return quant_blockwise(x, code, blocksize=blocksize, quant_type="8bit")

def dequantize_8bit(x, code, absmax, blocksize, quant_type="fp8"):
if code is None:
if quant_type=="fp8":
code = paddle.to_tensor(create_fp8_map(signed=True, exponent_bits=2, precision_bits=1, total_bits=4))
code = create_fp8_map(signed=True, exponent_bits=2, precision_bits=1, total_bits=4)
else:
code = paddle.to_tensor(create_dynamic_map())

Expand Down
47 changes: 47 additions & 0 deletions tests/lc/test_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import sys
sys.path.append("../../")
import numpy as np
import unittest
import paddle
from paddleslim.lc.layers import NF4Linear, FP4Linear
from paddleslim.lc.quantizers.quant_func import quantize_nf4, quantize_fp4, dequantize_nf4, dequantize_fp4, quantize_8bit, dequantize_8bit

class NF4(unittest.TestCase):
def setUp(self):
self.quant_type = "nf4"
self.blocksize = 64

def test_nf4_fp16(self):
a = paddle.uniform([2, 64], dtype="float16")
nf4_a, scale_a = quantize_nf4(a, self.blocksize)
fp16_a = dequantize_nf4(nf4_a, scale_a, self.blocksize).cast("float16")

class FP4(unittest.TestCase):
def setUp(self):
self.quant_type = "fp4"
self.blocksize = 64

def test_fp4_fp16(self):
a = paddle.uniform([2, 64], dtype="float16")
nf4_a, scale_a = quantize_fp4(a, self.blocksize)
fp16_a = dequantize_fp4(nf4_a, scale_a, self.blocksize).cast("float16")

class BIT8(unittest.TestCase):
def setUp(self):
self.quant_type = "fp8"
self.blocksize = 64

def test_fp8_fp16(self):
a = paddle.uniform([2, 64], dtype="float16")
nf4_a, scale_a = quantize_8bit(a, None, self.blocksize, quant_type="fp8")
fp16_a = dequantize_8bit(nf4_a, None, scale_a, self.blocksize, quant_type="fp8").cast("float16")

def test_dynamic_fp8_fp16(self):
a = paddle.uniform([2, 64], dtype="float16")
nf4_a, scale_a = quantize_8bit(a, None, self.blocksize, quant_type="dynamic_fp8")
fp16_a = dequantize_8bit(nf4_a, None, scale_a, self.blocksize, quant_type="dynamic_fp8").cast("float16")

if __name__ == '__main__':
unittest.main()


0 comments on commit a08750c

Please sign in to comment.