diff --git a/test/export/test_export.py b/test/export/test_export.py index ba8ed0299c975a..d3800c3ff9ffb1 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -2870,6 +2870,25 @@ def forward(self, x): # this doesn't work today gm_unflat_strict = unflatten(ep) + def test_nonstrict_retrace_preserves_metadata(self): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, x): + return self.linear(x) + + inp = torch.randn(4, 4) + m = MyModule() + ep = torch.export.export(m, (inp,), {}, strict=False) + # retrace + ep2 = torch.export.export(ep.module(), (inp,), {}, strict=False) + + for n1, n2 in zip(list(ep.graph.nodes), list(ep2.graph.nodes)): + self.assertEqual(n1.meta.get("stack_trace"), n2.meta.get("stack_trace")) + + @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support") class TestOneOffModelExportResult(TestCase): def test_scaled_dot_product_attention_cpu(self): diff --git a/torch/export/_trace.py b/torch/export/_trace.py index 27977b7bd76bc8..7c035b118aa895 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -551,9 +551,14 @@ def __init__(self, mod): def forward(self, *args, **kwargs): nonlocal out_spec - flat_outs, out_spec = pytree.tree_flatten( - self._export_root(*args, **kwargs) - ) + if isinstance(self._export_root, torch.fx.GraphModule): + with torch.fx.traceback.preserve_node_meta(): + tree_out = torch.fx.Interpreter(self._export_root).run( + *args, **kwargs + ) + else: + tree_out = self._export_root(*args, **kwargs) + flat_outs, out_spec = pytree.tree_flatten(tree_out) return tuple(flat_outs) wrapped_mod = Wrapper(mod) @@ -581,8 +586,6 @@ def forward(self, *args, **kwargs): for node in gm.graph.nodes: if "nn_module_stack" in node.meta: nn_module_stack = node.meta["nn_module_stack"] - # Delete the wrapper module reference - del nn_module_stack[""] node.meta["nn_module_stack"] = { fixup_key(key): val for key, val in pytree.tree_map(