forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmath_benchmark.py
153 lines (139 loc) · 4.05 KB
/
math_benchmark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Microbenchmarks for floating point operations."""
import functools
import google_benchmark as benchmark
import jax
import jax.numpy as jnp
import numpy as np
import os
import sys
from google_benchmark import Counter
from absl import app
from absl import flags
_SET_ENV = flags.DEFINE_multi_string(
"set_env", None,
"Specifies additional environment variables to be injected into the "
"environment (via --set_env=variable=value or --set_env=variable). "
"Using this flag is useful when running on remote machines where we do not "
"have direct control of the environment except for passing argument flags.")
def math_benchmark(*args):
def decorator(func):
for test_case in args[0]:
@benchmark.register(name=f"{func.__name__}_{test_case['name']}")
@functools.wraps(func)
def wrapper(state, test_case=test_case):
return func(state, **test_case)
return wrapper
return decorator
@math_benchmark(
[
{
'name': f'{op.__name__}_{shape}_{dtype}',
'shape': shape,
'dtype': dtype,
'op': op,
}
for op in [
jnp.exp,
jnp.exp2,
jnp.expm1,
jnp.log,
jnp.log2,
jnp.log1p,
jnp.tanh,
]
for shape in [2**i for i in range(10, 15, 2)]
for dtype in ['float32']
]
)
def jax_unary(state, **kwargs):
shape = kwargs['shape']
dtype = kwargs['dtype']
op = kwargs['op']
input0 = np.random.random(shape).astype(dtype)
f = op
f_jitted = jax.jit(f)
f_jitted(input0).block_until_ready()
while state:
f_jitted(input0).block_until_ready()
state.counters['items_per_second'] = Counter(
input0.size * state.iterations, Counter.kIsRate
)
@math_benchmark(
[
{
'name': f'{op.__name__}_{mkn[0]}x{mkn[1]}x{mkn[2]}_{dtype}',
'mkn': mkn,
'dtype': dtype,
'op': op,
}
for op in [
jnp.dot,
]
for mkn in [[2**i, 2**i, 2**i] for i in range(4, 11, 1)] +
[
[1, 2, 256],
[1, 8, 256],
[1, 18, 300],
[1, 37, 256],
[1, 91, 256],
[1, 111, 256],
[1, 192, 192],
[1, 226, 256],
[1, 256, 192],
[1, 256, 256],
[1, 512, 512],
[1, 300, 18],
[21, 24, 1],
[21, 120, 1],
[10, 10, 10],
[100, 100, 100],
[18, 1, 300],
[18, 300, 1],
[300, 1, 18],
[300, 18, 1],
]
for dtype in ['float32']
]
)
def jax_binary_op(state, **kwargs):
mkn = kwargs['mkn']
m = mkn[0]
k = mkn[1]
n = mkn[2]
dtype = kwargs['dtype']
op = kwargs['op']
a = np.random.random([m, k]).astype(dtype)
b = np.random.random([k, n]).astype(dtype)
f = op
f_jitted = jax.jit(f)
f_jitted(a, b).block_until_ready()
while state:
f_jitted(a, b).block_until_ready()
state.counters['items_per_second'] = Counter(
state.iterations, Counter.kIsRate
)
def main(argv):
if _SET_ENV.value:
for env_str in _SET_ENV.value:
# Stop matching at the first '=' since we want to capture
# --set_env='FOO=--foo_a=1 --foo_b=2' all as part of FOO.
env_list = env_str.split('=', 1)
if len(env_list) == 2:
os.environ[env_list[0]] = env_list[1];
benchmark.run_benchmarks()
if __name__ == '__main__':
sys.argv = benchmark.initialize(sys.argv)
app.run(main)