-
Notifications
You must be signed in to change notification settings - Fork 511
/
Copy pathop_max_pool2d.py
82 lines (68 loc) · 2.19 KB
/
op_max_pool2d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
from typing import List
import serializer.tosa_serializer as ts # type: ignore
import torch
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
get_input_qparams,
get_output_qparams,
)
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from serializer.tosa_serializer import TosaOp
@register_node_visitor
class MaxPool2dVisitor(NodeVisitor):
target = "aten.max_pool2d.default"
def __init__(self, *args):
super().__init__(*args)
def define_node(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
input_tensor = inputs[0]
kernel_size = inputs[1].special
stride = inputs[2].special
try:
pad_size_list = inputs[3].special
pad_size_list = [
pad_size_list[0],
pad_size_list[0],
pad_size_list[1],
pad_size_list[1],
]
except IndexError:
pad_size_list = [0, 0, 0, 0]
accumulator_type = output.dtype
# Initilize zero point to zero.
input_zp = 0
if inputs[0].dtype == ts.DType.INT8:
input_qparams = get_input_qparams(node)
input_zp = input_qparams[0].zp
output_zp = 0
if output.dtype == ts.DType.INT8:
output_qparams = get_output_qparams(node)
output_zp = output_qparams[0].zp
attr = ts.TosaSerializerAttribute()
attr.PoolAttribute(
kernel=kernel_size,
stride=stride,
pad=pad_size_list,
input_zp=input_zp,
output_zp=output_zp,
accum_dtype=accumulator_type,
)
tosa_graph.addOperator(
TosaOp.Op().MAX_POOL2D,
[input_tensor.name],
[output.name],
attr,
)