Skip to content

Commit

Permalink
Add a benchmark for VBE (pytorch#3464)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#547

Pull Request resolved: pytorch#3464

This commit adds a new benchmark for VBE, which solely evaluates the VBE kernel.

The existing "vbe" benchmark is renamed to `benchmark-tbe-input-compression` since it actually evaluates the performance of compressed and uncompressed versions of TBE implementations.

Reviewed By: sryap

Differential Revision: D66797784

fbshipit-source-id: a9eb13d588053156f6d4f07263d1e37d32daa9e6
  • Loading branch information
shintaro-iwasaki authored and facebook-github-bot committed Dec 10, 2024
1 parent 9210796 commit bdb7095
Show file tree
Hide file tree
Showing 2 changed files with 203 additions and 6 deletions.
62 changes: 58 additions & 4 deletions fbgemm_gpu/bench/bench_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def benchmark_pipelined_requests(


@dataclass
class VBEBenchmarkOutput:
class EvalCompressionBenchmarkOutput:
avg: float
fwd: float
bwd: float
Expand All @@ -399,14 +399,14 @@ class VBEBenchmarkOutput:
compressed_bwd: float


def benchmark_vbe(
def benchmark_eval_compression(
baseline_requests: List[Tuple[torch.Tensor, torch.Tensor]],
compressed_requests: List[Tuple[torch.Tensor, torch.Tensor]],
baseline_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
compressed_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
reindex: torch.Tensor,
embedding_dim: int,
) -> VBEBenchmarkOutput:
) -> EvalCompressionBenchmarkOutput:
times = []
fwd_times = []
bwd_times = []
Expand Down Expand Up @@ -485,11 +485,65 @@ def benchmark_vbe(
reindex = statistics.median(reindex_times)
compressed_bwd = statistics.median(bwd_times)

return VBEBenchmarkOutput(
return EvalCompressionBenchmarkOutput(
avg, fwd, bwd, compressed_avg, compressed_fwd, reindex, compressed_bwd
)


def benchmark_vbe(
requests: List[Tuple[torch.Tensor, torch.Tensor]],
func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
) -> Tuple[float, float]:
"""
A benchmark function to return the average execution time in seconds of
forward and backward of VBE kernels.
Args:
requests (List[Tuple[torch.Tensor, torch.Tensor]]):
A list of requests. Each request is a tuple
of indices and offsets.
func (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]):
A function that takes in indices and offsets
and returns the output of the VBE kernel.
Returns:
Tuple[float, float]:
A tuple of average execution time in seconds of forward and
backward of VBE kernels.
"""

fwd_times = []
bwd_times = []

torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

for indices, offsets in requests:
# forward
start_event.record()
out = func(indices, offsets)
end_event.record()
torch.cuda.synchronize()
it_time = start_event.elapsed_time(end_event) * 1.0e-3
fwd_times.append(it_time)

grad = torch.rand_like(out)
start_event.record()
# backward
out.backward(grad)
end_event.record()
torch.cuda.synchronize()
it_time = start_event.elapsed_time(end_event) * 1.0e-3
bwd_times.append(it_time)

fwd_time_sec = statistics.median(fwd_times)
bwd_time_sec = statistics.median(bwd_times)

return fwd_time_sec, bwd_time_sec


def fill_random_scale_bias(
emb: nn.Module,
T: int,
Expand Down
147 changes: 145 additions & 2 deletions fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
if open_source:
# pyre-ignore[21]
from bench_utils import (
benchmark_eval_compression,
benchmark_pipelined_requests,
benchmark_requests,
benchmark_requests_refer,
Expand All @@ -79,6 +80,7 @@
)
else:
from fbgemm_gpu.bench.bench_utils import (
benchmark_eval_compression,
benchmark_pipelined_requests,
benchmark_requests,
benchmark_requests_refer,
Expand Down Expand Up @@ -3375,7 +3377,7 @@ def _to_offsets(lengths: torch.Tensor) -> torch.Tensor:
@click.option("--num-tables", default=20)
@click.option("--compressed-tables", default=10)
@click.option("--iters", default=100)
def vbe(
def benchmark_tbe_input_compression(
batch_size: int,
compressed_batch_size: int,
embedding_dim: int,
Expand Down Expand Up @@ -3470,7 +3472,7 @@ def vbe(
for _ in range(iters)
]

out = benchmark_vbe(
out = benchmark_eval_compression(
requests,
compressed_requests,
baseline_func=lambda indices, offsets: emb.forward(
Expand All @@ -3493,5 +3495,146 @@ def vbe(
)


@cli.command()
@click.option("--batch-sizes", default="128000,1280")
@click.option("--embedding-dims", default="1024,16")
@click.option("--bag-sizes", default="5,2")
@click.option("--nums-embeddings", default="10000,1000000")
@click.option("--num-tables", default=2)
@click.option("--iters", default=100)
def vbe(
batch_sizes: str,
embedding_dims: str,
bag_sizes: str,
nums_embeddings: str,
num_tables: int,
iters: int,
) -> None:
"""
A benchmark function to evaluate variable batch-size table-batched
embedding (VBE) kernels for both forward and backward. Unlike TBE,
batch sizes can be specified per table for VBE.
Args:
batch_sizes (str):
A comma separated list of batch sizes for each table.
embedding_dims (str):
A comma separated list of embedding dimensions for each table.
bag_sizes (str):
A comma separated list of bag sizes for each table.
num_embeddings (str):
A comma separated list of number of embeddings for each table.
num_tables (int):
The number of tables.
iters (int):
The number of iterations to run the benchmark for.
"""

torch.manual_seed(42)
Bs = [int(v) for v in batch_sizes.split(",")]
Ds = [int(v) for v in embedding_dims.split(",")]
Ls = [int(v) for v in bag_sizes.split(",")]
Es = [int(v) for v in nums_embeddings.split(",")]
T = num_tables

# All these variables must have the same length.
assert T == len(Bs)
assert T == len(Ds)
assert T == len(Ls)
assert T == len(Es)

optimizer = OptimType.EXACT_ROWWISE_ADAGRAD
managed_option = (
EmbeddingLocation.DEVICE
if torch.cuda.is_available()
else EmbeddingLocation.HOST
)
pooling_mode = PoolingMode.SUM

emb = SplitTableBatchedEmbeddingBagsCodegen(
[
(
E,
D,
managed_option,
ComputeDevice.CUDA,
)
for E, D in zip(Es, Ds)
],
optimizer=optimizer,
learning_rate=0.1,
eps=0.1,
weights_precision=SparseType.FP32,
stochastic_rounding=False,
output_dtype=SparseType.FP32,
pooling_mode=pooling_mode,
bounds_check_mode=BoundsCheckMode(BoundsCheckMode.NONE.value),
).to(get_device())

lengths_list: List[torch.Tensor] = []
num_values_per_table: List[int] = []
for t, B in enumerate(Bs):
L = Ls[t]
# Assume a uniformly distributed random number in [0, 2L)
# On average it should be L.
lengths_list.append(
torch.randint(
low=0, high=2 * L, size=(B,), dtype=torch.int64, device=get_device()
)
)

# num_values is used later.
# Note: sum().tolist() returns a scalar value.
# pyre-ignore
num_values: int = torch.sum(lengths_list[-1]).tolist()
num_values_per_table.append(num_values)

lengths = torch.cat(lengths_list, 0)

# Convert lengths into offsets.
offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)

# Set up values.
values_list: List[torch.Tensor] = []
for t, E in enumerate(Es):
# Assuming that an index distribution is uniform [0, E)
values_list.append(
torch.randint(
low=0,
high=E,
size=(num_values_per_table[t],),
dtype=torch.int32,
device=get_device(),
)
)
values = torch.cat(values_list, 0)

requests = [
(
values,
offsets,
)
for _ in range(iters)
]

fwd_time_sec, bwd_time_sec = benchmark_vbe(
requests,
func=lambda indices, offsets: emb.forward(
indices.long(),
offsets.long(),
batch_size_per_feature_per_rank=[[B] for B in Bs],
),
)
logging.info(
f"T: {T}, Bs: {Bs}, Ds: {Ds}, Ls: {Ls}, Es: {Es}\n"
f"fwd: {fwd_time_sec * 1.0e6:.0f}us, bwd: {bwd_time_sec * 1.0e6:.0f}us"
)


if __name__ == "__main__":
cli()

0 comments on commit bdb7095

Please sign in to comment.