Skip to content

Commit

Permalink
linear constraints (pytorch#82614)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#82614
Approved by: https://github.com/jansel
  • Loading branch information
migeed-z authored and pytorchmergebot committed Aug 3, 2022
1 parent b858abd commit 1f29a5f
Showing 1 changed file with 18 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -895,13 +895,27 @@ def relu_inference_rule(n: Node, module_instance, symbols, constraints, counter)
assert isinstance(input, TVar)
return [BinConstraintT(input, output, op_eq)], counter


@register_inference_rule(torch.nn.Linear)
def linear_inference_rule(n: Node, module_instance, symbols, constraints, counter):
"""
Input and output sizes should be the same except for the last dimension
If the input is Dyn, then so should the output
"""
assert isinstance(n.args[0], Node)
return linear_constraints(n, module_instance.in_features, module_instance.out_features, symbols, counter)


@register_inference_rule(torch._C._nn.linear) # type: ignore[attr-defined]
def torch_linear_inference_rule(n: Node, symbols, constraints, counter):
assert isinstance(n.args[0], Node)
weight_dims, counter = gen_tensor_dims(2, counter)
equality_constraint = BinConstraintT(n.args[1], TensorType(weight_dims), op_eq)
constraints, counter = linear_constraints(n, weight_dims[0], weight_dims[1], symbols, counter)
return [equality_constraint] + constraints, counter


def linear_constraints(n: Node, in_features, out_features, symbols, counter):
linear_output, counter = gen_tvar(counter)
symbols[n] = linear_output
linear_input = symbols[n.args[0]]
Expand All @@ -920,11 +934,9 @@ def linear_inference_rule(n: Node, module_instance, symbols, constraints, counte

c_tensor_i = Conj([BinConstraintT(linear_input, TensorType(new_dims_rhs_1), op_eq),
BinConstraintT(linear_output, TensorType(new_dims_rhs_2), op_eq)] +
add_linear_constraints(new_dims_rhs_1, new_dims_rhs_2, module_instance) +
add_linear_constraints(new_dims_rhs_1, new_dims_rhs_2, in_features, out_features) +
nat_constraints)
c2.append(c_tensor_i)


return [Disj([c1, Disj(c2)])], counter

def add_layer_norm_constraints(input_dim, normalized_dim):
Expand All @@ -948,13 +960,13 @@ def add_layer_norm_constraints(input_dim, normalized_dim):
return constraints


def add_linear_constraints(dims1, dims2, module_instance):
def add_linear_constraints(dims1, dims2, in_features, out_features):
assert len(dims1) == len(dims2)
constraints = []
for i in range(len(dims1)):
if i == len(dims1) - 1:
constraints.append(BinConstraintD(dims1[i], module_instance.in_features, op_consistency))
constraints.append(BinConstraintD(dims2[i], module_instance.out_features, op_eq))
constraints.append(BinConstraintD(dims1[i], in_features, op_consistency))
constraints.append(BinConstraintD(dims2[i], out_features, op_eq))
else:
constraints.append(BinConstraintD(dims1[i], dims2[i], op_eq))

Expand Down

0 comments on commit 1f29a5f

Please sign in to comment.