Skip to content

Commit

Permalink
Disallows casting of activations with positive distribution to int8; …
Browse files Browse the repository at this point in the history
…bug fix.

PiperOrigin-RevId: 370162367
  • Loading branch information
shivaniag authored and copybara-github committed Apr 23, 2021
1 parent 412e0ef commit 0782952
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 19 deletions.
29 changes: 20 additions & 9 deletions aqt/jax/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,21 +743,32 @@ def quantized_dot(*,
weight_scale = jnp.array(1.0, dtype=SCALE_DTYPE)

# Use metadata context to annotate op metadata with quantization info
lhs_prec = None if act_hparams is None else act_hparams.prec
rhs_prec = None if weight_params is None else weight_params.prec
act_prec = None if act_hparams is None else act_hparams.prec
act_has_symm_distribution = act_hparams is not None and (
act_hparams.input_distribution
== QuantOps.ActHParams.InputDistribution.symmetric)
weight_prec = None if weight_params is None else weight_params.prec

# To decide whether to use an integer-domain dot operation, we first check
# if the static quantization parameters are compatible with it by seeing if
# they request that both inputs be quantized 8bits or less. Then check if
# the dynamic parameters are compatible with it. ie, in a training run with
# quantization enabled, are we past the activation start step yet.
if lhs_prec is None or rhs_prec is None or lhs_prec > 8 or rhs_prec > 8:
use_int8_to_int32_dot = False
else:
# is_act_quantized might be an instance of a Jax tracer instead of a
# Python boolean since it is generally computed from a dynamic input to a
# JITted Jax function. Thus we use '&' instead of 'and'.
use_int8_to_int32_dot = prefer_int8_to_int32_dot & is_weight_quantized & is_act_quantized

# We also do not use int8_to_int32_dot if activation has positive
# distribution and prec=8, since we would not be able to fit uint8 range in
# int8.
# TODO(shivaniagrawal): A proper solution for this would be to have mixed
# dot(uint8, int8) -> int32 in XLA.
weight_fits_in_int8 = is_weight_quantized and (weight_prec is not None and
weight_prec <= 8)
# is_act_quantized might be an instance of a Jax tracer instead of a
# Python boolean since it is generally computed from a dynamic input to a
# JITted Jax function. Thus we use '&' instead of 'and'.
act_prec_fits_int8 = act_prec is not None and (
(act_prec == 8 and act_has_symm_distribution) or (act_prec < 8))
act_fits_in_int8 = is_act_quantized & act_prec_fits_int8
use_int8_to_int32_dot = prefer_int8_to_int32_dot & weight_fits_in_int8 & act_fits_in_int8

metadata_context = contextlib.suppress()
with metadata_context:
Expand Down
26 changes: 16 additions & 10 deletions aqt/jax/quantization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,17 +566,22 @@ def assert_is_integer_in_range(self, x, *, prec, distribution):
f'and {distribution} distribution.')

@parameterized.parameters(
dict(act_distribution='symmetric', prefer_int8_to_int32_dot=True),
dict(act_distribution='positive', prefer_int8_to_int32_dot=True),
dict(act_distribution='symmetric', prefer_int8_to_int32_dot=False))
dict(act_distribution='symmetric', prefer_int8_to_int32_dot=True, prec=4),
dict(act_distribution='symmetric', prefer_int8_to_int32_dot=True, prec=8),
dict(act_distribution='positive', prefer_int8_to_int32_dot=True, prec=4),
dict(act_distribution='positive', prefer_int8_to_int32_dot=True, prec=8),
dict(
act_distribution='symmetric', prefer_int8_to_int32_dot=False, prec=4))
@mock.patch.object(jax.lax, 'dot_general')
def test_lax_dot_has_integer_inputs_in_quantized_dot(
self, mock_dot_general, act_distribution, prefer_int8_to_int32_dot):
weight_params = QuantOps.WeightParams(prec=4, axis=(0,))
def test_lax_dot_has_integer_inputs_in_quantized_dot(self, mock_dot_general,
act_distribution,
prefer_int8_to_int32_dot,
prec):
weight_params = QuantOps.WeightParams(prec=prec, axis=(0,))
act_params = QuantOps.ActHParams(
input_distribution=act_distribution,
bounds=jnp.array([[3.0, 1.5]]),
prec=4)
prec=prec)
act = self.lhs
if act_distribution == 'positive':
act = jnp.abs(act)
Expand All @@ -599,10 +604,11 @@ def test_lax_dot_has_integer_inputs_in_quantized_dot(
prefer_int8_to_int32_dot=prefer_int8_to_int32_dot)
act_inputs, weight_inputs = mock_dot_general.call_args[0]
self.assert_is_integer_in_range(
act_inputs, prec=4, distribution=act_distribution)
act_inputs, prec=prec, distribution=act_distribution)
self.assert_is_integer_in_range(
weight_inputs, prec=4, distribution='symmetric')
if prefer_int8_to_int32_dot:
weight_inputs, prec=prec, distribution='symmetric')
if prefer_int8_to_int32_dot and not (act_distribution == 'positive' and
prec == 8):
expected_input_dtype = jnp.int8
else:
expected_input_dtype = jnp.float32
Expand Down

0 comments on commit 0782952

Please sign in to comment.