Skip to content

Commit

Permalink
Add support for multi-thread/stream execution in feed_lower_benchmark (
Browse files Browse the repository at this point in the history
…pytorch#1474)

Summary: Pull Request resolved: pytorch#1474

Reviewed By: jianyuh

Differential Revision: D41427487

fbshipit-source-id: fbc4336986d7c0ec0f542fb48d0b08729c0b3c23
  • Loading branch information
Jiecao Yu authored and facebook-github-bot committed Nov 29, 2022
1 parent e749c95 commit 13ea863
Showing 1 changed file with 65 additions and 1 deletion.
66 changes: 65 additions & 1 deletion fbgemm_gpu/bench/bench_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import copy
import logging
import statistics
import threading
import time
from typing import Callable, List, Optional, Tuple

Expand All @@ -29,14 +31,17 @@ def benchmark_torch_function(
num_warmups: int = 2,
device: str = "cuda",
name: str = "",
num_threads: int = 1,
copy_f_for_multi_thread_test: bool = False,
) -> Tuple[float, torch.Tensor]:
logging.info(f"Start to benchmark {name}...")
if device != "" and device != "cuda":
torch.cuda.set_device(device)
for _ in range(num_warmups):
output = f(*args)

if torch.cuda.is_available():
assert num_threads > 0
if torch.cuda.is_available() and (num_threads == 1):
cache = torch.empty(
int(flush_gpu_cache_size_mb * 1024 * 1024 // 4),
dtype=torch.float,
Expand All @@ -58,6 +63,65 @@ def benchmark_torch_function(
[s.elapsed_time(e) for s, e in zip(start_event, end_event)]
)
elapsed_time = torch.mean(times).item() * 1.0e-3
elif torch.cuda.is_available() and (num_threads > 1):
cache = torch.empty(
int(flush_gpu_cache_size_mb * 1024 * 1024 // 4),
dtype=torch.float,
device=device,
)
duration_ms_list: List[float] = []

f_list = [f]
# make deepcopy of f if necessary
for _ in range(num_threads - 1):
f_list.append(copy.deepcopy(f) if copy_f_for_multi_thread_test else f)

@torch.inference_mode()
# pyre-ignore[53]
def forward(idx: int) -> None:
stream = torch.cuda.Stream()
f_temp = f_list[idx]
start_event = [
torch.cuda.Event(enable_timing=True)
for i in range(iters // num_threads)
]
end_event = [
torch.cuda.Event(enable_timing=True)
for i in range(iters // num_threads)
]
torch.cuda.synchronize(device)
with torch.cuda.stream(stream):
for i in range(iters // num_threads):
# flush the cache
if flush_gpu_cache_size_mb:
cache.zero_()
start_event[i].record()
with torch.cuda.nvtx.range(f"RunCudaModule_{name}"):
_ = f_temp(*args)
end_event[i].record()
torch.cuda.synchronize(device)
times = torch.tensor(
[s.elapsed_time(e) for s, e in zip(start_event, end_event)]
)
duration_ms = torch.sum(times).item()
duration_ms_list.append(duration_ms)

threads = [
threading.Thread(target=forward, args=(idx,)) for idx in range(num_threads)
]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
elapsed_time = sum(duration_ms_list) * 1.0e-3 / num_threads / iters

torch.cuda.synchronize(device)
if copy_f_for_multi_thread_test:
# clean the copies of f and clean the HBM cache
for idx in range(num_threads - 1):
del f_list[idx + 1]
torch.cuda.empty_cache()

else:
start_time = time.time()
for _ in range(iters):
Expand Down

0 comments on commit 13ea863

Please sign in to comment.