Skip to content

Commit

Permalink
Embedding rule for TorchDynamo (pytorch#82163)
Browse files Browse the repository at this point in the history
- Embedding rule used in XGLM tracing with Dynamo
- Modify constraint generator to replace unknown annotations with Dyn
- Tests
Pull Request resolved: pytorch#82163
Approved by: https://github.com/jansel
  • Loading branch information
migeed-z authored and pytorchmergebot committed Aug 3, 2022
1 parent 50a1124 commit f6c2a75
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 1 deletion.
50 changes: 50 additions & 0 deletions test/fx/test_z3_gradual_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,34 @@
HAS_TORCHVISION = False
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")

class TorchDynamoUseCases(unittest.TestCase):


def test_reshape(self):
"""
In this example, we prove that some nodes must
always have a fixed shape regardless of the input
"""

class BasicBlock(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x: Dyn):
y = x.view(100)
tmp = y.size()[0]
return tmp

symbolic_traced: torch.fx.GraphModule = symbolic_trace(BasicBlock())
transformed = transform_all_constraints(symbolic_traced, counter=0)

s = z3.Solver()
s.add(transformed)
self.assertEqual(s.check(), z3.sat)
dim = z3.Int(4)
self.assertEqual(s.model()[dim], 100)
# print(s.model()[dim])


class HFOperations(unittest.TestCase):

Expand Down Expand Up @@ -719,6 +747,28 @@ def forward(self, x: TensorType([2, 4])):
self.assertEquals(s.check(), z3.sat)


def test_embedding_2(self):
class BasicBlock(torch.nn.Module):
def __init__(self):
super(BasicBlock, self).__init__()

def forward(self, x: TensorType([2, 4]), y: TensorType([Dyn, 1024])):
return torch.nn.functional.embedding(x, y)

B = BasicBlock().forward(torch.ones([2, 4], dtype=torch.long), torch.rand(256008, 1024)).size()
ast_rewriter = RewritingTracer()
graph = ast_rewriter.trace(BasicBlock())
traced = GraphModule(ast_rewriter.root, graph, "gm")
transformed = transform_all_constraints(traced, counter=0)
s = z3.Solver()
s.add(transformed)
self.assertEquals(s.check(), z3.sat)
embedding_result = z3.Const(5, tensor_type)

assert s.model()[embedding_result].arg(0).arg(1) == B[0]
assert s.model()[embedding_result].arg(1).arg(1) == B[1]
assert s.model()[embedding_result].arg(2).arg(1) == B[2]

def test_size_two_args(self):
class BasicBlock(torch.nn.Module):
def __init__(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,14 +256,30 @@ def masked_fill_inference_rule(n: Node, symbols, constraints, counter):
raise NotImplementedError('Not yet implemented')


@register_inference_rule(torch.nn.functional.embedding)
def embedding_inference_rule_functional(n: Node, symbols, constraints, counter):
assert isinstance(n.args[0], Node)

embedding_dim_weights = symbols[n.args[1]]

# will treat this as a static shape. So we will not use matching.
weight_dims, counter = gen_tensor_dims(2, counter)
equality_constraint = BinConstraintT(embedding_dim_weights, TensorType(weight_dims), op_eq)
embedding_dim = weight_dims[1]
constraints, counter = gen_embedding_rules(n, symbols, embedding_dim, counter)
return [equality_constraint] + constraints, counter


@register_inference_rule(torch.nn.modules.sparse.Embedding)
def embedding_inference_rule(n: Node, module_instance, symbols, constraints, counter):
"""
The output shape differs from the input shape in the last dimension
"""
assert isinstance(n.args[0], Node)
return gen_embedding_rules(n, symbols, module_instance.embedding_dim, counter)


embedding_dim = module_instance.embedding_dim # number
def gen_embedding_rules(n: Node, symbols, embedding_dim, counter):

embedding_output, counter = gen_tvar(counter)
symbols[n] = embedding_output
Expand Down Expand Up @@ -1077,6 +1093,7 @@ def generate_constraints_node(self, n: Node, counter):
if n.op == 'placeholder':
x, counter = gen_tvar(counter)
self.symbol_dict[n] = x
n.type = Dyn if not (isinstance(n.type, TensorType) or n.type == Dyn) else n.type
c1 = BinConstraintT(n.type, x, op_precision)
c2 = BinConstraintT(x, MAX_TENSOR_RANK, op_leq)
return [c1, c2], counter
Expand Down

0 comments on commit f6c2a75

Please sign in to comment.