Skip to content

Commit

Permalink
Forward fix failures for torch.export switch to predispatch (pytorch#…
Browse files Browse the repository at this point in the history
…126081)

Summary:
Fixes:
- executorch test
- torchrec test

Test Plan: CI

Differential Revision: D57282304

Pull Request resolved: pytorch#126081
Approved by: https://github.com/angelayi
  • Loading branch information
tugsbayasgalan authored and pytorchmergebot committed May 15, 2024
1 parent 0d49c5c commit 26f6f98
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions torch/_export/passes/replace_set_grad_with_hop_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ def _replace_with_hop(node: torch.fx.Node):
set_grad_node.meta.get("nn_module_stack", {})
)
output_node = next(iter(reversed(sub_gm.graph.nodes)), None)
# Split_module pass intentially doesn't add output node
# if the graph doesn't return anything.
# TODO (tmanlaibaatar) Figure out if this is right behaviour
# for split_module
if isinstance(output_node, torch.fx.Node) and output_node.op != "output":
output_node = None
if output_node is not None:
assert len(output_node.args) == 1
output_args = output_node.args[0]
Expand Down Expand Up @@ -106,9 +112,7 @@ def _replace_with_hop(node: torch.fx.Node):
f"repalce_set_grad_with_hop_pass doesnt' support output type {type(output_args)}"
)
else:
raise NotImplementedError(
"Cannot replace a call_module with a hop if it has no output. This module will gets DCEed."
)
node.graph.erase_node(node)
sub_graph.erase_node(set_grad_node)


Expand Down Expand Up @@ -164,6 +168,7 @@ def _maybe_inline_or_replace_with_hop(node: torch.fx.Node):
else node
),
)
new_gm.recompile()
return new_gm

return gm
Expand Down

0 comments on commit 26f6f98

Please sign in to comment.