From 410cde0425a751c722e5c5cc8c1aeb84ccbdc5d4 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 1 Mar 2017 04:54:02 -0800 Subject: [PATCH] Add --printoptions flag to inspect_checkpoint, to control how numpy formats tensor values. For example, this allows --printoptions threshold=1000000 to print all the values in a large tensor instead of ellipsis in place of most of them. Change: 148876650 --- tensorflow/python/tools/inspect_checkpoint.py | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/tensorflow/python/tools/inspect_checkpoint.py b/tensorflow/python/tools/inspect_checkpoint.py index e218fd06ab4bf5..a6bda5d3053034 100644 --- a/tensorflow/python/tools/inspect_checkpoint.py +++ b/tensorflow/python/tools/inspect_checkpoint.py @@ -20,8 +20,11 @@ import argparse import sys +import numpy as np + from tensorflow.python import pywrap_tensorflow from tensorflow.python.platform import app +from tensorflow.python.platform import flags FLAGS = None @@ -58,6 +61,38 @@ def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors): "with SNAPPY.") +def parse_numpy_printoption(kv_str): + """Sets a single numpy printoption from a string of the form 'x=y'. + + See documentation on numpy.set_printoptions() for details about what values + x and y can take. x can be any option listed there other than 'formatter'. + + Args: + kv_str: A string of the form 'x=y', such as 'threshold=100000' + + Raises: + argparse.ArgumentTypeError: If the string couldn't be used to set any + nump printoption. + """ + k_v_str = kv_str.split("=", 1) + if len(k_v_str) != 2 or not k_v_str[0]: + raise argparse.ArgumentTypeError("'%s' is not in the form k=v." % kv_str) + k, v_str = k_v_str + printoptions = np.get_printoptions() + if k not in printoptions: + raise argparse.ArgumentTypeError("'%s' is not a valid printoption." % k) + v_type = type(printoptions[k]) + if v_type is type(None): + raise argparse.ArgumentTypeError( + "Setting '%s' from the command line is not supported." % k) + try: + v = (v_type(v_str) if v_type is not bool + else flags.BooleanParser().Parse(v_str)) + except ValueError as e: + raise argparse.ArgumentTypeError(e.message) + np.set_printoptions(**{k: v}) + + def main(unused_argv): if not FLAGS.file_name: print("Usage: inspect_checkpoint --file_name=checkpoint_file_name " @@ -87,5 +122,10 @@ def main(unused_argv): type="bool", default=False, help="If True, print the values of all the tensors.") + parser.add_argument( + "--printoptions", + nargs="*", + type=parse_numpy_printoption, + help="Argument for numpy.set_printoptions(), in the form 'k=v'.") FLAGS, unparsed = parser.parse_known_args() app.run(main=main, argv=[sys.argv[0]] + unparsed)