Jax
and Flax
quantization libraries provides what you serve is what you train
quantization for convolution and matmul.
jax/imagenet
directory contains quantized ResNet model.
- quantization.quantized_dot: LAX.dot with optionally quantized weights and activations.
- quantization.quantized_dynamic_dot_general: LAX.dot general with optionally quantized dynamic inputs.
- quantization.quantized_sum: Sums a tensor while quantizing intermediate accumulations.
- quantization.dot_general_aqt: Adds quantization to LAX.dot_general with option to use integer dot.
- flax_layers.DenseAqt: Adds quantization to Flax Dot Module
- flax_layers.ConvAqt: Adds quantization to Flax Conv Module
- flax_layers.EmbedAqt: Adds quantization to Flax Embed Module.
- flax_layers.LayerNormAqt: Adds quantization support to the Flax LayerNorm layer
- flax_attention.MultiHeadDotProductAttentionAqt: Adds quantization to Flax Multi-head dot-product attention.