Skip to content

Commit

Permalink
Use QTensor with quantized FC operator (pytorch#19541)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#19541

For the quantized FC operator, replace the tuple (Tensor, scale, zero_point) with QTensor.

Differential Revision: D14900407

fbshipit-source-id: 164df38f3564e0a68af21b9fedaba98a44ca1453
  • Loading branch information
jianyuh authored and facebook-github-bot committed May 8, 2019
1 parent 8defcbf commit e8cdfb5
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 72 deletions.
16 changes: 7 additions & 9 deletions aten/src/ATen/cpp_custom_type_hack.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
//
// Template argument <T> has to be registered with CAFFE_KNOWN_TYPE mechanism.

#include "ATen/ATen.h"
#include <ATen/ATen.h>

namespace at {
namespace cpp_custom_type_hack {

template<typename T>
template <typename T>
T& cast(const Tensor& packed) {
AT_CHECK(
packed.scalar_type() == kByte, "Expected temporary cpp type wrapper");
Expand All @@ -24,23 +24,21 @@ T& cast(const Tensor& packed) {
return *reinterpret_cast<T*>(packed.storage().data_ptr().get());
}

template<typename T>
template <typename T>
Tensor create(std::unique_ptr<T> ptr, TensorOptions options) {
// We store this instance away in a Tensor and register a deleter function
// so that we do not leak memory. On the other side, we pull out the storage's
// data_ptr and get the right typed pointer.
void* raw_ptr = ptr.release();
at::DataPtr at_ptr(
raw_ptr,
raw_ptr,
caffe2::TypeMeta::Make<T>().deleteFn(),
at::kCPU);
raw_ptr, raw_ptr, caffe2::TypeMeta::Make<T>().deleteFn(), at::kCPU);

// size doesn't really matter, but we can align it to the actual size
// returning variables because one likely want to use this hack from python
auto retval = at::empty({sizeof(T)}, options.device(kCPU).dtype(at::kByte));
retval.storage().set_data_ptr(std::move(at_ptr));
return retval;
}
}
}

} // namespace cpp_custom_type_hack
} // namespace at
12 changes: 12 additions & 0 deletions aten/src/ATen/native/quantized/cpu/fbgemm_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,18 @@
struct FBGEMM_API PackedFCWeight {
std::unique_ptr<fbgemm::PackBMatrix<int8_t>> w;
std::vector<int32_t> col_offsets;
float w_scale;
int w_zp;
};

// Convert the weight from uint8 to int8.
static void convert_uint8_int8(int K, int N, const uint8_t* src_uint8, int8_t* dst_int8) {
for (size_t i = 0; i < N; ++i) {
for (size_t j = 0; j < K; ++j) {
dst_int8[i * K + j] =
static_cast<int8_t>(static_cast<int32_t>(src_uint8[i * K + j]) - 128);
}
}
}

#endif // USE_FBGEMM
40 changes: 19 additions & 21 deletions aten/src/ATen/native/quantized/cpu/qfc.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#include <ATen/ATen.h>
#include <ATen/core/Type.h>
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/cpp_custom_type_hack.h>
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
#include <ATen/quantized/Quantizer.h>

#include <algorithm>
#include <tuple>

namespace at {
namespace native {
Expand All @@ -13,13 +14,9 @@ namespace {
class QFCInt8 final : public c10::OperatorKernel {
public:
#ifdef USE_FBGEMM
std::tuple<at::Tensor, double, int64_t> operator()(
at::Tensor operator()(
at::Tensor input,
double input_scale,
int64_t input_zero_point,
at::Tensor packed_weight,
double weight_scale,
int64_t weight_zero_point,
at::Tensor bias,
double output_scale,
int64_t output_zero_point) {
Expand All @@ -33,7 +30,8 @@ class QFCInt8 final : public c10::OperatorKernel {

// TODO: contiguous is called for further jit optimizations.
auto input_contig = input.contiguous();
const auto* input_ptr = input_contig.data<uint8_t>();
const auto* input_ptr =
reinterpret_cast<uint8_t*>(input_contig.data<c10::qint8>());

AT_ASSERT(input.dim() >= 2);
// C(output) = A(input) x B(weight), where C, A, B are M x N, M x K, K x N
Expand All @@ -55,11 +53,11 @@ class QFCInt8 final : public c10::OperatorKernel {
AT_ASSERT(bias.size(0) == N);
AT_ASSERT(bias.dim() == 1);

float input_scale_float = static_cast<float>(input_scale);
int32_t input_zero_point_int32 = static_cast<int32_t>(input_zero_point);
float input_scale_float = input.q_scale().toFloat();
int32_t input_zero_point_int32 = input.q_zero_point().toInt();

float weight_scale_float = static_cast<float>(weight_scale);
int32_t weight_zero_point_int32 = static_cast<int32_t>(weight_zero_point);
float weight_scale_float = pack_ptr.w_scale;
int32_t weight_zero_point_int32 = pack_ptr.w_zp;

float output_multiplier_float = (input_scale_float * weight_scale_float) /
static_cast<float>(output_scale);
Expand Down Expand Up @@ -113,31 +111,31 @@ class QFCInt8 final : public c10::OperatorKernel {
/*nCol=*/N);

// Allocate output Tensor and a buffer for fbgemmPacked to use
auto output = at::zeros({M, N}, input.options().dtype(at::kByte));
auto output = _empty_affine_quantized(
{M, N},
at::device(kCPU).dtype(kQInt8),
output_scale,
output_zero_point);

auto buffer = at::zeros_like(output, output.options().dtype(at::kInt));

// Do the GEMM
fbgemm::fbgemmPacked(
/*packA=*/packA,
/*packB=*/*packB,
/*C=*/output.data<uint8_t>(),
/*C=*/reinterpret_cast<uint8_t*>(output.data<c10::qint8>()),
/*C_buffer=*/buffer.data<int32_t>(),
/*ldc=*/N,
/*outProcess=*/outputProcObj,
/*thread_id=*/0,
/*num_threads=*/1);

return std::make_tuple(output, output_scale, output_zero_point);
return output;
}
#else // USE_FBGEMM
std::tuple<at::Tensor, double, int64_t> operator()(
at::Tensor operator()(
at::Tensor /* input */,
double /* input_scale */,
int64_t /* input_zero_point */,
at::Tensor /* packed_weight */,
double /* weight_scale */,
int64_t /* weight_zero_point */,
at::Tensor /* bias */,
double /* output_scale */,
int64_t /* output_zero_point */) {
Expand All @@ -151,9 +149,9 @@ class QFCInt8 final : public c10::OperatorKernel {
};

static auto registry = c10::RegisterOperators().op(
"quantized::fbgemm_linear(Tensor X, float X_scale, int X_zero_point, Tensor W_prepack, float W_scale, int W_zero_point, Tensor b, float Y_scale_i, int Y_zero_point_i) -> (Tensor Y, float Y_scale_o, int Y_zero_point_o)",
"quantized::fbgemm_linear(Tensor X, Tensor W_prepack, Tensor b, float Y_scale_i, int Y_zero_point_i) -> Tensor Y",
c10::kernel<QFCInt8>(),
c10::dispatchKey(CPUTensorId()));
c10::dispatchKey(QuantizedCPUTensorId()));
} // namespace
} // namespace native
} // namespace at
27 changes: 17 additions & 10 deletions aten/src/ATen/native/quantized/cpu/qfc_prepack.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#include <ATen/ATen.h>
#include <ATen/core/Type.h>
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/cpp_custom_type_hack.h>
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
#include <ATen/quantized/Quantizer.h>

#include <algorithm>
#include <vector>
Expand All @@ -20,7 +22,7 @@ namespace {
class QFCPackWeightInt8 final : public c10::OperatorKernel {
public:
#ifdef USE_FBGEMM
// Calculate the column offsets
// Calculate the column offsets.
// Note this includes the sum of the columns as well as the scalar term
// B_zero_point * K, whereas the row_offsets created by
// PackAWithQuantRowOffset is only the sum of the A rows.
Expand All @@ -39,15 +41,20 @@ class QFCPackWeightInt8 final : public c10::OperatorKernel {
}
}

at::Tensor operator()(at::Tensor weight, int64_t weight_zero_point) {
at::Tensor operator()(at::Tensor weight) {
auto N = weight.size(0);
auto K = weight.size(1);

int32_t weight_zero_point_int32 = static_cast<int32_t>(weight_zero_point);
int32_t weight_zero_point_int32 = weight.q_zero_point().toInt() - 128;

// TODO: contiguous is called for further JIT optimizations.
auto weight_contig = weight.contiguous();
auto weight_ptr_int8 = weight_contig.data<int8_t>();

std::vector<int8_t> weight_int8(K * N);
int8_t* weight_ptr_int8 = weight_int8.data();
uint8_t* weight_ptr_uint8 =
reinterpret_cast<uint8_t*>(weight_contig.data<c10::qint8>());
convert_uint8_int8(K, N, weight_ptr_uint8, weight_ptr_int8);

std::vector<int32_t> col_offsets(N);
calc_col_offsets_transpose(
Expand All @@ -66,16 +73,16 @@ class QFCPackWeightInt8 final : public c10::OperatorKernel {
/*ld=*/K,
/*pmat=*/nullptr, // PackBMatrix manages ownership of pmat
/*groups=*/1),
col_offsets});
col_offsets,
weight.q_scale().toFloat(),
weight_zero_point_int32});

// TODO: we will need to replace this with torchscript classes at a later
// point.
return cpp_custom_type_hack::create(std::move(ret_ptr), weight.options());
}
#else // USE_FBGEMM
at::Tensor operator()(
at::Tensor /* weight */,
int64_t /* weight_zero_point */
at::Tensor operator()(at::Tensor /* weight */
) {
// We make a strong guarantee that models using these operators will have
// the same numerics across different machines. Therefore, we do not provide
Expand All @@ -87,9 +94,9 @@ class QFCPackWeightInt8 final : public c10::OperatorKernel {
};

static auto registry = c10::RegisterOperators().op(
"quantized::fbgemm_linear_prepack(Tensor W, int W_zero_point) -> Tensor W_prepack",
"quantized::fbgemm_linear_prepack(Tensor W) -> Tensor W_prepack",
c10::kernel<QFCPackWeightInt8>(),
c10::dispatchKey(CPUTensorId()));
c10::dispatchKey(QuantizedCPUTensorId()));
} // namespace
} // namespace native
} // namespace at
64 changes: 32 additions & 32 deletions test/test_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,70 +198,70 @@ def test_qfc(self):
X_zp = 5
X_value_min = 0
X_value_max = 225
X_q = np.round(
X_q0 = np.round(
np.random.rand(batch_size, input_channels) * (X_value_max - X_value_min)
+ X_value_min
).astype(np.uint8)

W_scale = 0.4
# W_zp is the zero point for int8 quantization.
W_zp = 2
W_value_min = -128
W_value_max = 127
W_q = np.round(
W_q0 = np.round(
np.random.rand(output_channels, input_channels)
* (W_value_max - W_value_min)
+ W_value_min
).astype(np.int8)

b_q = np.round(np.random.randn(output_channels) * 10 - 10).astype(np.int32)

# Compare X_scale * W_scale * input_channels * X_value_max * W_value_max with
# Y_scale * 255 (max for uint8).
Y_scale = 125.1234
Y_zp = 5

avoid_vpmaddubsw_overflow_fc(
batch_size,
input_channels,
output_channels,
X_q,
X_q0,
X_value_min,
X_value_max,
W_q,
W_q0,
W_value_min,
W_value_max,
)

X = torch.from_numpy(_dequantize(X_q0, X_scale, X_zp)).to(dtype=torch.float)
W = torch.from_numpy(_dequantize(W_q0, W_scale, W_zp)).to(dtype=torch.float)

X_q = X.quantize_linear(scale=X_scale, zero_point=X_zp)
# W_zp + 128 is the zero point for uint8 quantization.
W_q = W.quantize_linear(scale=W_scale, zero_point=W_zp + 128)
b_q = torch.round(torch.rand(output_channels) * 10 - 10).to(dtype=torch.int32)

# Compare X_scale * W_scale * input_channels * X_value_max * W_value_max with
# Y_scale * 255 (max for uint8).
Y_scale = 125.1234
Y_zp = 5

# Reference quantized FC operator
Y_q_ref = qfc_ref(X_q, X_scale, X_zp, W_q, W_scale, W_zp, b_q, Y_scale, Y_zp)
Y_q_ref = qfc_ref(X_q0, X_scale, X_zp, W_q0, W_scale, W_zp, b_q.numpy(), Y_scale, Y_zp)

# Weight prepacking operator for quantized FC
W_prepack = qfc_prepack(torch.from_numpy(W_q), W_zp)
[Y_q, Y_scale, Y_zp] = qfc(
torch.from_numpy(X_q),
X_scale,
X_zp,
W_prepack,
W_scale,
W_zp,
torch.from_numpy(b_q),
Y_scale,
Y_zp,
)
W_prepack = qfc_prepack(W_q)
# Quantized FC operator with prepacked weight
Y_q = qfc(X_q, W_prepack, b_q, Y_scale, Y_zp)

# Y_q_ref_real = _dequantize(Y_q_ref, Y_scale, Y_zp)
# Y_q_real = Y_q.dequantize()

# Assert equal
np.testing.assert_equal(Y_q_ref, Y_q.numpy())
np.testing.assert_equal(Y_q_ref, Y_q.int_repr().numpy())

# Reference quantized result from PyTorch Linear operator
W_fp32 = _dequantize(W_q, W_scale, W_zp).astype(np.float)
X_fp32 = _dequantize(X_q, X_scale, X_zp).astype(np.float)
b_fp32 = _dequantize(b_q, W_scale * X_scale, 0).astype(np.float)
Y_fp32_ref = F.linear(torch.from_numpy(X_fp32), torch.from_numpy(W_fp32),
torch.from_numpy(b_fp32)).numpy()
Y_q_ref2 = _quantize(Y_fp32_ref, Y_scale, Y_zp)
W_fp32 = W_q.dequantize().to(dtype=torch.float)
X_fp32 = X_q.dequantize().to(dtype=torch.float)
b_fp32 = torch.from_numpy(_dequantize(b_q.numpy(), W_scale * X_scale, 0).astype(np.float)).to(dtype=torch.float)
Y_fp32_ref = F.linear(X_fp32, W_fp32, b_fp32)
Y_q_ref2 = Y_fp32_ref.quantize_linear(Y_scale, Y_zp)

# Assert equal
np.testing.assert_equal(Y_q_ref2, Y_q.numpy())
np.testing.assert_equal(Y_q_ref2.int_repr().numpy(), Y_q.int_repr().numpy())


if __name__ == "__main__":
Expand Down

0 comments on commit e8cdfb5

Please sign in to comment.