From 99726128b8ab52e380c81f99a9292005dade5b13 Mon Sep 17 00:00:00 2001 From: Sarunya Pumma Date: Mon, 28 Oct 2024 16:43:27 -0700 Subject: [PATCH] Improve bounds check indices benchmark (#3283) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3283 X-link: https://github.com/facebookresearch/FBGEMM/pull/380 Improve bounds_check_indices benchmark: - Add VBE support - Add an option to export trace Reviewed By: spcyppt Differential Revision: D65010206 fbshipit-source-id: 97992304d99da534b652e3dcceba5f67f2711147 --- ...plit_table_batched_embeddings_benchmark.py | 110 ++++++++++++++---- 1 file changed, 85 insertions(+), 25 deletions(-) diff --git a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py index 4517351dcb..585eee82db 100644 --- a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py @@ -16,6 +16,7 @@ import statistics import tempfile from contextlib import nullcontext +from itertools import accumulate from pathlib import Path from typing import Any, Callable, Dict, List, Optional @@ -2689,9 +2690,30 @@ def pruned_array( # noqa C901 @click.option("--warmup-runs", default=0) @click.option("--num-embeddings", default=int(1e5)) @click.option("--num-tables", default=32) -@click.option("--bounds-check-mode", type=int, default=BoundsCheckMode.WARNING.value) +@click.option( + "--bounds-check-mode", + type=int, + default=BoundsCheckMode.WARNING.value, + help=f"Available modes: FATAL={BoundsCheckMode.FATAL.value}, " + f"WARNING={BoundsCheckMode.WARNING.value}, " + f"IGNORE={BoundsCheckMode.IGNORE.value}, " + f"NONE={BoundsCheckMode.NONE.value}", +) @click.option("--requests_data_file", type=str, default=None) @click.option("--tables", type=str, default=None) +@click.option( + "--batch-sizes", + type=str, + default="", + help="A list of batch sizes for the variable batch size case (VBE). " + "The list is comma separated, i.e., 512,128,4", +) +@click.option("--export-trace", is_flag=True, default=False) +@click.option( + "--trace-url", + type=str, + default="bounds_check_indices_trace_{ospid}.json", +) def bounds_check_indices( # noqa C901 bag_size: int, batch_size: int, @@ -2702,43 +2724,81 @@ def bounds_check_indices( # noqa C901 bounds_check_mode: int, requests_data_file: Optional[str], tables: Optional[str], + batch_sizes: str, + export_trace: bool, + trace_url: str, ) -> None: np.random.seed(42) torch.manual_seed(42) - B = batch_size L = bag_size E = num_embeddings T = num_tables - requests = generate_requests( - iters, - B, - T, - L, - E, - requests_data_file=requests_data_file, - tables=tables, - ) + is_vbe = len(batch_sizes) > 0 + if is_vbe: + Bs = [int(B) for B in batch_sizes.split(",")] + assert ( + len(Bs) == T + ), "The number of batch sizes must be the same as the number of tables" + B_offsets = torch.tensor([0] + list(accumulate(Bs))) + max_B = max(Bs) + total_B = int(B_offsets[-1].item()) + requests = generate_requests( + iters, + total_B, + 1, + L, + E, + requests_data_file=requests_data_file, + tables=tables, + ) + B_offsets = B_offsets.to(get_device()).to(torch.int) + else: + B = batch_size + Bs = [B] * T + B_offsets = None + max_B = -1 + total_B = B * T + requests = generate_requests( + iters, + B, + T, + L, + E, + requests_data_file=requests_data_file, + tables=tables, + ) warning = torch.tensor([0]).long().to(get_device()) rows_per_table = torch.tensor([E for _ in range(T)]).long().to(get_device()) - # forward - time_per_iter = benchmark_requests( - requests, - lambda indices, offsets, _: torch.ops.fbgemm.bounds_check_indices( - rows_per_table, - indices.long(), - offsets.long(), - BoundsCheckMode(bounds_check_mode), - warning, - ), - num_warmups=warmup_runs, - ) + + def _kineto_trace_handler(p: profile) -> None: + p.export_chrome_trace(trace_url.format(ospid=os.getpid())) + + # pyre-ignore[3] + def context_factory(on_trace_ready: Callable[[profile], None]): + return profile(on_trace_ready=on_trace_ready) if export_trace else nullcontext() + + with context_factory(lambda p: _kineto_trace_handler(p)): + # forward + time_per_iter = benchmark_requests( + requests, + lambda indices, offsets, _: torch.ops.fbgemm.bounds_check_indices( + rows_per_table, + indices.long(), + offsets.long(), + BoundsCheckMode(bounds_check_mode), + warning, + B_offsets=B_offsets, + max_B=max_B, + ), + num_warmups=warmup_runs, + ) logging.info( - f"Bounds Check Indices: B: {B}, " + f"Bounds Check Indices: Bs: {Bs}, " f"E: {E}, T: {T}, L: {L}, " - f"BW: {(8 * B * T * L + 8 * (B * T + 1)) / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 + f"BW: {(8 * total_B * L + 8 * (total_B + 1)) / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 f"T: {time_per_iter * 1.0e6:.0f}us" )