Skip to content

Commit

Permalink
Add --export_dir and --baseline_dir flags to benchmark.py. (jax-m…
Browse files Browse the repository at this point in the history
…l#2677)

`--export_dir` allows saving benchmark results to CSV files, and
`--baseline_dir` allows comparing results to a baseline exported via
`--export_dir`.
  • Loading branch information
skye authored Apr 13, 2020
1 parent afefc92 commit 8c2901c
Showing 1 changed file with 54 additions and 5 deletions.
59 changes: 54 additions & 5 deletions benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,26 @@
"""A simple Python microbenchmarking library."""

from collections import OrderedDict
import csv
from numbers import Number
import os
import time
from typing import Any, Optional, Union, Callable, List, Dict

from absl import flags
import numpy as onp
from tabulate import tabulate

from jax.util import safe_zip

FLAGS = flags.FLAGS
flags.DEFINE_string(
"export_dir", None,
"If set, will save results as CSV files in the specified directory.")
flags.DEFINE_string(
"baseline_dir", None,
"If set, include comparison to baseline in results. Baselines should be "
"generated with --export_dir and benchmark names are matched to filenames.")

def benchmark(f: Callable[[], Any], iters: Optional[int] = None,
warmup: Optional[int] = None, name: Optional[str] = None,
Expand Down Expand Up @@ -97,14 +107,53 @@ def benchmark_suite(prepare: Callable[..., Callable], params_list: List[Dict],
times.append(benchmark(f, name=subname,
target_total_secs=target_total_secs))

print("---------Benchmark summary for %s---------" % name)
param_names = list(params_list[0].keys())
print(tabulate([tuple(map(_param_str, params.values())) +
(t.mean(), _pstd(t), t.mean() / times[0].mean())
for params, t in safe_zip(params_list, times)],
param_names + ["mean", "%std", "relative"]))
data_header = param_names + ["mean", "%std", "relative"]
data = [list(map(_param_str, params.values())) +
[t.mean(), _pstd(t), t.mean() / times[0].mean()]
for params, t in safe_zip(params_list, times)]

if FLAGS.baseline_dir:
mean_idx = len(param_names)
means = _get_baseline_means(FLAGS.baseline_dir, name)
assert len(means) == len(data), (means, data)
data_header.append("mean/baseline")
for idx, mean in enumerate(means):
data[idx].append(data[idx][mean_idx] / mean)

print("---------Benchmark summary for %s---------" % name)
print(tabulate(data, data_header))
print()

if FLAGS.export_dir:
filename = _export_results(data_header, data, FLAGS.export_dir, name)
print("Wrote %s results to %s" % (name, filename))
print()


def _get_baseline_means(baseline_dir, name):
baseline_dir = os.path.expanduser(baseline_dir)
filename = os.path.join(baseline_dir, name + ".csv")
if not os.path.exists(filename):
raise FileNotFoundError("Can't find baseline file: %s" % filename)
with open(filename, newline="") as csvfile:
reader = csv.reader(csvfile)
header = next(reader)
mean_idx = header.index("mean")
return [float(row[mean_idx]) for row in reader]


def _export_results(data_header, data, export_dir, name):
assert "mean" in data_header # For future comparisons via _get_baseline_means
export_dir = os.path.expanduser(export_dir)
os.makedirs(export_dir, exist_ok=True)
filename = os.path.join(export_dir, name + ".csv")
with open(filename, "w", newline="") as csvfile:
writer = csv.writer(csvfile)
writer.writerow(data_header)
writer.writerows(data)
return filename


def _param_str(param):
if callable(param):
Expand Down

0 comments on commit 8c2901c

Please sign in to comment.