diff --git a/fbgemm_gpu/fbgemm_gpu/quantize_comm.py b/fbgemm_gpu/fbgemm_gpu/quantize_comm.py index bc68ae8db4..303625d172 100644 --- a/fbgemm_gpu/fbgemm_gpu/quantize_comm.py +++ b/fbgemm_gpu/fbgemm_gpu/quantize_comm.py @@ -23,10 +23,14 @@ fp32_to_bf16_with_clamp, fp32_to_fp16_with_clamp, fp32_to_hfp8_with_clamp, + fp32_to_mx4, hfp8_to_fp32, + mx4_to_fp32, + RoundingMode, ) from fbgemm_gpu.split_embedding_configs import SparseType + from torch.autograd.profiler import record_function # usort:skip from dataclasses import dataclass @@ -58,6 +62,7 @@ class QuantizationContext: row_dim: int = ROW_DIM_DEFAULT row_dim_quant: int = -1 mx_group_size: int = MX_GROUP_SIZE_DEFAULT + rounding_mode: RoundingMode = RoundingMode.ceil def _quantize_tensor( @@ -100,15 +105,8 @@ def _quantize_tensor( return input_quant_all2all elif comm_precision == SparseType.MX4: mx_group_size = ctx.mx_group_size if ctx is not None else MX_GROUP_SIZE_DEFAULT - quantized_output = torch.ops.fbgemm.quantize_mx( - input=input_tensor, - scale_bits=8, - elem_ebits=2, - elem_mbits=3, - elem_max_norm=6.0, - mx_group_size=mx_group_size, - ) - return quantized_output + rounding_mode = ctx.rounding_mode if ctx is not None else RoundingMode.ceil + return fp32_to_mx4(input_tensor, mx_group_size, rounding_mode=rounding_mode) else: raise ValueError(f"comm_precision={comm_precision} is not supported") @@ -149,11 +147,7 @@ def _dequantize_tensor( return dequant_tensor.view(-1) elif comm_precision == SparseType.MX4: mx_group_size = ctx.mx_group_size if ctx is not None else MX_GROUP_SIZE_DEFAULT - dequant_tensor = torch.ops.fbgemm.dequantize_mx( - input=quantized_tensor, - mx_group_size=mx_group_size, - ) - return dequant_tensor.view(-1) + return mx4_to_fp32(quantized_tensor, mx_group_size) else: raise ValueError(f"comm_precision={comm_precision} is not supported")