Skip to content

Commit

Permalink
Improve bounds check indices benchmark (pytorch#3283)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#3283

X-link: facebookresearch/FBGEMM#380

Improve bounds_check_indices benchmark:

- Add VBE support
- Add an option to export trace

Reviewed By: spcyppt

Differential Revision: D65010206

fbshipit-source-id: 97992304d99da534b652e3dcceba5f67f2711147
  • Loading branch information
sryap authored and facebook-github-bot committed Oct 28, 2024
1 parent cdd102f commit 9972612
Showing 1 changed file with 85 additions and 25 deletions.
110 changes: 85 additions & 25 deletions fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import statistics
import tempfile
from contextlib import nullcontext
from itertools import accumulate
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional

Expand Down Expand Up @@ -2689,9 +2690,30 @@ def pruned_array( # noqa C901
@click.option("--warmup-runs", default=0)
@click.option("--num-embeddings", default=int(1e5))
@click.option("--num-tables", default=32)
@click.option("--bounds-check-mode", type=int, default=BoundsCheckMode.WARNING.value)
@click.option(
"--bounds-check-mode",
type=int,
default=BoundsCheckMode.WARNING.value,
help=f"Available modes: FATAL={BoundsCheckMode.FATAL.value}, "
f"WARNING={BoundsCheckMode.WARNING.value}, "
f"IGNORE={BoundsCheckMode.IGNORE.value}, "
f"NONE={BoundsCheckMode.NONE.value}",
)
@click.option("--requests_data_file", type=str, default=None)
@click.option("--tables", type=str, default=None)
@click.option(
"--batch-sizes",
type=str,
default="",
help="A list of batch sizes for the variable batch size case (VBE). "
"The list is comma separated, i.e., 512,128,4",
)
@click.option("--export-trace", is_flag=True, default=False)
@click.option(
"--trace-url",
type=str,
default="bounds_check_indices_trace_{ospid}.json",
)
def bounds_check_indices( # noqa C901
bag_size: int,
batch_size: int,
Expand All @@ -2702,43 +2724,81 @@ def bounds_check_indices( # noqa C901
bounds_check_mode: int,
requests_data_file: Optional[str],
tables: Optional[str],
batch_sizes: str,
export_trace: bool,
trace_url: str,
) -> None:
np.random.seed(42)
torch.manual_seed(42)
B = batch_size
L = bag_size
E = num_embeddings
T = num_tables

requests = generate_requests(
iters,
B,
T,
L,
E,
requests_data_file=requests_data_file,
tables=tables,
)
is_vbe = len(batch_sizes) > 0
if is_vbe:
Bs = [int(B) for B in batch_sizes.split(",")]
assert (
len(Bs) == T
), "The number of batch sizes must be the same as the number of tables"
B_offsets = torch.tensor([0] + list(accumulate(Bs)))
max_B = max(Bs)
total_B = int(B_offsets[-1].item())
requests = generate_requests(
iters,
total_B,
1,
L,
E,
requests_data_file=requests_data_file,
tables=tables,
)
B_offsets = B_offsets.to(get_device()).to(torch.int)
else:
B = batch_size
Bs = [B] * T
B_offsets = None
max_B = -1
total_B = B * T
requests = generate_requests(
iters,
B,
T,
L,
E,
requests_data_file=requests_data_file,
tables=tables,
)

warning = torch.tensor([0]).long().to(get_device())
rows_per_table = torch.tensor([E for _ in range(T)]).long().to(get_device())
# forward
time_per_iter = benchmark_requests(
requests,
lambda indices, offsets, _: torch.ops.fbgemm.bounds_check_indices(
rows_per_table,
indices.long(),
offsets.long(),
BoundsCheckMode(bounds_check_mode),
warning,
),
num_warmups=warmup_runs,
)

def _kineto_trace_handler(p: profile) -> None:
p.export_chrome_trace(trace_url.format(ospid=os.getpid()))

# pyre-ignore[3]
def context_factory(on_trace_ready: Callable[[profile], None]):
return profile(on_trace_ready=on_trace_ready) if export_trace else nullcontext()

with context_factory(lambda p: _kineto_trace_handler(p)):
# forward
time_per_iter = benchmark_requests(
requests,
lambda indices, offsets, _: torch.ops.fbgemm.bounds_check_indices(
rows_per_table,
indices.long(),
offsets.long(),
BoundsCheckMode(bounds_check_mode),
warning,
B_offsets=B_offsets,
max_B=max_B,
),
num_warmups=warmup_runs,
)

logging.info(
f"Bounds Check Indices: B: {B}, "
f"Bounds Check Indices: Bs: {Bs}, "
f"E: {E}, T: {T}, L: {L}, "
f"BW: {(8 * B * T * L + 8 * (B * T + 1)) / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950
f"BW: {(8 * total_B * L + 8 * (total_B + 1)) / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950
f"T: {time_per_iter * 1.0e6:.0f}us"
)

Expand Down

0 comments on commit 9972612

Please sign in to comment.