Skip to content

Commit

Permalink
Add --printoptions flag to inspect_checkpoint, to control how numpy f…
Browse files Browse the repository at this point in the history
…ormats tensor values. For example, this allows

  --printoptions threshold=1000000
to print all the values in a large tensor instead of ellipsis in place of most of them.
Change: 148876650
  • Loading branch information
tensorflower-gardener committed Mar 1, 2017
1 parent 1617ffb commit 410cde0
Showing 1 changed file with 40 additions and 0 deletions.
40 changes: 40 additions & 0 deletions tensorflow/python/tools/inspect_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@
import argparse
import sys

import numpy as np

from tensorflow.python import pywrap_tensorflow
from tensorflow.python.platform import app
from tensorflow.python.platform import flags

FLAGS = None

Expand Down Expand Up @@ -58,6 +61,38 @@ def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors):
"with SNAPPY.")


def parse_numpy_printoption(kv_str):
"""Sets a single numpy printoption from a string of the form 'x=y'.
See documentation on numpy.set_printoptions() for details about what values
x and y can take. x can be any option listed there other than 'formatter'.
Args:
kv_str: A string of the form 'x=y', such as 'threshold=100000'
Raises:
argparse.ArgumentTypeError: If the string couldn't be used to set any
nump printoption.
"""
k_v_str = kv_str.split("=", 1)
if len(k_v_str) != 2 or not k_v_str[0]:
raise argparse.ArgumentTypeError("'%s' is not in the form k=v." % kv_str)
k, v_str = k_v_str
printoptions = np.get_printoptions()
if k not in printoptions:
raise argparse.ArgumentTypeError("'%s' is not a valid printoption." % k)
v_type = type(printoptions[k])
if v_type is type(None):
raise argparse.ArgumentTypeError(
"Setting '%s' from the command line is not supported." % k)
try:
v = (v_type(v_str) if v_type is not bool
else flags.BooleanParser().Parse(v_str))
except ValueError as e:
raise argparse.ArgumentTypeError(e.message)
np.set_printoptions(**{k: v})


def main(unused_argv):
if not FLAGS.file_name:
print("Usage: inspect_checkpoint --file_name=checkpoint_file_name "
Expand Down Expand Up @@ -87,5 +122,10 @@ def main(unused_argv):
type="bool",
default=False,
help="If True, print the values of all the tensors.")
parser.add_argument(
"--printoptions",
nargs="*",
type=parse_numpy_printoption,
help="Argument for numpy.set_printoptions(), in the form 'k=v'.")
FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)

0 comments on commit 410cde0

Please sign in to comment.