Skip to content

Commit

Permalink
math_benchmark: add --set_env flag
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 515417422
  • Loading branch information
cota authored and jax authors committed Mar 9, 2023
1 parent 3656053 commit 6f1d829
Showing 1 changed file with 23 additions and 1 deletion.
24 changes: 23 additions & 1 deletion benchmarks/math_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,22 @@
import jax
import jax.numpy as jnp
import numpy as np
import os
import sys

from google_benchmark import Counter
from absl import app
from absl import flags


FLAGS = flags.FLAGS
flags.DEFINE_multi_string(
"set_env", None,
"Specifies additional environment variables to be injected into the "
"environment (via --set_env=variable=value or --set_env=variable). "
"Using this flag is useful when running on remote machines where we do not "
"have direct control of the environment except for passing argument flags.")

def math_benchmark(*args):
def decorator(func):
for test_case in args[0]:
Expand Down Expand Up @@ -127,6 +139,16 @@ def jax_binary_op(state, **kwargs):
state.iterations, Counter.kIsRate
)

def main(argv):
if FLAGS.set_env:
for env_str in FLAGS.set_env:
# Stop matching at the first '=' since we want to capture
# --set_env='FOO=--foo_a=1 --foo_b=2' all as part of FOO.
env_list = env_str.split('=', 1)
if len(env_list) == 2:
os.environ[env_list[0]] = env_list[1];
benchmark.run_benchmarks()

if __name__ == '__main__':
benchmark.main()
sys.argv = benchmark.initialize(sys.argv)
app.run(main)

0 comments on commit 6f1d829

Please sign in to comment.