Skip to content

Commit

Permalink
[RELAND] Add __torch_function__ benchmarks (pytorch#36138)
Browse files Browse the repository at this point in the history
Summary:
Re-land of pytorch#35530 and pytorch#34645
Pull Request resolved: pytorch#36138

Differential Revision: D20893770

Pulled By: ezyang

fbshipit-source-id: 75ab688a086f5fb87412a853df5246c0c39704ca
  • Loading branch information
hameerabbasi authored and facebook-github-bot committed Apr 10, 2020
1 parent 3aeb2b1 commit 7c825ba
Show file tree
Hide file tree
Showing 7 changed files with 187 additions and 0 deletions.
14 changes: 14 additions & 0 deletions .jenkins/pytorch/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,18 @@ test_custom_script_ops() {
fi
}

test_torch_function_benchmark() {
echo "Testing __torch_function__ benchmarks"
pushd benchmarks/overrides_benchmark
python bench.py -n 1 -m 2
python pyspybench.py Tensor -n 1
python pyspybench.py SubTensor -n 1
python pyspybench.py WithTorchFunction -n 1
python pyspybench.py SubWithTorchFunction -n 1
popd
assert_git_not_dirty
}

test_xla() {
export XLA_USE_XRT=1 XRT_DEVICE_MAP="CPU:0;/job:localservice/replica:0/task:0/device:XLA_CPU:0"
# Issue #30717: randomize the port of XLA/gRPC workers is listening on to reduce flaky tests.
Expand Down Expand Up @@ -286,6 +298,7 @@ elif [[ "${BUILD_ENVIRONMENT}" == *-test2 || "${JOB_BASE_NAME}" == *-test2 ]]; t
test_aten
test_libtorch
test_custom_script_ops
test_torch_function_benchmark
elif [[ "${BUILD_ENVIRONMENT}" == *-bazel-* ]]; then
test_bazel
else
Expand All @@ -295,4 +308,5 @@ else
test_aten
test_libtorch
test_custom_script_ops
test_torch_function_benchmark
fi
40 changes: 40 additions & 0 deletions benchmarks/overrides_benchmark/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# `__torch_function__` micro-benchmarks

This benchmark suite provides a systemic way to measure the performance of `__torch_function__` overhead.

## Getting started
### Initial Setup
Install `py-spy` by doing:

```bash
pip install py-spy
```

Note that more extensive documentation on using `py-spy` is available in `CONTRIBUTING.md`.

### Running the benchmark
Run one of the following commands in the terminal, with the working directory being `${PYTORCH_CLONE_DIR}/benchmarks/overrides_benchmark`:

```bash
# Benchmark all the cases
python bench.py

# Flame graph pertaining to each case.
py-spy record -o tensor.svg --native -- python pyspybench.py Tensor
py-spy record -o subtensor.svg --native -- python pyspybench.py SubTensor
py-spy record -o overridden.svg --native -- python pyspybench.py WithTorchFunction
py-spy record -o suboverridden.svg --native -- python pyspybench.py SubWithTorchFunction
```

Here is a brief overview of what the results should look like, if run correctly:

* Overhead for `torch` functions when run on `torch.Tensor` objects is on the order of 2 μs.
* `__torch_function__` should add zero overhead for `torch.Tensor` inputs, a small overhead for subclasses of `torch.Tensor`, and a couple of microseconds for `Tensor`-likes with `__torch_function__`.
* Changing the dispatching mechanism may result in changes that are on the order of 100 ns, which are hard to detect due to noise, but important.

## Reporting benchmark results
When modifying any of the machinery around `__torch_function__`, run the benchmark for both the feature branch and the point it diverges from `master`. For each of these:

* Run `bench.py`, and include the output in your result.
* For each case where `bench.py` shows a regression, run the commands described above, prefixing the output SVG filename (the input to the `-o` switch) with `base-` or `branch-` depending on the commit you are running the benchmark on.
* For each SVG, open it in the browser, take a screenshot and include it in your result. Also include a ZIP file with all SVGs thus produced included.
67 changes: 67 additions & 0 deletions benchmarks/overrides_benchmark/bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import torch
import time
import argparse

from common import SubTensor, WithTorchFunction, SubWithTorchFunction

NUM_REPEATS = 1000
NUM_REPEAT_OF_REPEATS = 1000


def bench(t1, t2):
bench_times = []
for _ in range(NUM_REPEAT_OF_REPEATS):
time_start = time.time()
for _ in range(NUM_REPEATS):
torch.add(t1, t2)
bench_times.append(time.time() - time_start)

bench_time = float(torch.min(torch.Tensor(bench_times))) / 1000
bench_std = float(torch.std(torch.Tensor(bench_times))) / 1000

return bench_time, bench_std


def main():
global NUM_REPEATS
global NUM_REPEAT_OF_REPEATS

parser = argparse.ArgumentParser(
description="Run the __torch_function__ benchmarks."
)
parser.add_argument(
"--nreps",
"-n",
type=int,
default=NUM_REPEATS,
help="The number of repeats for one measurement.",
)
parser.add_argument(
"--nrepreps",
"-m",
type=int,
default=NUM_REPEAT_OF_REPEATS,
help="The number of measurements.",
)
args = parser.parse_args()

NUM_REPEATS = args.nreps
NUM_REPEAT_OF_REPEATS = args.nrepreps

types = torch.Tensor, SubTensor, WithTorchFunction, SubWithTorchFunction

for t in types:
tensor_1 = t(1)
tensor_2 = t(2)

bench_min, bench_std = bench(tensor_1, tensor_2)
print(
"Type {0} had a minimum time of {1} us"
" and a standard deviation of {2} us.".format(
t.__name__, (10 ** 6 * bench_min), (10 ** 6) * bench_std
)
)


if __name__ == "__main__":
main()
31 changes: 31 additions & 0 deletions benchmarks/overrides_benchmark/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch

NUM_REPEATS = 1000
NUM_REPEAT_OF_REPEATS = 1000


class SubTensor(torch.Tensor):
pass


class WithTorchFunction:
def __init__(self, data, requires_grad=False):
if isinstance(data, torch.Tensor):
self._tensor = data
return

self._tensor = torch.Tensor(data, requires_grad)

def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}

return WithTorchFunction(args[0]._tensor + args[1]._tensor)


class SubWithTorchFunction(torch.Tensor):
def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}

