Skip to content

Commit

Permalink
Support batched embeddings for 8Bit embedding bag quantization (pytor…
Browse files Browse the repository at this point in the history
…ch#55343)

Summary:
Pull Request resolved: pytorch#55343

Add support for N-dimensioned batches of 2D embedding bags to qembeddingbag_byte_prepack and qembeddingbag_byte_unpack.

This is currently supported in C2 via caffe2::Fused8BitRowwiseQuantizedToFloat and caffe2::FloatToFused8BitRowwiseQuantized, but is being supported in PyTorch operators via this change.

Test Plan: buck test //caffe2/test:quantization  -- test_embedding_bag_byte

Reviewed By: radkris-git

Differential Revision: D27480917

fbshipit-source-id: 9878751c6cee8a55909fe58a3e8c222ea31c20bb
  • Loading branch information
b-koopman authored and facebook-github-bot committed Apr 12, 2021
1 parent 80d04f9 commit db394ef
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 25 deletions.
91 changes: 83 additions & 8 deletions aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,16 +125,89 @@ namespace {
// Note - This is a temporary pack function for embedding bag which quantizes
// and packs the float weight tensor. In the next step it will be replaced by a
// quantize and pack function once we support FP scale and FP zero_point
//
// Python example examining a packed 8bit zero_point and scale:
//
// >> x = torch.from_numpy(np.array([[[10, 20], [30, 40]],[[50, 60], [70, 80]]], dtype=np.float32))
// >> x_packed = torch.ops.quantized.embedding_bag_byte_prepack(x)
//
// # Pull out and examine packed scales, zero_points and values
// >> zero_points = x_packed[:,:,-4:].numpy()
// >> scales = x_packed[:,:,-8:-4].numpy()
// >> values = x_packed[:,:,:-8].numpy()
//
// >> zero_points
// array([[[ 0, 0, 32, 65],
// [ 0, 0, 240, 65]],
//
// [[ 0, 0, 72, 66],
// [ 0, 0, 140, 66]]], dtype=uint8)
//
// >> scales
// array([[[161, 160, 32, 61],
// [161, 160, 32, 61]],
//
// [[161, 160, 32, 61],
// [161, 160, 32, 61]]], dtype=uint8)
// >> values
// array([[[ 0, 255],
// [ 0, 255]],
//
// [[ 0, 255],
// [ 0, 255]]], dtype=uint8)
//
// # Convert 4 byte packed scales and zero_points to float
// # and apply against values in order to recover unquantized values.
// def bytes2float(arr):
// packed_hex = bytearray(arr)
// return struct.unpack('f', packed_hex)
//
// >> float_zero_points = np.apply_along_axis(bytes2float, 2, zero_points)
// >> float_zero_points
// array([[[10.],
// [30.]],
//
// [[50.],
// [70.]]])
// >> float_scales = np.apply_along_axis(bytes2float, 2, scales)
// >> float_scales
// array([[[0.03921569],
// [0.03921569]],
//
// [[0.03921569],
// [0.03921569]]])
// >> values * float_scales + float_zero_points
// array([[[10. , 20.00000035],
// [30. , 40.00000035]],
//
// [[50. , 60.00000035],
// [70. , 80.00000035]]])
Tensor qembeddingbag_byte_prepack(const Tensor& weight) {
int64_t embedding_rows = weight.size(0);
int64_t embedding_cols = weight.size(1);
// The "last" dimension of an N-Dimensioned batch of embedding bags is
// quantization channel. E.g. for a 2D embedding bag, this has
// [ row, col ] dimensions, for batched of embedding bags, dimensions might be
// [ batch, row, col ].
//
// Python Batched Embedding Example:
// weights = torch.from_numpy((np.random.random_sample((
// 2, 10, 3)).squeeze() + 1).astype(np.float32))
// assert(weights.size() == torch.Size([2, 10, 3]))
// # NOTE: 8 bytes (columns) are added due to fp32 zero_point and scales
// packed_weights = torch.ops.quantized.embedding_bag_byte_prepack(weights)
// assert(packed_weights.size() == torch.Size([2, 10, 11]))

const auto weight_sizes = weight.sizes();
const auto cols_dim = weight_sizes.size() - 1;
const int32_t embedding_rows = c10::size_to_dim_(cols_dim, weight_sizes);
const int32_t embedding_cols = weight_sizes[cols_dim];
// Add 8 bytes per column to store FP32 scale and zero_point per row.
const int32_t output_columns = embedding_cols + 2 * sizeof(float);
Tensor weight_contig = weight.contiguous(weight.suggest_memory_format());

const float* weight_data = weight_contig.data_ptr<float>();
std::vector<int64_t> output_shape = {
embedding_rows,
embedding_cols +
8}; // extra 8 bytes to store FP scale and zero_point per row.
// Adjust output dimensions to account for FP32 scale and zero_points.
std::vector<int64_t> output_shape = weight_sizes.vec();
output_shape[cols_dim] = output_columns;

// Allocate output packed weights
auto output = at::empty(
Expand All @@ -144,16 +217,17 @@ Tensor qembeddingbag_byte_prepack(const Tensor& weight) {
auto* output_data = output.data_ptr<uint8_t>();

#ifdef USE_FBGEMM

at::parallel_for(
0, embedding_rows, 1, [&](int32_t start_idx, int32_t end_idx) {
for (int64_t row = start_idx; row < end_idx; ++row) {
fbgemm::FloatToFused8BitRowwiseQuantizedSBFloat(
weight_data + row * embedding_cols, 1,
embedding_cols, output_data + row * output_shape[1]);
embedding_cols, output_data + row * output_columns);
}
});

#else
size_t output_columns = output_shape[1];
constexpr float kEpsilon = 1e-8f;
for (std::size_t row = 0; row < embedding_rows; ++row) {
const float* input_row = weight_data + row * embedding_cols;
Expand All @@ -180,6 +254,7 @@ Tensor qembeddingbag_byte_prepack(const Tensor& weight) {
return output;
}

// TODO: Extend support to N-D batched embeddings, similar to qembeddingbag_byte_prepack
Tensor _qembeddingbag_nbit_prepack_helper(
const Tensor& weight,
int bit_width,
Expand Down
32 changes: 24 additions & 8 deletions aten/src/ATen/native/quantized/cpu/qembeddingbag_unpack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,31 @@ namespace native {
namespace {

Tensor qembeddingbag_byte_unpack(const Tensor& packed_weight) {
const auto input_rows = packed_weight.size(0);
const auto input_columns = packed_weight.size(1);

// The "last" dimension of an N-Dimensioned batch of embedding bags is
// quantization channel. E.g. for a 2D embedding bag, this has
// [ row, col ] dimensions, for batched of embedding bags, dimensions might be
// [ batch, row, col ].
//
// Python Batched Embedding Example:
// weights = torch.from_numpy((np.random.random_sample((
// 2, 10, 3)).squeeze() + 1).astype(np.float32))
// assert(weights.size() == torch.Size([2, 10, 3]))
// # NOTE: 8 bytes (columns) are added due to fp32 zero_point and scales
// packed_weights = torch.ops.quantized.embedding_bag_byte_prepack(weights)
// assert(packed_weights.size() == torch.Size([2, 10, 11]))
// unpacked_weights = torch.ops.quantized.embedding_bag_byte_unpack(packed_weights)
// assert(unpacked_weights.size() == torch.Size([2, 10, 3]))
const auto packed_weight_sizes = packed_weight.sizes();
const auto col_dim = packed_weight_sizes.size() - 1;
const int32_t input_rows = c10::size_to_dim_(col_dim, packed_weight_sizes);
const int32_t input_columns = packed_weight_sizes[col_dim];
// The last 2 values are used to store the FP32 scale and zero_point values
// per row.
int output_columns = input_columns - 2 * sizeof(float);
const int32_t output_columns = input_columns - 2 * sizeof(float);
const auto* input_data = packed_weight.data_ptr<uint8_t>();

const auto* input = packed_weight.data_ptr<uint8_t>();
std::vector<int64_t> output_shape = {input_rows, output_columns};
std::vector<int64_t> output_shape = packed_weight_sizes.vec();
output_shape[col_dim] = output_columns;
at::Tensor output = at::empty(
output_shape,
packed_weight.options().dtype(kFloat),
Expand All @@ -110,15 +126,15 @@ Tensor qembeddingbag_byte_unpack(const Tensor& packed_weight) {
0, input_rows, 1, [&](int32_t start_idx, int32_t end_idx) {
for (int64_t row = start_idx; row < end_idx; ++row) {
fbgemm::Fused8BitRowwiseQuantizedSBFloatToFloat(
input + row * input_columns,
input_data + row * input_columns,
1,
input_columns,
output_data + row * output_columns);
}
});
#else
for (std::size_t row = 0; row < input_rows; ++row) {
const std::uint8_t* input_row = input + row * input_columns;
const std::uint8_t* input_row = input_data + row * input_columns;
const float* input_row_scale_zp =
reinterpret_cast<const float*>(input_row + output_columns);
float* output_row = output_data + row * output_columns;
Expand Down
28 changes: 19 additions & 9 deletions test/quantization/test_quantized_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -3100,9 +3100,10 @@ def test_qlinear_unpack(self, W, use_channelwise):

@unittest.skipIf(sys.platform == "darwin", "Known test failure on Mac.")
class TestQuantizedEmbeddingOps(TestCase):
def _test_embedding_bag_unpack_fn(self, pack_fn, unpack_fn, num_embeddings, embedding_dim, bit_rate, optimized_qparams):
def _test_embedding_bag_unpack_fn(self, pack_fn, unpack_fn, num_embeddings, embedding_dim, bit_rate, optimized_qparams,
num_batches):
weights = torch.from_numpy((np.random.random_sample((
num_embeddings, embedding_dim)) + 1).astype(np.float32))
num_batches, num_embeddings, embedding_dim)).squeeze() + 1).astype(np.float32))
qtype = torch.quint8
if bit_rate == 8:
w_packed = pack_fn(weights)
Expand All @@ -3111,16 +3112,24 @@ def _test_embedding_bag_unpack_fn(self, pack_fn, unpack_fn, num_embeddings, embe
w_unpacked = unpack_fn(w_packed)

if bit_rate == 8 or bit_rate == 4:
obs_weights = weights
# Combine 3D embeddings (e.g. stacked combination of embeddings)
# in a dimension orthogonal to channels.
if(num_batches > 1):
stacked_shape = list(weights.size())
stacked_shape[1] *= stacked_shape[0]
obs_weights = weights.reshape(stacked_shape[1:])

# Check numerics of prepack function that accepts qtensor as input.
# We use min-max observer to mimic the quantization performed in the original function.
obs = PerChannelMinMaxObserver(dtype=torch.quint8, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0)
obs(weights)
obs(obs_weights)
# Get the scale and zero point for the weight tensor
qparams = obs.calculate_qparams()
if bit_rate == 4:
qtype = torch.quint4x2
# Quantize the weights to 8bits
qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=qtype)
qweight = torch.quantize_per_channel(obs_weights, qparams[0], qparams[1], axis=0, dtype=qtype)
real_packed_weight = torch.ops.quantized.embedding_bag_prepack(qweight)
self.assertEqual(isinstance(real_packed_weight, torch._C.ScriptObject), True)
unpacked_weight = torch.ops.quantized.embedding_bag_unpack(real_packed_weight)
Expand Down Expand Up @@ -3175,12 +3184,13 @@ def get_c2_weights(weights, engine_str):

""" Tests the correctness of the embedding_bag_8bit pack/unpack op against C2 """
@given(num_embeddings=st.integers(10, 100),
embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0),)
def test_embedding_bag_byte_unpack(self, num_embeddings, embedding_dim):
embedding_dim=st.integers(5, 50).filter(lambda x: x % 4 == 0),
num_batches=st.integers(1, 5))
def test_embedding_bag_byte_unpack(self, num_embeddings, embedding_dim, num_batches):
pack_fn = torch.ops.quantized.embedding_bag_byte_prepack
unpack_fn = torch.ops.quantized.embedding_bag_byte_unpack

self._test_embedding_bag_unpack_fn(pack_fn, unpack_fn, num_embeddings, embedding_dim, 8, False)
self._test_embedding_bag_unpack_fn(pack_fn, unpack_fn, num_embeddings, embedding_dim, 8, False, num_batches)

""" Tests the correctness of the embedding_bag_4bit pack/unpack op against C2 """
@given(num_embeddings=st.integers(10, 100),
Expand All @@ -3190,7 +3200,7 @@ def test_embedding_bag_4bit_unpack(self, num_embeddings, embedding_dim, optimize
pack_fn = torch.ops.quantized.embedding_bag_4bit_prepack
unpack_fn = torch.ops.quantized.embedding_bag_4bit_unpack

self._test_embedding_bag_unpack_fn(pack_fn, unpack_fn, num_embeddings, embedding_dim, 4, optimized_qparams)
self._test_embedding_bag_unpack_fn(pack_fn, unpack_fn, num_embeddings, embedding_dim, 4, optimized_qparams, 1)

""" Tests the correctness of the embedding_bag_2bit pack/unpack op against C2 """
@given(num_embeddings=st.integers(10, 100),
Expand All @@ -3200,7 +3210,7 @@ def test_embedding_bag_2bit_unpack(self, num_embeddings, embedding_dim, optimize
pack_fn = torch.ops.quantized.embedding_bag_2bit_prepack
unpack_fn = torch.ops.quantized.embedding_bag_2bit_unpack

self._test_embedding_bag_unpack_fn(pack_fn, unpack_fn, num_embeddings, embedding_dim, 2, optimized_qparams)
self._test_embedding_bag_unpack_fn(pack_fn, unpack_fn, num_embeddings, embedding_dim, 2, optimized_qparams, 1)


def embedding_bag_rowwise_offsets_run(
Expand Down

0 comments on commit db394ef

Please sign in to comment.