Skip to content

Commit

Permalink
Intra-op parallel microbenchmarks for PT (#19997)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch/pytorch#19997
ghimport-source-id: 420d4a68a1ef879beee2734adba8abb575e0b0ab

Differential Revision: D15231375

Pulled By: ilia-cher

fbshipit-source-id: ce7248ea2ebb54d25c9d831c6e3f23f3534557dd
  • Loading branch information
Ilia Cherniavskii authored and facebook-github-bot committed May 7, 2019
1 parent 481b6d0 commit 19e6886
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 8 deletions.
4 changes: 2 additions & 2 deletions benchmarks/operator_benchmark/benchmark_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ def PyTorchOperatorTestCase(test_name, op_type, input_shapes, op_args, run_mode)
tensor_shape = list(shape)
if not is_contig:
tensor_shape = [s * 2 for s in tensor_shape]
if dtype in [torch.float32, torch.float64]:
if dtype in [torch.float32, torch.float64]: # skip float16
input = torch.rand(tensor_shape, dtype=dtype)
elif dtype in [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]:
elif not dtype.is_floating_point:
input = torch.randint(low=0, high=100, size=tensor_shape, dtype=dtype)
else:
input = torch.ones(tensor_shape, dtype=dtype)
Expand Down
24 changes: 18 additions & 6 deletions benchmarks/operator_benchmark/benchmark_test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ def generate_test(configs, map_config, ops, OperatorTestCase):
continue
shapes.update(item)
assert run_mode is not None, "Missing mode in configs"
shapes_args = map_config(**shapes)
if shapes_args is not None:
for op in ops:
for op in ops:
shapes_args = map_config(test_name=op[0], **shapes)
if shapes_args is not None:
OperatorTestCase(
test_name=op[0],
op_type=op[1],
Expand All @@ -50,7 +50,7 @@ def generate_c2_test(configs, c2_map_func, c2_ops):
generate_test(configs, c2_map_func, c2_ops, Caffe2OperatorTestCase)


def map_c2_config_add(M, N, K):
def map_c2_config_add(test_name, M, N, K):
input_one = (M, N, K)
input_two = (M, N, K)
input_shapes = [input_one, input_two]
Expand All @@ -60,7 +60,7 @@ def map_c2_config_add(M, N, K):
map_pt_config_add = map_c2_config_add


def map_c2_config_matmul(M, N, K, trans_a, trans_b, contig, dtype):
def map_c2_config_matmul(test_name, M, N, K, trans_a, trans_b, contig, dtype):
if not contig or dtype != torch.float32:
return None
input_one = (N, M) if trans_a else (M, N)
Expand All @@ -70,9 +70,21 @@ def map_c2_config_matmul(M, N, K, trans_a, trans_b, contig, dtype):
return (input_shapes, args)


def map_pt_config_matmul(M, N, K, trans_a, trans_b, contig, dtype):
def map_pt_config_matmul(test_name, M, N, K, trans_a, trans_b, contig, dtype):
if trans_a or trans_b:
return None
input_shapes = [(M, N), (N, K)]
args = {'contig': contig, 'dtype': dtype}
return (input_shapes, args)


def map_pt_config_intraop(test_name, N, M, contig, dtype):
if test_name in ['bitor', 'cbitor']:
if dtype.is_floating_point:
return None
if test_name in ['tanh', 'sigmoid', 'sumall']:
if not dtype.is_floating_point:
return None
input_shapes = [(N, M), (N, M)]
args = {'contig': contig, 'dtype': dtype}
return (input_shapes, args)
3 changes: 3 additions & 0 deletions benchmarks/operator_benchmark/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,14 @@ def generate_configs(**configs):
results = configs['sample_func'](*result)
return results


def is_caffe2_enabled(framework_arg):
return 'Caffe2' in framework_arg


def is_pytorch_enabled(framework_arg):
return 'PyTorch' in framework_arg


def get_requested_frameworks(framework_arg):
return [fr.strip() for fr in framework_arg.split(',') if len(fr.strip()) > 0]
92 changes: 92 additions & 0 deletions benchmarks/operator_benchmark/ops/intraop_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

from operator_benchmark import benchmark_core, benchmark_runner
from operator_benchmark.benchmark_test_generator import *

import torch


"""Microbenchmarks for PyTorch CPU intra-op parallelism.
Tests the following functions:
- bitor, cbitor
- tensor-scalar and tensor-tensor element-wise function, integer-only
- tahn and sigmoid
- unary ops
- sumall
- basic reduction function
"""

# Config
config = generate_configs(
N=[128, 1024, 4096],
M=[128, 1024, 4096],
dtype=[torch.float32, torch.int32],
contig=[True, False],
mode=['short'],
sample_func=cross_product
)


def torch_or(tensor_arg):
jit_ior_loop_code = """\
def forward(self, a, b, iterations):
# type: (Tensor, Tensor, int)
for _ in range(iterations):
a.__ior__({})
return a
"""
jit_ior_loop = torch.jit.ScriptModule()
jit_ior_loop.define(jit_ior_loop_code.format("b" if tensor_arg else "42"))

print("torch_or(", tensor_arg, "):\n", jit_ior_loop.code)
return jit_ior_loop


def torch_unary(op_str):
jit_op_loop_code = """\
def forward(self, a, b, iterations):
# type: (Tensor, Tensor, int)
for _ in range(iterations):
a.{}()
return a
"""
jit_op_loop = torch.jit.ScriptModule()
jit_op_loop.define(jit_op_loop_code.format(op_str))

print("torch_unary(", op_str, "):\n", jit_op_loop.code)
return jit_op_loop


@torch.jit.script
def torch_sumall(a, b, iterations):
# type: (Tensor, Tensor, int)
result = 0.0
for _ in range(iterations):
result += float(torch.sum(a))
a[0][0] += 0.01
return result

print("torch_sumall:\n", torch_sumall.code)

@benchmark_core.register_test
def test_th_intraop():
generate_pt_test(
[config],
map_pt_config_intraop,
[('bitor', torch_or(False)),
('cbitor', torch_or(True)),
('tanh', torch_unary('tanh_')),
('sigmoid', torch_unary('sigmoid_')),
('sumall', torch_sumall)]
)


if __name__ == "__main__":
benchmark_runner.main()

0 comments on commit 19e6886

Please sign in to comment.