Skip to content

Commit

Permalink
[export] preserve metadata during nonstrict tracing (pytorch#118607)
Browse files Browse the repository at this point in the history
Previously, nonstrict tracing would wipe metadata of graphmodules, because the wrapper class we're using was not detected as a graphmodule and thus meta preservation was not turned on

Differential Revision: [D53139354](https://our.internmc.facebook.com/intern/diff/D53139354/)
Pull Request resolved: pytorch#118607
Approved by: https://github.com/zhxchen17
  • Loading branch information
suo authored and pytorchmergebot committed Jan 30, 2024
1 parent 644f64f commit 6511811
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 5 deletions.
19 changes: 19 additions & 0 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -2870,6 +2870,25 @@ def forward(self, x):
# this doesn't work today
gm_unflat_strict = unflatten(ep)

def test_nonstrict_retrace_preserves_metadata(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(4, 4)

def forward(self, x):
return self.linear(x)

inp = torch.randn(4, 4)
m = MyModule()
ep = torch.export.export(m, (inp,), {}, strict=False)
# retrace
ep2 = torch.export.export(ep.module(), (inp,), {}, strict=False)

for n1, n2 in zip(list(ep.graph.nodes), list(ep2.graph.nodes)):
self.assertEqual(n1.meta.get("stack_trace"), n2.meta.get("stack_trace"))


@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestOneOffModelExportResult(TestCase):
def test_scaled_dot_product_attention_cpu(self):
Expand Down
13 changes: 8 additions & 5 deletions torch/export/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,9 +551,14 @@ def __init__(self, mod):

def forward(self, *args, **kwargs):
nonlocal out_spec
flat_outs, out_spec = pytree.tree_flatten(
self._export_root(*args, **kwargs)
)
if isinstance(self._export_root, torch.fx.GraphModule):
with torch.fx.traceback.preserve_node_meta():
tree_out = torch.fx.Interpreter(self._export_root).run(
*args, **kwargs
)
else:
tree_out = self._export_root(*args, **kwargs)
flat_outs, out_spec = pytree.tree_flatten(tree_out)
return tuple(flat_outs)

wrapped_mod = Wrapper(mod)
Expand Down Expand Up @@ -581,8 +586,6 @@ def forward(self, *args, **kwargs):
for node in gm.graph.nodes:
if "nn_module_stack" in node.meta:
nn_module_stack = node.meta["nn_module_stack"]
# Delete the wrapper module reference
del nn_module_stack[""]
node.meta["nn_module_stack"] = {
fixup_key(key): val
for key, val in pytree.tree_map(
Expand Down

0 comments on commit 6511811

Please sign in to comment.