From 6d7b7615b1f9d9b7853f691260a399d61c4d7ee7 Mon Sep 17 00:00:00 2001 From: migeedz Date: Wed, 3 Aug 2022 11:41:20 -0700 Subject: [PATCH] store parameter values as static shapes during constraint generation (#82742) Parameter values are static tensor shapes. With this assumption, we will store those values as static tensor shapes when generating constraints for placeholders. Pull Request resolved: https://github.com/pytorch/pytorch/pull/82742 Approved by: https://github.com/jansel --- .../migrate_gradual_types/constraint_generator.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py index b2be15261fe21..547e3381afd72 100644 --- a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py +++ b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py @@ -1123,7 +1123,16 @@ def generate_constraints_node(self, n: Node, counter): if n.op == 'placeholder': x, counter = gen_tvar(counter) self.symbol_dict[n] = x - n.type = Dyn if not (isinstance(n.type, TensorType) or n.type == Dyn) else n.type + + if n.type != Dyn and (not isinstance(n.type, TensorType)): + + if n.type == torch.nn.parameter.Parameter: + # since we have a parameter, the shape must be static + assert 'example_value' in n.meta + n.type = TensorType(n.meta['example_value'].size()) + else: + n.type = Dyn + c1 = BinConstraintT(n.type, x, op_precision) c2 = BinConstraintT(x, MAX_TENSOR_RANK, op_leq) return [c1, c2], counter