Skip to content

Commit

Permalink
[NNAPI] Update support for Linear (pytorch#54695)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#54695

Previously, torch.nn.Linear was calling aten::addmm internally.  Now
it's calling aten::linear, so add support for that.

Test Plan: Unit test

Reviewed By: axitkhurana

Differential Revision: D27536795

Pulled By: dreiss

fbshipit-source-id: 42c8d2a80b20ac12ed9bba599c5e0e874256bb13
  • Loading branch information
dreiss authored and facebook-github-bot committed Apr 6, 2021
1 parent 8d960f7 commit 8fcf9ca
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion torch/backends/_nnapi/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,8 @@ def serialize_ints(ints):
self.add_prelu_op(node),
"aten::addmm": lambda self, node:
self.add_addmm(node),
"aten::linear": lambda self, node:
self.add_linear(node),
"aten::_convolution": lambda self, node:
self.add_conv_underscore(node),
"aten::conv2d": lambda self, node:
Expand Down Expand Up @@ -1017,6 +1019,16 @@ def add_addmm(self, node):
if scale_value != 1:
raise Exception("NNAPI Fully-Connected does not support alpha and beta.")

self.add_addmm_or_linear(node, True, jit_input, jit_weight, jit_bias)

def add_linear(self, node):
assert node.inputsSize() == 3
assert node.outputsSize() == 1
jit_input, jit_weight, jit_bias = node.inputs()

self.add_addmm_or_linear(node, False, jit_input, jit_weight, jit_bias)

def add_addmm_or_linear(self, node, transpose_weight, jit_input, jit_weight, jit_bias):
input_id, input_oper = self.get_tensor_operand_by_jitval(jit_input)
bias_id, bias_oper = self.get_tensor_operand_for_weight(jit_bias)

Expand All @@ -1026,7 +1038,10 @@ def add_addmm(self, node):
# TODO: Transform at load time to share weights with CPU model.
_, weight_tensor = self.get_constant_value(jit_weight, "TensorType")
assert len(weight_tensor.shape) == 2
nnapi_weight_tensor = weight_tensor.t().contiguous()
if transpose_weight:
nnapi_weight_tensor = weight_tensor.t().contiguous()
else:
nnapi_weight_tensor = weight_tensor.contiguous()
weight_id = self.add_tensor_operand_for_weight(nnapi_weight_tensor)
weight_oper = self.operands[weight_id]

Expand Down

0 comments on commit 8fcf9ca

Please sign in to comment.