Skip to content

Commit

Permalink
Add rounding_mode in quantize_comm front-end (pytorch#2859)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2859

Add rounding_mode in quantize_comm
switch to use wrapper MX4 directly

Reviewed By: jwfromm

Differential Revision: D59704964

fbshipit-source-id: fdabb08b350f29a33cabb69994e097c743a00835
  • Loading branch information
spcyppt authored and facebook-github-bot committed Jul 17, 2024
1 parent 4276231 commit e08af85
Showing 1 changed file with 8 additions and 14 deletions.
22 changes: 8 additions & 14 deletions fbgemm_gpu/fbgemm_gpu/quantize_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,14 @@
fp32_to_bf16_with_clamp,
fp32_to_fp16_with_clamp,
fp32_to_hfp8_with_clamp,
fp32_to_mx4,
hfp8_to_fp32,
mx4_to_fp32,
RoundingMode,
)

from fbgemm_gpu.split_embedding_configs import SparseType

from torch.autograd.profiler import record_function # usort:skip
from dataclasses import dataclass

Expand Down Expand Up @@ -58,6 +62,7 @@ class QuantizationContext:
row_dim: int = ROW_DIM_DEFAULT
row_dim_quant: int = -1
mx_group_size: int = MX_GROUP_SIZE_DEFAULT
rounding_mode: RoundingMode = RoundingMode.ceil


def _quantize_tensor(
Expand Down Expand Up @@ -100,15 +105,8 @@ def _quantize_tensor(
return input_quant_all2all
elif comm_precision == SparseType.MX4:
mx_group_size = ctx.mx_group_size if ctx is not None else MX_GROUP_SIZE_DEFAULT
quantized_output = torch.ops.fbgemm.quantize_mx(
input=input_tensor,
scale_bits=8,
elem_ebits=2,
elem_mbits=3,
elem_max_norm=6.0,
mx_group_size=mx_group_size,
)
return quantized_output
rounding_mode = ctx.rounding_mode if ctx is not None else RoundingMode.ceil
return fp32_to_mx4(input_tensor, mx_group_size, rounding_mode=rounding_mode)
else:
raise ValueError(f"comm_precision={comm_precision} is not supported")

Expand Down Expand Up @@ -149,11 +147,7 @@ def _dequantize_tensor(
return dequant_tensor.view(-1)
elif comm_precision == SparseType.MX4:
mx_group_size = ctx.mx_group_size if ctx is not None else MX_GROUP_SIZE_DEFAULT
dequant_tensor = torch.ops.fbgemm.dequantize_mx(
input=quantized_tensor,
mx_group_size=mx_group_size,
)
return dequant_tensor.view(-1)
return mx4_to_fp32(quantized_tensor, mx_group_size)
else:
raise ValueError(f"comm_precision={comm_precision} is not supported")

Expand Down

0 comments on commit e08af85

Please sign in to comment.