Skip to content

Commit

Permalink
[export] handle constant aliasing for export (pytorch#125509)
Browse files Browse the repository at this point in the history
Summary: Currently export will [error out](https://github.com/pytorch/pytorch/blob/2b5ae2611e22d992565f202df9267fe66469efaa/torch/export/_trace.py#L477) if a constant is aliased. This PR supports this by modifying ConstantAttrMap to map constants to a list of FQNs instead of a single FQN, populating the ExportedProgram constants dict to contain multiple entries to the same constant.

Test Plan: added test case in test_export.py

Differential Revision: D56955654

Pull Request resolved: pytorch#125509
Approved by: https://github.com/angelayi, https://github.com/ydwu4
  • Loading branch information
pianpwk authored and pytorchmergebot committed May 10, 2024
1 parent fd816bf commit c9a258e
Show file tree
Hide file tree
Showing 7 changed files with 184 additions and 48 deletions.
34 changes: 34 additions & 0 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4844,6 +4844,40 @@ def forward(self, w, x, y, z):
_disable_forced_specializations=True,
)

def test_constant_aliasing(self):
class M1(torch.nn.Module):
def __init__(self, m2, foo):
super().__init__()
self.m2 = m2
self.foo = foo

def forward(self, x):
return x + self.foo + self.m2(x)

class M2(torch.nn.Module):
def __init__(self):
super().__init__()
self.foo = torch.ones(3, 3)

def forward(self, x):
return x + self.foo

m2 = M2()
m1 = M1(m2, m2.foo)
inps = (torch.ones(3, 3),)
ep = torch.export.export(m1, inps, strict=False)
# check both constants appear in list
self.assertEqual(sorted(list(ep.constants)), ["foo", "m2.foo"])
# check only one input spec exists
num_constant_inputs = [
spec.kind == InputKind.CONSTANT_TENSOR
for spec in ep.graph_signature.input_specs
].count(True)
self.assertEqual(num_constant_inputs, 1)
# unflatten
unflattened = unflatten(ep)
self.assertTrue(torch.allclose(m1(*inps), unflattened(*inps)))


@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestOneOffModelExportResult(TestCase):
Expand Down
14 changes: 8 additions & 6 deletions test/export/test_lift_unlift.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,17 +396,19 @@ def test_dict_api(self):
constant_attr_map = ConstantAttrMap()
const_obj = torch.classes._TorchScriptTesting._Foo(10, 20)
const_tensor = torch.ones(2, 3)
constant_attr_map[const_obj] = "foo.bar"
constant_attr_map[const_tensor] = "foo.bar.baz"
constant_attr_map.add(const_obj, "foo.bar")
constant_attr_map.add(const_tensor, "foo.bar.baz")
self.assertEqual(len(constant_attr_map), 2)
self.assertEqual(list(constant_attr_map), [const_obj, const_tensor])
self.assertEqual(list(constant_attr_map.keys()), [const_obj, const_tensor])
self.assertEqual(list(constant_attr_map.values()), ["foo.bar", "foo.bar.baz"])
self.assertEqual(constant_attr_map[const_obj], "foo.bar")
self.assertEqual(constant_attr_map[const_tensor], "foo.bar.baz")
self.assertEqual(
list(constant_attr_map.values()), [["foo.bar"], ["foo.bar.baz"]]
)
self.assertEqual(constant_attr_map[const_obj], ["foo.bar"])
self.assertEqual(constant_attr_map[const_tensor], ["foo.bar.baz"])
self.assertTrue(const_obj in constant_attr_map)
with self.assertRaises(TypeError):
constant_attr_map[1] = "foo.bar"
constant_attr_map.add(1, "foo.bar")

del constant_attr_map[const_obj]
self.assertEqual(len(constant_attr_map), 1)
Expand Down
60 changes: 60 additions & 0 deletions test/export/test_torchbind.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Owner(s): ["oncall: export"]


import unittest

import torch
import torch.utils._pytree as pytree
from torch._dynamo.testing import EagerAndRecordGraphs
Expand Down Expand Up @@ -361,6 +363,64 @@ def forward(self, token, x, cc):
return (getitem, add)""", # noqa: B950
)

@parametrize("pre_dispatch", [True, False])
@parametrize("fakify_script_obj", [True, False])
def test_torchbind_alias(self, pre_dispatch, fakify_script_obj):
class F2(torch.nn.Module):
def __init__(self, foo):
super().__init__()
self.foo = foo

def forward(self, x):
return x + torch.ops._TorchScriptTesting.takes_foo(self.foo, x)

class F1(torch.nn.Module):
def __init__(self):
super().__init__()
self.alpha = torch.classes._TorchScriptTesting._Foo(10, 20)
if not fakify_script_obj:
qual_name = self.alpha._type().qualified_name()
if torch._library.fake_class_registry.has_fake_class(qual_name):
torch._library.fake_class_registry.deregister_fake_class(
"_TorchScriptTesting::_Foo"
)
self.beta = self.alpha
self.gamma = self.alpha
self.foo = F2(self.gamma)

def forward(self, x):
return (
x
+ torch.ops._TorchScriptTesting.takes_foo(self.gamma, x)
+ self.foo(x)
)

self._test_export_same_as_eager(
F1(), (torch.ones(2, 3),), strict=False, pre_dispatch=pre_dispatch
)

# TODO(pianpwk): look into this
@unittest.expectedFailure
@parametrize("pre_dispatch", [True, False])
@parametrize("fakify_script_obj", [True, False])
def test_torchbind_input_and_alias(self, pre_dispatch, fakify_script_obj):
# alias as model attribute
class F3(torch.nn.Module):
def forward(self, x, foo):
self.foo = foo
return x + self.foo.add_tensor(x)

foo = torch.classes._TorchScriptTesting._Foo(10, 20)
if not fakify_script_obj:
qual_name = foo._type().qualified_name() # type: ignore[att-defined]
if torch._library.fake_class_registry.has_fake_class(qual_name):
torch._library.fake_class_registry.deregister_fake_class(
"_TorchScriptTesting::_Foo"
)
self._test_export_same_as_eager(
F3(), (torch.ones(2, 3), foo), strict=False, pre_dispatch=pre_dispatch
)

@parametrize("pre_dispatch", [True, False])
def test_unlift_custom_obj(self, pre_dispatch):
class MyModule(torch.nn.Module):
Expand Down
23 changes: 10 additions & 13 deletions torch/_export/non_strict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,12 +369,7 @@ def inner(m: torch.nn.Module, prefix_atoms: List[str], constants):
continue

fqn = ".".join(prefix_atoms + [k])
if v in constants:
raise ValueError(
f"Duplicate reference to constant attribute found: '{constants[v]}' and '{fqn}'."
)

constants[v] = fqn
constants.add(v, fqn)
for k, v in m.named_children():
inner(v, prefix_atoms + [k], constants)

Expand Down Expand Up @@ -431,16 +426,18 @@ def _leaf_mod_and_attr(
return cur_mod, last_attr

try:
for obj, fqn in constant_attrs.items():
for obj, fqns in constant_attrs.items():
if isinstance(obj, torch.ScriptObject):
cur_mod, attr = _leaf_mod_and_attr(mod, fqn)
assert obj is getattr(cur_mod, attr)
fake_script_obj = _maybe_fakify_obj(obj)
setattr(cur_mod, attr, fake_script_obj)
fake_constant_attrs[fake_script_obj] = fqn
patched_attr[fqn] = obj
for fqn in fqns:
cur_mod, attr = _leaf_mod_and_attr(mod, fqn)
assert obj is getattr(cur_mod, attr)
setattr(cur_mod, attr, fake_script_obj)
fake_constant_attrs.add(fake_script_obj, fqn)
patched_attr[fqn] = obj
else:
fake_constant_attrs[obj] = fqn
for fqn in fqns:
fake_constant_attrs.add(obj, fqn)

fake_args, fake_kwargs = pytree.tree_map_only(
torch.ScriptObject, _maybe_fakify_obj, (args, kwargs)
Expand Down
45 changes: 35 additions & 10 deletions torch/_export/passes/lift_constants_pass.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import collections
from typing import Any, Dict, Union
from typing import Any, Dict, List, Union

import torch
from torch._export.verifier import SpecViolationError
Expand All @@ -26,7 +26,9 @@ class ConstantAttrMap(collections.abc.MutableMapping):

def __init__(self):
# Underlying dict that we use to implement this mapping.
self._constant_attrs: Dict[Union[int, torch.Tensor, FakeScriptObject], Any] = {}
self._constant_attrs: Dict[
Union[int, torch.Tensor, FakeScriptObject], List[Any]
] = {}
# Map from the hash(ScriptObject) to the ScriptObject itself. Used for
# APIs like `__iter__` that should look like they're returning the
# original ScriptObjects.
Expand All @@ -39,14 +41,25 @@ def __getitem__(
assert isinstance(real_key, (int, torch.Tensor, FakeScriptObject))
return self._constant_attrs[real_key]

def __setitem__(
def __setitem__(self, key: Union[torch.Tensor, torch.ScriptObject], value):
# we shouldn't actually call this, should go to add() instead to handle aliasing
raise NotImplementedError(
"""Directly setting values for ConstantAttrMap is not supported, please use add(key, value) instead.
The same key can be mapped to multiple values, for handling constant aliasing."""
)

def add(
self, key: Union[torch.Tensor, torch.ScriptObject, FakeScriptObject], value: Any
) -> None:
if isinstance(key, torch.ScriptObject):
self._constant_attrs[hash(key)] = value
if hash(key) not in self._constant_attrs:
self._constant_attrs[hash(key)] = []
self._constant_attrs[hash(key)].append(value)
self._script_object_map[hash(key)] = key
elif isinstance(key, (torch.Tensor, FakeScriptObject)):
self._constant_attrs[key] = value
if key not in self._constant_attrs:
self._constant_attrs[key] = []
self._constant_attrs[key].append(value)
else:
raise TypeError(
f"Expected key to be a tensor or ScriptObject, got {type(key)}"
Expand Down Expand Up @@ -83,6 +96,14 @@ def get_constant_fqn(node: torch.fx.Node, constant_name: str) -> str:
return constant_name


def _get_first_fqn(
const_attrs: ConstantAttrMap,
key: Union[torch.Tensor, torch.ScriptObject, FakeScriptObject],
) -> Any:
fqns = const_attrs.get(key)
return fqns[0] if fqns else None


def lift_constants_pass(
gm: torch.fx.GraphModule,
graph_signature: ExportGraphSignature,
Expand Down Expand Up @@ -136,7 +157,7 @@ def lift_constants_pass(
# We already lifted this constant elsewhere. Just rewrite uses
# of this get_attr to point to the already-existing placeholder
# node.
const_placeholder_node = lifted_objs[constant_val]
const_placeholder_node = _get_first_fqn(lifted_objs, constant_val)
node.replace_all_uses_with(const_placeholder_node)
gm.graph.erase_node(node)
continue
Expand All @@ -152,7 +173,7 @@ def lift_constants_pass(
# some name and attach it to the module in which it was used.
if isinstance(constant_val, (torch.ScriptObject, FakeScriptObject)):
constant_kind = InputKind.CUSTOM_OBJ
constant_fqn = constant_attrs.get(constant_val)
constant_fqn = _get_first_fqn(constant_attrs, constant_val)
if constant_fqn is not None:
constant_name = constant_fqn.replace(".", "_")
else:
Expand All @@ -161,7 +182,7 @@ def lift_constants_pass(
num_custom_obj += 1
elif isinstance(constant_val, torch.Tensor):
constant_kind = InputKind.CONSTANT_TENSOR
constant_fqn = constant_attrs.get(constant_val)
constant_fqn = _get_first_fqn(constant_attrs, constant_val)
if constant_fqn is not None:
constant_name = constant_fqn.replace(".", "_")
else:
Expand Down Expand Up @@ -222,7 +243,7 @@ def lift_constants_pass(
f"tried to lift unsupported type {type(constant_val)} from node {node.format_node()}"
)

lifted_objs[constant_val] = const_placeholder_node
lifted_objs.add(constant_val, const_placeholder_node)
node.replace_all_uses_with(const_placeholder_node)
gm.graph.erase_node(node)

Expand All @@ -235,7 +256,11 @@ def lift_constants_pass(
target=constant_fqn,
),
)
all_constants[constant_fqn] = constant_val
if constant_val in constant_attrs:
for fqn in constant_attrs[constant_val]:
all_constants[fqn] = constant_val
else:
all_constants[constant_fqn] = constant_val
first_user_input_loc += 1

return all_constants
Expand Down
8 changes: 5 additions & 3 deletions torch/export/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def _remap_constants(
constants: Dict[str, Union[torch.Tensor, torch.ScriptObject]],
) -> None:
"""Rewrite the graph signature and constants table to use the FQN from the original module."""
remap_table: Dict[str, str] = {}
remap_table: Dict[str, List[str]] = {}
for name, value in constants.items():
if value in orig_constant_attrs:
remap_table[name] = orig_constant_attrs[value]
Expand All @@ -282,11 +282,13 @@ def _remap_constants(
):
orig_target = spec.target
assert orig_target is not None
spec.target = remap_table.get(orig_target, orig_target)
targets = remap_table.get(orig_target, [orig_target])
spec.target = targets[0]

constant = constants[orig_target]
del constants[orig_target]
constants[spec.target] = constant
for target in targets:
constants[target] = constant


def _rename_constants_nodes(
Expand Down
Loading

0 comments on commit c9a258e

Please sign in to comment.