-
Notifications
You must be signed in to change notification settings - Fork 508
/
Copy pathpartitioner.py
106 lines (86 loc) · 3.76 KB
/
partitioner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# Copyright (c) 2024 MediaTek Inc.
#
# Licensed under the BSD License (the "License"); you may not use this file
# except in compliance with the License. See the license file in the root
# directory of this source tree for more details.
from typing import Callable, final, List, Optional, Tuple
import torch
from executorch.backends.mediatek.preprocess import NeuropilotBackend
from executorch.exir.backend.backend_details import CompileSpec
from executorch.exir.backend.partitioner import (
DelegationSpec,
Partitioner,
PartitionResult,
)
from executorch.exir.backend.utils import tag_constant_data
from mtk_converter.python.converters.pytorch import importer_v2
from torch.export.exported_program import ExportedProgram
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import OperatorSupportBase
class NeuropilotOperatorsSupport(OperatorSupportBase):
def __init__(
self,
op_types_to_skip: Optional[set] = None,
op_names_to_skip: Optional[set] = None,
) -> None:
if op_types_to_skip is None:
op_types_to_skip = set()
if op_names_to_skip is None:
op_names_to_skip = set()
self._op_types_to_skip = op_types_to_skip
self._op_names_to_skip = op_names_to_skip
def is_node_supported(self, _, node: torch.fx.Node) -> bool:
# Handle 'call_function' only cause 'placeholder' and 'output' cannot be tagged.
# Ref: https://github.com/pytorch/executorch/pull/1398
if node.op != "call_function":
return False
op_type = node.target.__name__
# Skip until we can handle the dimension order representation
if op_type == "aten._to_copy.default":
return False
if op_type in self._op_types_to_skip or node.name in self._op_names_to_skip:
print(
f"[Neuropilot Backend] The {op_type} operator with name '{node.name}' is skipped."
)
return False
return importer_v2.is_fx_node_supported(node)
@final
class NeuropilotPartitioner(Partitioner):
def __init__(
self,
compile_spec: List[CompileSpec],
op_types_to_skip: Optional[set] = None,
op_names_to_skip: Optional[set] = None,
) -> None:
self.delegation_spec = DelegationSpec(NeuropilotBackend.__name__, compile_spec)
self._op_types_to_skip = op_types_to_skip
self._op_names_to_skip = op_names_to_skip
def ops_to_not_decompose(
self,
ep: ExportedProgram,
) -> Tuple[List[torch._ops.OpOverload], Optional[Callable[[torch.fx.Node], bool]]]:
ops_not_decompose = [
torch.ops.aten.pixel_shuffle.default,
torch.ops.aten.upsample_bilinear2d.default,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.upsample_nearest2d.default,
torch.ops.aten.upsample_nearest2d.vec,
]
return (ops_not_decompose, None)
def partition(self, exported_program: ExportedProgram) -> PartitionResult:
capability_partitioner = CapabilityBasedPartitioner(
exported_program.graph_module,
NeuropilotOperatorsSupport(self._op_types_to_skip, self._op_names_to_skip),
allows_single_node_partition=True,
)
partition_list = capability_partitioner.propose_partitions()
partition_tags = {}
for partition in partition_list:
for node in partition.nodes:
tag = f"tag{partition.id}"
node.meta["delegation_tag"] = tag
partition_tags[tag] = self.delegation_spec
tag_constant_data(exported_program)
return PartitionResult(
tagged_exported_program=exported_program, partition_tags=partition_tags
)