From bdb709541107d5adc733d0b7f35e7ebf24ffb1c4 Mon Sep 17 00:00:00 2001 From: Shintaro Iwasaki Date: Mon, 9 Dec 2024 17:10:02 -0800 Subject: [PATCH] Add a benchmark for VBE (#3464) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/547 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/3464 This commit adds a new benchmark for VBE, which solely evaluates the VBE kernel. The existing "vbe" benchmark is renamed to `benchmark-tbe-input-compression` since it actually evaluates the performance of compressed and uncompressed versions of TBE implementations. Reviewed By: sryap Differential Revision: D66797784 fbshipit-source-id: a9eb13d588053156f6d4f07263d1e37d32daa9e6 --- fbgemm_gpu/bench/bench_utils.py | 62 +++++++- ...plit_table_batched_embeddings_benchmark.py | 147 +++++++++++++++++- 2 files changed, 203 insertions(+), 6 deletions(-) diff --git a/fbgemm_gpu/bench/bench_utils.py b/fbgemm_gpu/bench/bench_utils.py index 7d2c650925..830415945c 100644 --- a/fbgemm_gpu/bench/bench_utils.py +++ b/fbgemm_gpu/bench/bench_utils.py @@ -389,7 +389,7 @@ def benchmark_pipelined_requests( @dataclass -class VBEBenchmarkOutput: +class EvalCompressionBenchmarkOutput: avg: float fwd: float bwd: float @@ -399,14 +399,14 @@ class VBEBenchmarkOutput: compressed_bwd: float -def benchmark_vbe( +def benchmark_eval_compression( baseline_requests: List[Tuple[torch.Tensor, torch.Tensor]], compressed_requests: List[Tuple[torch.Tensor, torch.Tensor]], baseline_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], compressed_func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], reindex: torch.Tensor, embedding_dim: int, -) -> VBEBenchmarkOutput: +) -> EvalCompressionBenchmarkOutput: times = [] fwd_times = [] bwd_times = [] @@ -485,11 +485,65 @@ def benchmark_vbe( reindex = statistics.median(reindex_times) compressed_bwd = statistics.median(bwd_times) - return VBEBenchmarkOutput( + return EvalCompressionBenchmarkOutput( avg, fwd, bwd, compressed_avg, compressed_fwd, reindex, compressed_bwd ) +def benchmark_vbe( + requests: List[Tuple[torch.Tensor, torch.Tensor]], + func: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], +) -> Tuple[float, float]: + """ + A benchmark function to return the average execution time in seconds of + forward and backward of VBE kernels. + + Args: + requests (List[Tuple[torch.Tensor, torch.Tensor]]): + A list of requests. Each request is a tuple + of indices and offsets. + + func (Callable[[torch.Tensor, torch.Tensor], torch.Tensor]): + A function that takes in indices and offsets + and returns the output of the VBE kernel. + + Returns: + Tuple[float, float]: + A tuple of average execution time in seconds of forward and + backward of VBE kernels. + """ + + fwd_times = [] + bwd_times = [] + + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + for indices, offsets in requests: + # forward + start_event.record() + out = func(indices, offsets) + end_event.record() + torch.cuda.synchronize() + it_time = start_event.elapsed_time(end_event) * 1.0e-3 + fwd_times.append(it_time) + + grad = torch.rand_like(out) + start_event.record() + # backward + out.backward(grad) + end_event.record() + torch.cuda.synchronize() + it_time = start_event.elapsed_time(end_event) * 1.0e-3 + bwd_times.append(it_time) + + fwd_time_sec = statistics.median(fwd_times) + bwd_time_sec = statistics.median(bwd_times) + + return fwd_time_sec, bwd_time_sec + + def fill_random_scale_bias( emb: nn.Module, T: int, diff --git a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py index b67837dc24..f439ed6780 100644 --- a/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py +++ b/fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py @@ -70,6 +70,7 @@ if open_source: # pyre-ignore[21] from bench_utils import ( + benchmark_eval_compression, benchmark_pipelined_requests, benchmark_requests, benchmark_requests_refer, @@ -79,6 +80,7 @@ ) else: from fbgemm_gpu.bench.bench_utils import ( + benchmark_eval_compression, benchmark_pipelined_requests, benchmark_requests, benchmark_requests_refer, @@ -3375,7 +3377,7 @@ def _to_offsets(lengths: torch.Tensor) -> torch.Tensor: @click.option("--num-tables", default=20) @click.option("--compressed-tables", default=10) @click.option("--iters", default=100) -def vbe( +def benchmark_tbe_input_compression( batch_size: int, compressed_batch_size: int, embedding_dim: int, @@ -3470,7 +3472,7 @@ def vbe( for _ in range(iters) ] - out = benchmark_vbe( + out = benchmark_eval_compression( requests, compressed_requests, baseline_func=lambda indices, offsets: emb.forward( @@ -3493,5 +3495,146 @@ def vbe( ) +@cli.command() +@click.option("--batch-sizes", default="128000,1280") +@click.option("--embedding-dims", default="1024,16") +@click.option("--bag-sizes", default="5,2") +@click.option("--nums-embeddings", default="10000,1000000") +@click.option("--num-tables", default=2) +@click.option("--iters", default=100) +def vbe( + batch_sizes: str, + embedding_dims: str, + bag_sizes: str, + nums_embeddings: str, + num_tables: int, + iters: int, +) -> None: + """ + A benchmark function to evaluate variable batch-size table-batched + embedding (VBE) kernels for both forward and backward. Unlike TBE, + batch sizes can be specified per table for VBE. + + Args: + batch_sizes (str): + A comma separated list of batch sizes for each table. + + embedding_dims (str): + A comma separated list of embedding dimensions for each table. + + bag_sizes (str): + A comma separated list of bag sizes for each table. + + num_embeddings (str): + A comma separated list of number of embeddings for each table. + + num_tables (int): + The number of tables. + + iters (int): + The number of iterations to run the benchmark for. + """ + + torch.manual_seed(42) + Bs = [int(v) for v in batch_sizes.split(",")] + Ds = [int(v) for v in embedding_dims.split(",")] + Ls = [int(v) for v in bag_sizes.split(",")] + Es = [int(v) for v in nums_embeddings.split(",")] + T = num_tables + + # All these variables must have the same length. + assert T == len(Bs) + assert T == len(Ds) + assert T == len(Ls) + assert T == len(Es) + + optimizer = OptimType.EXACT_ROWWISE_ADAGRAD + managed_option = ( + EmbeddingLocation.DEVICE + if torch.cuda.is_available() + else EmbeddingLocation.HOST + ) + pooling_mode = PoolingMode.SUM + + emb = SplitTableBatchedEmbeddingBagsCodegen( + [ + ( + E, + D, + managed_option, + ComputeDevice.CUDA, + ) + for E, D in zip(Es, Ds) + ], + optimizer=optimizer, + learning_rate=0.1, + eps=0.1, + weights_precision=SparseType.FP32, + stochastic_rounding=False, + output_dtype=SparseType.FP32, + pooling_mode=pooling_mode, + bounds_check_mode=BoundsCheckMode(BoundsCheckMode.NONE.value), + ).to(get_device()) + + lengths_list: List[torch.Tensor] = [] + num_values_per_table: List[int] = [] + for t, B in enumerate(Bs): + L = Ls[t] + # Assume a uniformly distributed random number in [0, 2L) + # On average it should be L. + lengths_list.append( + torch.randint( + low=0, high=2 * L, size=(B,), dtype=torch.int64, device=get_device() + ) + ) + + # num_values is used later. + # Note: sum().tolist() returns a scalar value. + # pyre-ignore + num_values: int = torch.sum(lengths_list[-1]).tolist() + num_values_per_table.append(num_values) + + lengths = torch.cat(lengths_list, 0) + + # Convert lengths into offsets. + offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths) + + # Set up values. + values_list: List[torch.Tensor] = [] + for t, E in enumerate(Es): + # Assuming that an index distribution is uniform [0, E) + values_list.append( + torch.randint( + low=0, + high=E, + size=(num_values_per_table[t],), + dtype=torch.int32, + device=get_device(), + ) + ) + values = torch.cat(values_list, 0) + + requests = [ + ( + values, + offsets, + ) + for _ in range(iters) + ] + + fwd_time_sec, bwd_time_sec = benchmark_vbe( + requests, + func=lambda indices, offsets: emb.forward( + indices.long(), + offsets.long(), + batch_size_per_feature_per_rank=[[B] for B in Bs], + ), + ) + logging.info( + f"T: {T}, Bs: {Bs}, Ds: {Ds}, Ls: {Ls}, Es: {Es}\n" + f"fwd: {fwd_time_sec * 1.0e6:.0f}us, bwd: {bwd_time_sec * 1.0e6:.0f}us" + ) + + if __name__ == "__main__": cli()