Skip to content

Commit

Permalink
Allow skipping dequantization after P2P. (pytorch#947)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#947

This is to profile the full INT8 operation (Gemm is also INT8) scenario.

Also, fix a few pytype bugs.

Reviewed By: jianyuh

Differential Revision: D34132882

fbshipit-source-id: c4bb477d902d97b8266bfa31705971d6edc4e849
  • Loading branch information
caogao authored and facebook-github-bot committed Feb 23, 2022
1 parent d838f62 commit d791702
Showing 1 changed file with 22 additions and 10 deletions.
32 changes: 22 additions & 10 deletions fbgemm_gpu/bench/merge_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,16 +235,16 @@ def print_p2p_bandwidth(


def benchmark(
all_to_one_only,
all_to_one_only: bool,
num_ads: int,
embedding_dimension: int,
ads_tables: int,
iters: int = 10,
p2p_bw: bool = False,
dst_device: int = 0,
data_type: str = "FP16",
include_quantization: bool = False,
mode: str = "P2P",
skip_dequantization: bool = False,
num_of_embeddings: int = 10000,
pooling_factor: int = 25,
) -> str:
Expand Down Expand Up @@ -296,6 +296,7 @@ def pool_func_with_quantization(
include_quantization,
include_tbe,
fused_tbe,
skip_dequantization,
data_type,
):
if include_tbe:
Expand Down Expand Up @@ -340,6 +341,9 @@ def pool_func_with_quantization(
quantized, batch_indices.size(0), batch_indices.device
)

if skip_dequantization:
return pooled_quantized_result

PooledEmbeddingDequantizeDataTypeFP16 = 1
if data_type == "INT8":
return torch.ops.fbgemm.Fused8BitRowwiseQuantizedToFloatMixedDim(
Expand Down Expand Up @@ -367,6 +371,7 @@ def pool_func_with_quantization(
include_quantization,
include_tbe,
fused_tbe,
skip_dequantization,
data_type,
)
t = benchmark_torch_function(
Expand All @@ -376,6 +381,7 @@ def pool_func_with_quantization(
include_quantization,
include_tbe,
fused_tbe,
skip_dequantization,
data_type,
),
)
Expand All @@ -385,6 +391,7 @@ def pool_func_with_quantization(
include_quantization,
include_tbe,
fused_tbe,
skip_dequantization,
data_type,
)
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
Expand Down Expand Up @@ -426,22 +433,25 @@ def pool_func_with_quantization(
type=click.Choice(["P2P", "P2P_QUANT", "P2P_TBE", "P2P_FUSED_TBE"]),
default="P2P",
)
# For quantized communication, do we dequantize back to FP16 in the end.
@click.option("--skip_dequantization", is_flag=True, default=False)
@click.option("--num_of_embeddings", default=100000, type=int)
@click.option("--pooling_factor", default=25, type=int)
@click.option("--sweep", is_flag=True, default=False)
def main(
all_to_one_only,
num_ads,
embedding_dimension,
ads_tables,
all_to_one_only: bool,
num_ads: int,
embedding_dimension: int,
ads_tables: int,
iters: int,
p2p_bw: bool,
dst_device: int,
data_type: str,
mode: bool,
num_of_embeddings: str,
mode: str,
skip_dequantization: bool,
num_of_embeddings: int,
pooling_factor: int,
sweep,
sweep: bool,
) -> None:
csv_header = (
"mode, data_type, num_ads, embedding_dimension, ads_tables, num_gpus, "
Expand All @@ -461,7 +471,7 @@ def handler(signum, frame):
num_ads *= 8 // num_gpu
for embedding_dimension in [16, 64, 112, 304]:
for ads_tables in [25, 50, 100, 400, 800]:
if num_ads * embedding_dimension * ads_tables > 1228800000:
if num_ads * embedding_dimension * ads_tables > 983040000:
continue # Skip tests that are too large
signal.signal(signal.SIGTERM, handler)
signal.alarm(600)
Expand All @@ -479,6 +489,7 @@ def handler(signum, frame):
dst_device,
data_type,
mode,
skip_dequantization,
num_of_embeddings,
pooling_factor,
)
Expand All @@ -501,6 +512,7 @@ def handler(signum, frame):
dst_device,
data_type,
mode,
skip_dequantization,
num_of_embeddings,
pooling_factor,
)
Expand Down

0 comments on commit d791702

Please sign in to comment.