Skip to content

Commit

Permalink
[quant] Add testing coverage for 4-bit embedding_bag sparse lookup op (
Browse files Browse the repository at this point in the history
…pytorch#47328)

Summary:
Pull Request resolved: pytorch#47328

Extend tests to cover case for pruned weights with mapping table.
Support for 8-bits sparse lookup to follow

Test Plan:
python test/test_quantization.py TestQuantizedEmbeddingOps

Imported from OSS

Reviewed By: qizzzh

Differential Revision: D24719910

fbshipit-source-id: d31db6304f446104ee8c7b10b902accd2919a513
  • Loading branch information
supriyar authored and facebook-github-bot committed Nov 5, 2020
1 parent f19637e commit 433b55b
Showing 1 changed file with 47 additions and 18 deletions.
65 changes: 47 additions & 18 deletions test/quantization/test_quantized_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2969,7 +2969,7 @@ def embedding_bag_rowwise_offsets_run(
embedding_dim, num_offsets,
use_32bit_indices, use_32bit_offsets,
enable_per_sample_weights,
include_last_offset, atol, rtol):
include_last_offset, prune_weights, sparsity, atol, rtol):
pt_op = torch.ops.quantized.embedding_bag_byte_rowwise_offsets
pt_prepack_op = torch.ops.quantized.embedding_bag_byte_prepack
if bit_rate == 4:
Expand Down Expand Up @@ -3024,17 +3024,44 @@ def get_reference_result(
return embedding_bag(indices, offsets,
per_sample_weights=per_sample_weights)

mapping_table = np.zeros(num_embeddings, dtype=np.int32)
pruned_weights = weights
if prune_weights and bit_rate == 4:
# Prune and generate mapping table
num_compressed_rows = 0
unpruned_ids = []
for i in range(num_embeddings):
if np.random.uniform() < sparsity:
mapping_table[i] = -1
q_weights[i, :] = 0
weights[i, :] = 0
else:
mapping_table[i] = num_compressed_rows
num_compressed_rows += 1
unpruned_ids.append(i)
q_weights = q_weights[unpruned_ids]
pruned_weights = weights[unpruned_ids]
result = pt_op(q_weights,
indices.int() if use_32bit_indices else indices,
offsets.int() if use_32bit_offsets else offsets,
mode=0,
pruned_weights=prune_weights,
per_sample_weights=per_sample_weights,
compressed_indices_mapping=torch.tensor(mapping_table),
include_last_offset=include_last_offset)
else:
result = pt_op(q_weights,
indices.int() if use_32bit_indices else indices,
offsets.int() if use_32bit_offsets else offsets,
mode=0,
pruned_weights=prune_weights,
per_sample_weights=per_sample_weights,
include_last_offset=include_last_offset)

reference_result = get_reference_result(
num_embeddings, embedding_dim, include_last_offset, weights,
per_sample_weights, indices, offsets)
result = pt_op(
q_weights,
indices.int() if use_32bit_indices else indices,
offsets.int() if use_32bit_offsets else offsets,
mode=0,
per_sample_weights=per_sample_weights,
include_last_offset=include_last_offset,
)

torch.testing.assert_allclose(reference_result, result, atol=atol,
rtol=rtol)

Expand All @@ -3048,15 +3075,16 @@ def get_reference_result(
qdtype = torch.quint8
op = torch.ops.quantized.embedding_bag_byte
obs = PerChannelMinMaxObserver(dtype=qdtype, qscheme=torch.per_channel_affine_float_qparams, ch_axis=0)
obs(weights)
obs(pruned_weights)
# Get the scale and zero point for the weight tensor
qparams = obs.calculate_qparams()

# Quantize the weights to 8bits
qweight = torch.quantize_per_channel(weights, qparams[0], qparams[1], axis=0, dtype=qdtype)
qweight = torch.quantize_per_channel(pruned_weights, qparams[0], qparams[1], axis=0, dtype=qdtype)
packed_weight = torch.ops.quantized.embedding_bag_prepack(qweight)
result = op(packed_weight, indices, offsets, mode=0,
pruned_weights=prune_weights,
per_sample_weights=per_sample_weights,
compressed_indices_mapping=torch.tensor(mapping_table),
include_last_offset=include_last_offset)
torch.testing.assert_allclose(reference_result, result, atol=atol, rtol=rtol)

Expand All @@ -3080,8 +3108,8 @@ def test_embedding_bag_byte(self, num_embeddings,
self.embedding_bag_rowwise_offsets_run(
8, num_embeddings, embedding_dim, num_offsets,
use_32bit_indices, use_32bit_offsets,
enable_per_sample_weights, include_last_offset,
atol=0.005, rtol=1e-3)
enable_per_sample_weights, include_last_offset, prune_weights=False,
sparsity=0, atol=0.005, rtol=1e-3)

""" Tests the correctness of the embedding_bag_4bit quantized operator """
@given(num_embeddings=st.integers(10, 100),
Expand All @@ -3090,19 +3118,20 @@ def test_embedding_bag_byte(self, num_embeddings,
use_32bit_indices=st.booleans(),
use_32bit_offsets=st.booleans(),
enable_per_sample_weights=st.booleans(),
include_last_offset=st.booleans())
include_last_offset=st.booleans(),
sparsity=st.sampled_from([0.0, 0.5, 0.7]))
def test_embedding_bag_4bit(self, num_embeddings,
embedding_dim, num_offsets,
use_32bit_indices,
use_32bit_offsets,
enable_per_sample_weights,
include_last_offset):
include_last_offset, sparsity):
self.embedding_bag_rowwise_offsets_run(4, num_embeddings,
embedding_dim, num_offsets,
use_32bit_indices, use_32bit_offsets,
enable_per_sample_weights,
include_last_offset, atol=0.1,
rtol=1e-2)
include_last_offset, True, sparsity=sparsity,
atol=0.1, rtol=1e-2)

""" Tests the correctness of the quantized embedding lookup operator """
@given(num_embeddings=st.integers(10, 100),
Expand Down

0 comments on commit 433b55b

Please sign in to comment.