return args[0] + args[1]
28 changes: 28 additions & 0 deletions benchmarks/overrides_benchmark/pyspybench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch
import argparse
from common import SubTensor, WithTorchFunction, SubWithTorchFunction # noqa: F401

Tensor = torch.Tensor

NUM_REPEATS = 1000000

if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run the torch.add for a given class a given number of times."
)
parser.add_argument(
"tensor_class", metavar="TensorClass", type=str, help="The class to benchmark."
)
parser.add_argument(
"--nreps", "-n", type=int, default=NUM_REPEATS, help="The number of repeats."
)
args = parser.parse_args()

TensorClass = globals()[args.tensor_class]
NUM_REPEATS = args.nreps

t1 = TensorClass(1)
t2 = TensorClass(2)

for _ in range(NUM_REPEATS):
torch.add(t1, t2)
3 changes: 3 additions & 0 deletions torch/_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
https://www.numpy.org/neps/nep-0018-array-function-protocol.html
)
If changing this file in a way that can affect ``__torch_function__`` overhead,
please report the benchmarks in ``benchmarks/overrides_benchmark``. See the
instructions in the ``README.md`` in that directory.
"""

import __future__
Expand Down
4 changes: 4 additions & 0 deletions torch/csrc/utils/python_arg_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,10 @@ auto handle_torch_function(PythonArgs &r, PyObject* args, PyObject* kwargs, PyOb
* precedence.
*
* 'obj' is an object to check for a __torch_function__ implementation
*
* If changing this file in a way that can affect the __torch_function__
* overhead, please report the benchmarks in 'benchmarks/overrides_benchmark'.
* See the instructions in the 'README.md' in that directory.
*
*/

Expand Down

0 comments on commit 7c825ba

Please sign in to comment.