Skip to content

Commit

Permalink
Refactor TBE generate requests return type (pytorch#2411)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/param#95

Pull Request resolved: pytorch#2411

This diff refactors the return type of `generate_requests` (TBE random
input generator).  Prior to this diff, `generate_requests` returns
a list of indices, offsets and per sample weights (optional) tuple
(`List[Tuple(Tensor, Tensor, Optional[Tensor])]`).  If we add another
return value to the tuple, we need to update every request tuple
unpacking site and update typing to satisfy Pyre requirements.  Thus,
this diff adds `TBERequest` which is a wrapper of return values of
`generate_requests`.  It allows the user to access each return value
individually or as an arbitrary length tuple.

Reviewed By: q10

Differential Revision: D54710610

fbshipit-source-id: fec2e926ff186ea233d4dea109208a68339e3388
  • Loading branch information
sryap authored and facebook-github-bot committed Mar 11, 2024
1 parent dc3a268 commit 415b72e
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 82 deletions.
20 changes: 11 additions & 9 deletions fbgemm_gpu/bench/bench_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
generate_requests, # noqa: F401
get_device, # noqa: F401
round_up, # noqa: F401
TBERequest,
)
from torch import nn

Expand Down Expand Up @@ -142,7 +143,7 @@ def forward(idx: int) -> None:


def benchmark_requests(
requests: List[Tuple[torch.IntTensor, torch.IntTensor, Optional[torch.Tensor]]],
requests: List[TBERequest],
func: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], torch.Tensor],
flush_gpu_cache_size_mb: int = 0,
check_median: bool = False,
Expand All @@ -163,7 +164,7 @@ def benchmark_requests(
num_warmups = num_warmups + 1 if num_warmups >= 0 else 1

if num_warmups > 0:
indices, offsets, weights = requests[0]
indices, offsets, weights = requests[0].unpack_3()
for _ in range(num_warmups):
out = func(indices, offsets, weights)
if bwd_only:
Expand All @@ -176,7 +177,8 @@ def benchmark_requests(
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
for it, (indices, offsets, weights) in enumerate(requests):
for it, req in enumerate(requests):
indices, offsets, weights = req.unpack_3()
if bwd_only:
# Run forward before profiling if does backward only
out = func(indices, offsets, weights)
Expand Down Expand Up @@ -217,7 +219,7 @@ def benchmark_requests(


def benchmark_requests_refer(
requests: List[Tuple[torch.IntTensor, torch.IntTensor, Optional[torch.Tensor]]],
requests: List[TBERequest],
T: int,
B: int,
L: int,
Expand All @@ -242,7 +244,8 @@ def benchmark_requests_refer(
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
for indices, _, weights in requests:
for req in requests:
indices, _, weights = req.unpack_3()
indices_list = indices.view(T, B, L).split(1)

if weighted:
Expand Down Expand Up @@ -308,7 +311,7 @@ def benchmark_requests_refer(


def benchmark_pipelined_requests(
requests: List[Tuple[torch.IntTensor, torch.IntTensor, Optional[torch.Tensor]]],
requests: List[TBERequest],
func1: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], None],
func2: Callable[[torch.Tensor, torch.Tensor, Optional[torch.Tensor]], None],
flush_gpu_cache_size_mb: int = 0,
Expand All @@ -323,9 +326,8 @@ def benchmark_pipelined_requests(
(torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True))
for _ in requests
]
for (indices, offsets, indices_weights), start_event, end_event in zip(
requests, start_events, end_events
):
for req, start_event, end_event in zip(requests, start_events, end_events):
indices, offsets, indices_weights = req.unpack_3()
if flush_gpu_cache_size_mb:
_ = torch.rand(
flush_gpu_cache_size_mb * 1024 * 1024 // 4,
Expand Down
Loading

0 comments on commit 415b72e

Please sign in to comment.