Skip to content

Commit

Permalink
Add FP32 <-> FP16/BF16 conversion benchmark (pytorch#1166)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1166

bf16_quant_pytorch is faster than bf16_quant_fbgemm

bf16_dequant_pytorch is faster than bf16_dequant_fbgemm

Reviewed By: jasonjk-park

Differential Revision: D37303359

fbshipit-source-id: 9b66737f843ad1eb1b427598a24dd87d02aec5e3
  • Loading branch information
jianyuh authored and facebook-github-bot committed Jun 23, 2022
1 parent 38993ce commit 243010f
Showing 1 changed file with 38 additions and 5 deletions.
43 changes: 38 additions & 5 deletions fbgemm_gpu/bench/quantize_ops_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,17 @@ def bench_impl(
"int2_quant": 0.0,
"fp8_143_quant": 0.0,
"fp8_152_quant": 0.0,
"fp16_quant": 0.0,
"bf16_quant_fbgemm": 0.0,
"bf16_quant_pytorch": 0.0,
"int8_dequant": 0.0,
"int4_dequant": 0.0,
"int2_dequant": 0.0,
"fp8_143_dequant": 0.0,
"fp8_152_dequant": 0.0,
"fp16_dequant": 0.0,
"bf16_dequant_fbgemm": 0.0,
"bf16_dequant_pytorch": 0.0,
}

benchmark = functools.partial(
Expand All @@ -62,6 +68,8 @@ def bench_impl(
)

input_data = torch.rand(num_rows, num_columns).float()
if torch.cuda.is_available():
input_data = input_data.cuda()

quant_data_8bit = torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized(input_data)
quant_data_4bit = torch.ops.fbgemm.FloatToFusedNBitRowwiseQuantizedSBHalf(
Expand All @@ -76,9 +84,11 @@ def bench_impl(
quant_data_fp8_152 = torch.ops.fbgemm.FloatToHFP8Quantized(
input_data, 5, 30, (2 - 2 ** (-2))
)

if torch.cuda.is_available():
input_data = input_data.cuda()
quant_data_fp16 = input_data.half()
quant_data_bf16_fbgemm = torch.ops.fbgemm.FloatToBfloat16Quantized(
input_data.contiguous()
)
quant_data_bf16_pytorch = input_data.bfloat16().view(torch.half)

average_time["int8_quant"], _ = benchmark(
torch.ops.fbgemm.FloatToFused8BitRowwiseQuantized,
Expand All @@ -88,7 +98,6 @@ def bench_impl(
torch.ops.fbgemm.FloatToFusedNBitRowwiseQuantizedSBHalf,
(input_data, 4),
)

average_time["int2_quant"], _ = benchmark(
torch.ops.fbgemm.FloatToFusedNBitRowwiseQuantizedSBHalf,
(input_data, 2),
Expand All @@ -101,12 +110,23 @@ def bench_impl(
torch.ops.fbgemm.FloatToHFP8Quantized,
(input_data, 5, 30, (2 - 2 ** (-2))),
)
average_time["fp16_quant"], _ = benchmark(
lambda tensor: tensor.half(),
(input_data,),
)
average_time["bf16_quant_fbgemm"], _ = benchmark(
torch.ops.fbgemm.FloatToBfloat16Quantized,
(input_data,),
)
average_time["bf16_quant_pytorch"], _ = benchmark(
lambda tensor: tensor.bfloat16().view(torch.half),
(input_data,),
)

average_time["int8_dequant"], _ = benchmark(
torch.ops.fbgemm.Fused8BitRowwiseQuantizedToFloat,
(quant_data_8bit,),
)

average_time["int4_dequant"], _ = benchmark(
torch.ops.fbgemm.FusedNBitRowwiseQuantizedSBHalfToFloat,
(quant_data_4bit, 4),
Expand All @@ -123,6 +143,19 @@ def bench_impl(
torch.ops.fbgemm.HFP8QuantizedToFloat,
(quant_data_fp8_152, 5, 30),
)
average_time["fp16_dequant"], _ = benchmark(
lambda tensor: tensor.float(),
(quant_data_fp16,),
)
average_time["bf16_dequant_fbgemm"], _ = benchmark(
torch.ops.fbgemm.Bfloat16QuantizedToFloat,
(quant_data_bf16_fbgemm,),
)
average_time["bf16_dequant_pytorch"], _ = benchmark(
lambda tensor: tensor.view(torch.bfloat16).float(),
(quant_data_bf16_pytorch,),
)

logging.info(f"-------------- ncols={num_columns}, nrows={num_rows}-------------")
for k, t_time in average_time.items():
logging.info(f"{k} time per iter: {t_time * 1.0e6:.0f}us")
Expand Down

0 comments on commit 243010f

Please sign in to comment.