Skip to content

Commit

Permalink
[DEBUG] modify some objects that is deprecated
Browse files Browse the repository at this point in the history
amanda/src/amanda/conversion/tf.py:
tf.MetaGraphDef/NameAttrList/Session/NodeDef/GraphDef/AttrValue/train.Saver
     =>tf.compat.v1.MetaGraphDef/...
tf.train.SessionRunHook
     =>tf.estimator.SessionRunHook

amanda/conversion/tensorflow_updater.py:
tf.Session
     =>tf.compat.v1.Session

Signed-off-by: shuoyuedge <[email protected]>
  • Loading branch information
YushuoEdge committed Mar 6, 2022
1 parent fc81f2f commit 390a8dc
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 28 deletions.
2 changes: 1 addition & 1 deletion src/amanda/conversion/tensorflow_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def register_import_hook() -> None:

def register_intercepts() -> None:
intercepts.register(
tf.Session.run, intercepts.to_handler(session_run_wrapper), key="amanda"
tf.compat.v1.Session.run, intercepts.to_handler(session_run_wrapper), key="amanda"
)
intercepts.register(
tf.estimator.Estimator.train, intercepts.to_handler(train_wrapper), key="amanda"
Expand Down
54 changes: 27 additions & 27 deletions src/amanda/conversion/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,13 @@ def __post_init__(self):
self.register_dtype_name(dtype.name, _tf_serde)
for proto_type in [
SaverDef,
tf.MetaGraphDef.MetaInfoDef,
tf.compat.v1.MetaGraphDef.MetaInfoDef,
AssetFileDef,
SavedObjectGraph,
SignatureDef,
VersionDef,
tensor_pb2.TensorProto,
tf.NameAttrList,
tf.compat.v1.NameAttrList,
]:
serde = ProtoToDictSerde(
proto_type,
Expand All @@ -170,7 +170,7 @@ def __post_init__(self):
def import_from_graph(
tf_graph: tf.Graph,
saver_def: SaverDef = None,
session: tf.Session = None,
session: tf.compat.v1.Session = None,
) -> Graph:
graph_def = tf_graph.as_graph_def()
graph = import_from_graph_def(graph_def)
Expand Down Expand Up @@ -212,13 +212,13 @@ def import_from_graph(
return graph


def init_op_attrs(op: Op, node: tf.NodeDef):
def init_op_attrs(op: Op, node: tf.compat.v1.NodeDef):
op.namespace = tf_namespace()
if node.HasField("experimental_debug_info"):
op.attrs["experimental_debug_info"] = node.experimental_debug_info


def update_initialized_variables(graph: Graph, tf_graph: tf.Graph, session: tf.Session):
def update_initialized_variables(graph: Graph, tf_graph: tf.Graph, session: tf.compat.v1.Session):
with tf_graph.as_default():
all_variables = tf.global_variables()
uninitialized_variables = session.run(
Expand Down Expand Up @@ -252,15 +252,15 @@ def lower_name(name: str):
return name.lower()


def import_from_graph_def(graph_def: Union[tf.GraphDef, str, bytes, Path]) -> Graph:
graph_def = to_proto(graph_def, tf.GraphDef)
def import_from_graph_def(graph_def: Union[tf.compat.v1.GraphDef, str, bytes, Path]) -> Graph:
graph_def = to_proto(graph_def, tf.compat.v1.GraphDef)
graph = create_graph(
namespace=tf_namespace(),
)
tf_graph = tf.Graph()
name_to_node = {node.name: node for node in graph_def.node}

def add_op(node: tf.NodeDef):
def add_op(node: tf.compat.v1.NodeDef):
if graph.get_op(node.name) is not None:
return
input_tensors: List[TFTensor] = []
Expand Down Expand Up @@ -319,13 +319,13 @@ def add_op(node: tf.NodeDef):

def import_from_pbtxt(file: Union[str, Path]) -> Graph:
file = Path(file)
graph_def = tf.GraphDef()
graph_def = tf.compat.v1.GraphDef()
google.protobuf.text_format.Parse(file.read_text(), graph_def)
return import_from_graph_def(graph_def)


def extract_meta_graph_fields(graph, meta_graph):
new_meta_info_def = tf.MetaGraphDef.MetaInfoDef()
new_meta_info_def = tf.compat.v1.MetaGraphDef.MetaInfoDef()
new_meta_info_def.CopyFrom(meta_graph.meta_info_def)
new_meta_info_def.stripped_op_list.Clear()
graph.attrs["meta_info_def"] = new_meta_info_def
Expand Down Expand Up @@ -353,16 +353,16 @@ def construct_meta_graph_fields(graph, meta_graph, graph_def):


def import_from_meta_graph(
meta_graph: Union[tf.MetaGraphDef, str, bytes, Path],
meta_graph: Union[tf.compat.v1.MetaGraphDef, str, bytes, Path],
checkpoint: Union[str, Path] = None,
session: tf.Session = None,
session: tf.compat.v1.Session = None,
) -> Graph:
meta_graph = to_proto(meta_graph, tf.MetaGraphDef)
meta_graph = to_proto(meta_graph, tf.compat.v1.MetaGraphDef)
with tf.Graph().as_default() as tf_graph, ExitStack() as exit_stack:
saver = tf.train.import_meta_graph(meta_graph)
if checkpoint is not None:
if session is None:
session = tf.Session()
session = tf.compat.v1.Session()
exit_stack.enter_context(session)
saver.restore(session, str(checkpoint))
graph = import_from_graph(tf_graph, saver.as_saver_def(), session)
Expand All @@ -378,7 +378,7 @@ def import_from_checkpoint(path: Union[str, Path]) -> Graph:
def import_from_saved_model(path: Union[str, Path], tags: List[str]) -> Graph:
path = str(path)
with tf.Graph().as_default() as tf_graph:
with tf.Session() as session:
with tf.compat.v1.Session() as session:
meta_graph = tf.saved_model.load(session, tags, path)
graph = import_from_graph(tf_graph, meta_graph.saver_def, session)
extract_meta_graph_fields(graph, meta_graph)
Expand Down Expand Up @@ -460,7 +460,7 @@ def get_dtypes(tf_graph, node_def):
]


def from_attr_proto(attr_value: tf.AttrValue) -> Any:
def from_attr_proto(attr_value: tf.compat.v1.AttrValue) -> Any:
field_name = attr_value.WhichOneof("value")
if field_name == "s":
return attr_value.s
Expand Down Expand Up @@ -511,11 +511,11 @@ def get_tensor_name_by_port(port: OutputPort, compact: bool = False) -> str:


def export_to_graph(
graph: Graph, session: tf.Session = None
) -> Tuple[tf.Graph, tf.train.Saver, tf.Session]:
graph: Graph, session: tf.compat.v1.Session = None
) -> Tuple[tf.Graph, tf.compat.v1.train.Saver, tf.compat.v1.Session]:
def export_fn(tf_graph, session):
if "saver_def" in graph.attrs:
meta_graph = tf.MetaGraphDef()
meta_graph = tf.compat.v1.MetaGraphDef()
graph_def = export_to_graph_def(graph)
meta_graph.graph_def.CopyFrom(graph_def)
meta_graph.saver_def.CopyFrom(graph.attrs["saver_def"])
Expand Down Expand Up @@ -591,11 +591,11 @@ def as_op(initializer):
with tf.Graph().as_default() as tf_graph:
session_config = tf.ConfigProto()
session_config.gpu_options.allow_growth = True
session = tf.Session(config=session_config)
session = tf.compat.v1.Session(config=session_config)
return export_fn(tf_graph, session)


def export_to_graph_def(graph: Graph) -> tf.GraphDef:
def export_to_graph_def(graph: Graph) -> tf.compat.v1.GraphDef:
if not graph.namespace.belong_to(tf_namespace()):
raise MismatchNamespaceError(expect=tf_namespace(), actual=graph.namespace)
tf_graph = tf.Graph()
Expand Down Expand Up @@ -651,7 +651,7 @@ def export_to_checkpoint(graph: Graph, path: Union[str, Path]) -> None:
tf_graph, saver, session = export_to_graph(graph)
with session, tf_graph.as_default():
if saver is None:
saver = tf.train.Saver()
saver = tf.compat.v1.train.Saver()
saver.save(session, str(path))


Expand Down Expand Up @@ -683,7 +683,7 @@ def to_attrs_proto(
attr_protos = {}
attr_defs = {attr_def.name: attr_def for attr_def in op_def.attr}
for key, value in attrs.items():
attr_value = tf.AttrValue()
attr_value = tf.compat.AttrValue()
if key in attr_defs:
attr_def = attr_defs[key]
elif value is None:
Expand All @@ -706,7 +706,7 @@ def to_attrs_proto(
attr_def.type = "shape"
elif isinstance(value, tensor_pb2.TensorProto):
attr_def.type = "tensor"
elif isinstance(value, tf.NameAttrList):
elif isinstance(value, tf.compat.v1.NameAttrList):
attr_def.type = "func"
elif isinstance(value, list) and len(value) == 0:
attr_value.list.SetInParent()
Expand Down Expand Up @@ -808,7 +808,7 @@ def to_attrs_proto(
elif attr_def.type == "list(tensor)":
attr_value.list.tensor.extend([_MakeTensor(x, key) for x in value])
elif attr_def.type == "func":
if isinstance(value, tf.NameAttrList):
if isinstance(value, tf.compat.v1.NameAttrList):
attr_value.func.CopyFrom(value)
elif isinstance(value, compat.bytes_or_text_types):
attr_value.func.name = value
Expand All @@ -822,7 +822,7 @@ def to_attrs_proto(
return attr_protos


class AmandaHook(tf.train.SessionRunHook):
class AmandaHook(tf.estimator.SessionRunHook):
def __init__(self, context: EventContext):
self.context = context

Expand Down Expand Up @@ -1469,7 +1469,7 @@ def end(self, is_enabled: bool) -> None:
self.end_ops_list.insert(0, set(op.name for op in tf_graph.get_operations()))


class FuncSessionHook(tf.train.SessionRunHook):
class FuncSessionHook(tf.estimator.SessionRunHook):
def __init__(
self,
after_create_session=None,
Expand Down

0 comments on commit 390a8dc

Please sign in to comment.