forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[RELAND] Add __torch_function__ benchmarks (pytorch#36138)
Summary: Re-land of pytorch#35530 and pytorch#34645 Pull Request resolved: pytorch#36138 Differential Revision: D20893770 Pulled By: ezyang fbshipit-source-id: 75ab688a086f5fb87412a853df5246c0c39704ca
- Loading branch information
1 parent
3aeb2b1
commit 7c825ba
Showing
7 changed files
with
187 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# `__torch_function__` micro-benchmarks | ||
|
||
This benchmark suite provides a systemic way to measure the performance of `__torch_function__` overhead. | ||
|
||
## Getting started | ||
### Initial Setup | ||
Install `py-spy` by doing: | ||
|
||
```bash | ||
pip install py-spy | ||
``` | ||
|
||
Note that more extensive documentation on using `py-spy` is available in `CONTRIBUTING.md`. | ||
|
||
### Running the benchmark | ||
Run one of the following commands in the terminal, with the working directory being `${PYTORCH_CLONE_DIR}/benchmarks/overrides_benchmark`: | ||
|
||
```bash | ||
# Benchmark all the cases | ||
python bench.py | ||
|
||
# Flame graph pertaining to each case. | ||
py-spy record -o tensor.svg --native -- python pyspybench.py Tensor | ||
py-spy record -o subtensor.svg --native -- python pyspybench.py SubTensor | ||
py-spy record -o overridden.svg --native -- python pyspybench.py WithTorchFunction | ||
py-spy record -o suboverridden.svg --native -- python pyspybench.py SubWithTorchFunction | ||
``` | ||
|
||
Here is a brief overview of what the results should look like, if run correctly: | ||
|
||
* Overhead for `torch` functions when run on `torch.Tensor` objects is on the order of 2 μs. | ||
* `__torch_function__` should add zero overhead for `torch.Tensor` inputs, a small overhead for subclasses of `torch.Tensor`, and a couple of microseconds for `Tensor`-likes with `__torch_function__`. | ||
* Changing the dispatching mechanism may result in changes that are on the order of 100 ns, which are hard to detect due to noise, but important. | ||
|
||
## Reporting benchmark results | ||
When modifying any of the machinery around `__torch_function__`, run the benchmark for both the feature branch and the point it diverges from `master`. For each of these: | ||
|
||
* Run `bench.py`, and include the output in your result. | ||
* For each case where `bench.py` shows a regression, run the commands described above, prefixing the output SVG filename (the input to the `-o` switch) with `base-` or `branch-` depending on the commit you are running the benchmark on. | ||
* For each SVG, open it in the browser, take a screenshot and include it in your result. Also include a ZIP file with all SVGs thus produced included. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import torch | ||
import time | ||
import argparse | ||
|
||
from common import SubTensor, WithTorchFunction, SubWithTorchFunction | ||
|
||
NUM_REPEATS = 1000 | ||
NUM_REPEAT_OF_REPEATS = 1000 | ||
|
||
|
||
def bench(t1, t2): | ||
bench_times = [] | ||
for _ in range(NUM_REPEAT_OF_REPEATS): | ||
time_start = time.time() | ||
for _ in range(NUM_REPEATS): | ||
torch.add(t1, t2) | ||
bench_times.append(time.time() - time_start) | ||
|
||
bench_time = float(torch.min(torch.Tensor(bench_times))) / 1000 | ||
bench_std = float(torch.std(torch.Tensor(bench_times))) / 1000 | ||
|
||
return bench_time, bench_std | ||
|
||
|
||
def main(): | ||
global NUM_REPEATS | ||
global NUM_REPEAT_OF_REPEATS | ||
|
||
parser = argparse.ArgumentParser( | ||
description="Run the __torch_function__ benchmarks." | ||
) | ||
parser.add_argument( | ||
"--nreps", | ||
"-n", | ||
type=int, | ||
default=NUM_REPEATS, | ||
help="The number of repeats for one measurement.", | ||
) | ||
parser.add_argument( | ||
"--nrepreps", | ||
"-m", | ||
type=int, | ||
default=NUM_REPEAT_OF_REPEATS, | ||
help="The number of measurements.", | ||
) | ||
args = parser.parse_args() | ||
|
||
NUM_REPEATS = args.nreps | ||
NUM_REPEAT_OF_REPEATS = args.nrepreps | ||
|
||
types = torch.Tensor, SubTensor, WithTorchFunction, SubWithTorchFunction | ||
|
||
for t in types: | ||
tensor_1 = t(1) | ||
tensor_2 = t(2) | ||
|
||
bench_min, bench_std = bench(tensor_1, tensor_2) | ||
print( | ||
"Type {0} had a minimum time of {1} us" | ||
" and a standard deviation of {2} us.".format( | ||
t.__name__, (10 ** 6 * bench_min), (10 ** 6) * bench_std | ||
) | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import torch | ||
|
||
NUM_REPEATS = 1000 | ||
NUM_REPEAT_OF_REPEATS = 1000 | ||
|
||
|
||
class SubTensor(torch.Tensor): | ||
pass | ||
|
||
|
||
class WithTorchFunction: | ||
def __init__(self, data, requires_grad=False): | ||
if isinstance(data, torch.Tensor): | ||
self._tensor = data | ||
return | ||
|
||
self._tensor = torch.Tensor(data, requires_grad) | ||
|
||
def __torch_function__(self, func, types, args=(), kwargs=None): | ||
if kwargs is None: | ||
kwargs = {} | ||
|
||
return WithTorchFunction(args[0]._tensor + args[1]._tensor) | ||
|
||
|
||
class SubWithTorchFunction(torch.Tensor): | ||
def __torch_function__(self, func, types, args=(), kwargs=None): | ||
if kwargs is None: | ||
kwargs = {} | ||
|
||
return args[0] + args[1] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import torch | ||
import argparse | ||
from common import SubTensor, WithTorchFunction, SubWithTorchFunction # noqa: F401 | ||
|
||
Tensor = torch.Tensor | ||
|
||
NUM_REPEATS = 1000000 | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser( | ||
description="Run the torch.add for a given class a given number of times." | ||
) | ||
parser.add_argument( | ||
"tensor_class", metavar="TensorClass", type=str, help="The class to benchmark." | ||
) | ||
parser.add_argument( | ||
"--nreps", "-n", type=int, default=NUM_REPEATS, help="The number of repeats." | ||
) | ||
args = parser.parse_args() | ||
|
||
TensorClass = globals()[args.tensor_class] | ||
NUM_REPEATS = args.nreps | ||
|
||
t1 = TensorClass(1) | ||
t2 = TensorClass(2) | ||
|
||
for _ in range(NUM_REPEATS): | ||
torch.add(t1, t2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters