Skip to content

Commit

Permalink
[NNAPI] Add support for unsqueeze, cat, and mean (pytorch#48811)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#48811

Test Plan: Unit tests.

Reviewed By: axitkhurana

Differential Revision: D25317936

Pulled By: dreiss

fbshipit-source-id: 9b3a0a75b8157ae35ac13d52293a67800bad0ded
  • Loading branch information
dreiss authored and facebook-github-bot committed Apr 6, 2021
1 parent 3802edd commit b057d27
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 10 deletions.
62 changes: 62 additions & 0 deletions test/test_nnapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,21 @@ def test_dequantize(self):
torch.nn.quantized.DeQuantize(),
nhwc(qpt([[[[1.0]], [[2.0]]]], 0.25, 2)))

def test_unsqueeze(self):
class UnsqueezeModule(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim

def forward(self, arg):
return arg.unsqueeze(self.dim)

self.check(UnsqueezeModule(-2), torch.randn(4, 2, 2))
self.check(UnsqueezeModule(-1), torch.randn(4, 2, 2))
self.check(UnsqueezeModule(0), torch.randn(4, 2, 2))
self.check(UnsqueezeModule(1), torch.randn(4, 2, 2))
self.check(UnsqueezeModule(2), torch.randn(4, 2, 2))

def test_reshape(self):
class ReshapeModule(torch.nn.Module):
def __init__(self, shape):
Expand All @@ -116,6 +131,36 @@ def forward(self, arg):
ReshapeModule((2, 4)),
nhwc(torch.randn(4, 2, 1, 1)))

def test_cat(self):
class CatModule(torch.nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim

def forward(self, t1, t2):
return torch.cat([t1, t2], self.dim)

self.check(
CatModule(0),
[
torch.randn(1, 2, 3, 3),
torch.randn(2, 2, 3, 3),
])

self.check(
CatModule(1),
[
torch.randn(1, 2, 3, 3),
torch.randn(1, 4, 3, 3),
])

self.check(
CatModule(1),
[
nhwc(torch.randn(1, 2, 3, 3)),
nhwc(torch.randn(1, 4, 3, 3)),
])

def test_pointwise_unary(self):
for op in ["relu", "sigmoid"]:
with self.subTest(op):
Expand Down Expand Up @@ -170,6 +215,23 @@ def test_hardtanh(self):
with self.assertRaisesRegex(Exception, "hardtanh with args"):
self.check(torch.nn.Hardtanh(0.0, 5.0), inp)

def test_mean(self):
class MeanModule(torch.nn.Module):
def __init__(self, dim, keep=False):
super().__init__()
self.dim = dim
self.keep = keep

def forward(self, t):
return torch.mean(t, dim=self.dim, keepdim=self.keep)

self.check(MeanModule(0), torch.randn(2, 3))
self.check(MeanModule(1), torch.randn(2, 3))
self.check(MeanModule([2, 3]), torch.randn(2, 3, 6, 6))
self.check(MeanModule([2, 3]), nhwc(torch.randn(2, 3, 6, 6)))
self.check(MeanModule([-1, -2]), nhwc(torch.randn(2, 3, 6, 6)))
self.check(MeanModule([-1, -2], keep=True), nhwc(torch.randn(2, 3, 6, 6)))

def test_max_pool2d(self):
for (name, inp) in self.float_and_quant_and_nhwc(torch.randn(2, 3, 12, 16), 0.3, 128):
with self.subTest(name):
Expand Down
161 changes: 151 additions & 10 deletions torch/backends/_nnapi/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,12 @@ def tensor_size(op_type, dims):
return size


def change_element(tup, index, value):
ls = list(tup)
ls[index] = value
return tuple(ls)


class ConvPoolArgs2d(NamedTuple):
"""Configuration arguments for a convolution."""
kernel_h: int
Expand Down Expand Up @@ -302,7 +308,7 @@ def __init__(self, config):

self.modules = {}
self.constants = {}
self.tensor_tuples = {}
self.tensor_sequences = {}
self.jitval_operand_map = {}
self.cached_immediates = {}
self.used_weights = []
Expand Down Expand Up @@ -428,9 +434,9 @@ def add_operation(self, opcode, inputs, outputs):
self.operations.append((opcode, len(inputs), len(outputs)))
self.operation_args.extend(inputs + outputs)

def add_tensor_tuple(self, jitval, values):
assert jitval not in self.tensor_tuples
self.tensor_tuples[jitval] = values
def add_tensor_sequence(self, jitval, values):
assert jitval not in self.tensor_sequences
self.tensor_sequences[jitval] = values

def add_constant_value(self, jitval, ctype, value):
assert jitval not in self.constants
Expand Down Expand Up @@ -521,7 +527,7 @@ def serialize_model(self, model, inputs):
self.outputs.append(op_id)
out_dim_orders.append(self.operands[op_id].dim_order.value)
elif retn_input.type().kind() == "TupleType":
for v in self.tensor_tuples[retn_input]:
for v in self.tensor_sequences[retn_input]:
op_id = self.jitval_operand_map[v]
self.outputs.append(op_id)
out_dim_orders.append(self.operands[op_id].dim_order.value)
Expand Down Expand Up @@ -583,10 +589,16 @@ def serialize_ints(ints):
self.add_list_construct(node),
"prim::TupleConstruct": lambda self, node:
self.add_tuple_construct(node),
"aten::unsqueeze": lambda self, node:
self.add_unsqueeze(node),
"aten::reshape": lambda self, node:
self.add_reshape(node),
"aten::size": lambda self, node:
self.add_size(node),
"aten::cat": lambda self, node:
self.add_cat(node),
"aten::mean": lambda self, node:
self.add_mean(node),
"aten::quantize_per_tensor": lambda self, node:
self.add_quantize(node),
"aten::dequantize": lambda self, node:
Expand Down Expand Up @@ -660,19 +672,60 @@ def add_list_construct(self, node):
assert node.outputsSize() == 1
output = node.outputsAt(0)
ctype = output.type()
values = []
const_vals = []
tensors = []
for inp in node.inputs():
_, val = self.get_constant_value(inp)
values.append(val)
self.add_constant_value(output, ctype, values)
if const_vals is not None and inp in self.constants:
_, val = self.get_constant_value(inp)
const_vals.append(val)
else:
const_vals = None
if tensors is not None and inp.type().kind() == "TensorType":
tensors.append(inp)
else:
tensros = None
if const_vals is not None:
# NOTE: Now that TorchScript supports list constants,
# this code path might not be used anymore.
self.add_constant_value(output, ctype, const_vals)
if tensors is not None:
self.add_tensor_sequence(output, tensors)
if const_vals is None and tensors is None:
raise Exception(
"Unable to handle ListConstruct node."
" Neither all constants nor all tensors. %r" % node)

def add_tuple_construct(self, node):
assert node.outputsSize() == 1
output = node.outputsAt(0)
values = []
for inp in node.inputs():
values.append(inp)
self.add_tensor_tuple(output, values)
self.add_tensor_sequence(output, values)

def add_unsqueeze(self, node):
assert node.inputsSize() == 2
assert node.outputsSize() == 1

in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))

_, dim = self.get_constant_value(node.inputsAt(1), "IntType")
assert in_oper.dim_order == DimOrder.PRESUMED_CONTIGUOUS

real_dim = dim if dim >= 0 else dim + len(in_oper.shape) + 1
out_shape_list = list(in_oper.shape)
out_shape_list.insert(real_dim, 1)
out_shape = tuple(out_shape_list)
out_oper = in_oper._replace(shape=out_shape)

inputs = [None] * 2
inputs[0] = in_id
inputs[1] = self.add_immediate_int_scalar(dim)

outputs = [None] * 1
outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)

self.add_operation(NNAPI_OperationCode.EXPAND_DIMS, inputs, outputs)

def add_reshape(self, node):
assert node.inputsSize() == 2
Expand Down Expand Up @@ -712,6 +765,94 @@ def add_size(self, node):
output = node.outputsAt(0)
self.add_constant_value(output, output.type(), res)

def add_cat(self, node):
assert node.inputsSize() == 2
assert node.outputsSize() == 1

tensors = self.tensor_sequences[node.inputsAt(0)]
_, dim = self.get_constant_value(node.inputsAt(1), "IntType")

assert len(tensors) > 0
in_ids = []
out_oper = None
out_dim_size = 0
for inp in tensors:
in_id, in_oper = self.get_tensor_operand_by_jitval(inp)
if out_oper is None:
out_shape = change_element(in_oper.shape, dim, -1)
out_oper = in_oper._replace(shape=out_shape)
assert in_oper.op_type == out_oper.op_type
assert in_oper.dim_order == out_oper.dim_order
assert change_element(in_oper.shape, dim, -1) == change_element(out_oper.shape, dim, -1)
# TODO: Possibly check scale and zero point.
in_ids.append(in_id)
# TODO: Possibly support variable-sized inputs.
out_dim_size += in_oper.shape[dim]

out_oper = out_oper._replace(shape=change_element(out_oper.shape, dim, out_dim_size))

if in_oper.dim_order == DimOrder.CHANNELS_LAST:
assert len(out_oper.shape) == 4
nnapi_dim = [0, 3, 1, 2][dim]
else:
nnapi_dim = dim

inputs = in_ids + [self.add_immediate_int_scalar(nnapi_dim)]

outputs = [None] * 1
outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)

self.add_operation(NNAPI_OperationCode.CONCATENATION, inputs, outputs)

def add_mean(self, node):
assert node.inputsSize() == 4
assert node.outputsSize() == 1

in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
dim_ctype, dim = self.get_constant_value(node.inputsAt(1))
assert dim_ctype.kind() == "ListType"
assert dim_ctype.getElementType().kind() == "IntType"
_, keep_dim = self.get_constant_value(node.inputsAt(2), "BoolType")
# Expect None for dtype
self.get_constant_value(node.inputsAt(3), "NoneType")

if in_oper.dim_order == DimOrder.CHANNELS_LAST:
assert len(in_oper.shape) == 4
nnapi_dim = [[0, 3, 1, 2][d] for d in dim]
else:
nnapi_dim = dim

collapsed_dims = set()
for d in dim:
if d < 0:
d += len(in_oper.shape)
collapsed_dims.add(d)

if in_oper.dim_order == DimOrder.CHANNELS_LAST and not keep_dim:
assert collapsed_dims.issuperset({2, 3})
out_dim_order = DimOrder.PRESUMED_CONTIGUOUS
else:
out_dim_order = in_oper.dim_order

out_shape = []
for i, s in enumerate(in_oper.shape):
if i not in collapsed_dims:
out_shape.append(s)
elif keep_dim:
out_shape.append(1)

out_oper = in_oper._replace(shape=out_shape, dim_order=out_dim_order)

inputs = [None] * 3
inputs[0] = in_id
inputs[1] = self.add_immediate_int_vector(nnapi_dim)
inputs[2] = self.add_immediate_int_scalar(keep_dim)

outputs = [None] * 1
outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)

self.add_operation(NNAPI_OperationCode.MEAN, inputs, outputs)

def add_quantize(self, node):
assert node.inputsSize() == 4
assert node.outputsSize() == 1
Expand Down

0 comments on commit b057d27

Please sign in to comment.