Skip to content

Commit

Permalink
Fixed several saver_test errors in open source and with python3.
Browse files Browse the repository at this point in the history
Fixes tensorflow#1035
Change: 114439527
  • Loading branch information
sherrym authored and tensorflower-gardener committed Feb 11, 2016
1 parent fe398ae commit 73b9dd1
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 19 deletions.
2 changes: 1 addition & 1 deletion google/protobuf
40 changes: 29 additions & 11 deletions tensorflow/python/training/saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,9 +1048,9 @@ def _get_kind_name(item):
Returns:
The string representation of the kind in CollectionDef.
"""
if isinstance(item, six.string_types) or isinstance(item, bytes):
if isinstance(item, (six.string_types, six.binary_type)):
kind = "bytes_list"
elif isinstance(item, (int, long)):
elif isinstance(item, six.integer_types):
kind = "int64_list"
elif isinstance(item, float):
kind = "float_list"
Expand Down Expand Up @@ -1091,10 +1091,15 @@ def _add_collection_def(meta_graph_def, key):
kind = _get_kind_name(collection_list[0])
if kind == "node_list":
getattr(col_def, kind).value.extend([x.name for x in collection_list])
elif kind == "bytes_list":
# NOTE(opensource): This force conversion is to work around the fact
# that Python3 distinguishes between bytes and strings.
getattr(col_def, kind).value.extend(
[compat.as_bytes(x) for x in collection_list])
else:
getattr(col_def, kind).value.extend([x for x in collection_list])
except Exception as e: # pylint: disable=broad-except
logging.warning("Error encountered when adding %s:\n"
logging.warning("Error encountered when serializing %s.\n"
"Type is unsupported, or the types of the items don't "
"match field type in CollectionDef.\n%s" % (key, str(e)))
if key in meta_graph_def.collection_def:
Expand Down Expand Up @@ -1176,20 +1181,27 @@ def _read_meta_graph_file(filename):
IOError: If the file doesn't exist, or cannot be successfully parsed.
"""
meta_graph_def = meta_graph_pb2.MetaGraphDef()
# First try to read it as a binary file.
if not gfile.Exists(filename):
raise IOError("File %s does not exist." % filename)
# First try to read it as a binary file.
with gfile.FastGFile(filename, "rb") as f:
file_content = f.read()
try:
meta_graph_def.ParseFromString(file_content)
return meta_graph_def
except Exception: # pylint: disable=broad-except
try:
# Next try to read it as a text file.
text_format.Merge(file_content, meta_graph_def)
except text_format.ParseError as e:
raise IOError("Cannot parse file %s: %s." % (filename, str(e)))
return meta_graph_def
pass

# Next try to read it as a text file.
with gfile.FastGFile(filename, "r") as f:
file_content = f.read()
try:
text_format.Merge(file_content, meta_graph_def)
return meta_graph_def
except text_format.ParseError as e:
raise IOError("Cannot parse file %s: %s." % (filename, str(e)))

return None


def _import_meta_graph_def(meta_graph_def):
Expand All @@ -1208,7 +1220,7 @@ def _import_meta_graph_def(meta_graph_def):
importer.import_graph_def(meta_graph_def.graph_def, name="")

# Restores all the other collections.
for key, col_def in meta_graph_def.collection_def.iteritems():
for key, col_def in meta_graph_def.collection_def.items():
kind = col_def.WhichOneof("kind")
if kind is None:
logging.error("Cannot identify data type for collection %s. Skipping."
Expand All @@ -1228,6 +1240,12 @@ def _import_meta_graph_def(meta_graph_def):
for value in field.value:
col_op = ops.get_default_graph().as_graph_element(value)
ops.add_to_collection(key, col_op)
elif kind == "int64_list":
# NOTE(opensource): This force conversion is to work around the fact
# that Python2 distinguishes between int and long, while Python3 has
# only int.
for value in field.value:
ops.add_to_collection(key, int(value))
else:
for value in field.value:
ops.add_to_collection(key, value)
Expand Down
14 changes: 7 additions & 7 deletions tensorflow/tensorboard/BUILD
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# Description:
# TensorBoard, a dashboard for investigating TensorFlow

package(default_visibility = ["//third_party/tensorflow:internal"])
package(default_visibility = ["//tensorflow:internal"])

filegroup(
name = "tensorboard_frontend",
srcs = [
"dist/index.html",
"dist/tf-tensorboard.html",
"//third_party/tensorflow/tensorboard/bower:bower",
"//tensorflow/tensorboard/bower:bower",
"TAG",
] + glob(["lib/**/*"]),
)
Expand All @@ -18,8 +18,8 @@ py_library(
srcs = ["backend/tensorboard_handler.py"],
deps = [
":float_wrapper",
"//third_party/tensorflow/python:platform",
"//third_party/tensorflow/python:summary",
"//tensorflow/python:platform",
"//tensorflow/python:summary",
],
srcs_version = "PY2AND3",
)
Expand All @@ -36,7 +36,7 @@ py_test(
srcs = ["backend/float_wrapper_test.py"],
deps = [
":float_wrapper",
"//third_party/tensorflow/python:platform_test",
"//tensorflow/python:platform_test",
],
srcs_version = "PY2AND3",
)
Expand All @@ -47,8 +47,8 @@ py_binary(
data = [":tensorboard_frontend"],
deps = [
":tensorboard_handler",
"//third_party/tensorflow/python:platform",
"//third_party/tensorflow/python:summary",
"//tensorflow/python:platform",
"//tensorflow/python:summary",
],
srcs_version = "PY2AND3",
)

0 comments on commit 73b9dd1

Please sign in to comment.