-
Notifications
You must be signed in to change notification settings - Fork 512
/
Copy pathop_any.py
53 lines (43 loc) · 1.71 KB
/
op_any.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
from typing import cast, List
import serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import ( # type: ignore
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg # type: ignore
from serializer.tosa_serializer import TosaOp
from torch.fx import Node
@register_node_visitor
class AnyVisitor(NodeVisitor):
target = "aten.any.dim"
def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
if not (inputs[0].dtype == output.dtype):
raise ValueError(
"All inputs and outputs need same dtype."
f"Got {ts.DTypeNames[inputs[0].dtype]=}, {ts.DTypeNames[output.dtype]=}."
)
if not (inputs[0].dtype == ts.DType.BOOL):
raise ValueError("All inputs need to be BOOL." f"Got {inputs[0].dtype=}")
input_shape = list(inputs[0].shape)
dim = cast(int, inputs[1].number) % len(
input_shape
) # process the negative index
keep_dim = cast(bool, inputs[2].number if len(inputs) > 2 else False)
if not keep_dim:
raise ValueError("This case should be handled by ConvertAnyDimDimsPass")
attr = ts.TosaSerializerAttribute()
attr.AxisAttribute(inputs[0].dim_order.index(dim))
tosa_graph.addOperator(
TosaOp.Op().REDUCE_ANY, [inputs[0].name], [output.name], attr
)