Skip to content

Commit

Permalink
[export] Use meta val from the old nodes in run_decompositions(). (py…
Browse files Browse the repository at this point in the history
…torch#111225)

Summary: fall back to the old nodes when meta val is missing.

Test Plan: buck2 run //executorch/examples/portable/scripts:export -- --model_name=emformer_predict

Differential Revision: D50278439

Pull Request resolved: pytorch#111225
Approved by: https://github.com/larryliu0820
  • Loading branch information
zhxchen17 authored and pytorchmergebot committed Oct 14, 2023
1 parent ac02531 commit 11ac4ac
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions torch/export/exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,8 +687,12 @@ def _get_placeholders(gm):
for old_node, new_node in zip(old_outputs, new_outputs)
}

def make_argument_spec(node) -> ArgumentSpec:
val = node.meta["val"]
def make_argument_spec(old_node, node) -> ArgumentSpec:
if "val" not in node.meta:
assert len(node.users) == 0
val = old_node.meta["val"]
else:
val = node.meta["val"]
if isinstance(val, torch.Tensor):
return TensorArgument(name=node.name)
elif isinstance(val, torch.SymInt):
Expand Down Expand Up @@ -719,15 +723,15 @@ def make_argument_spec(node) -> ArgumentSpec:
grad_user_inputs={},
loss_output=None,
inputs=[
make_argument_spec(node)
for node in gm.graph.nodes
make_argument_spec(old_placeholders[i], node)
for i, node in enumerate(gm.graph.nodes)
if node.op == "placeholder"
],
outputs=[
make_argument_spec(node)
for node in pytree.tree_flatten(
next(iter(reversed(gm.graph.nodes))).args
)[0]
make_argument_spec(old_outputs[i], node)
for i, node in enumerate(
pytree.tree_flatten(next(iter(reversed(gm.graph.nodes))).args)[0]
)
],
)

Expand Down

0 comments on commit 11ac4ac

Please sign in to comment.