forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpass_manager.py
78 lines (65 loc) · 2.75 KB
/
pass_manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from typing import Callable, List, Optional, Union
import torch
import torch.fx.passes.infra.pass_manager as fx
import torch.utils._pytree as pytree
from executorch.exir.error import ExportError, ExportErrorType
from torch.fx.passes.infra.pass_base import PassResult
from typing_extensions import TypeAlias
PassType: TypeAlias = Callable[[torch.fx.GraphModule], Optional[PassResult]]
class PassManager(fx.PassManager):
"""
Class to run multiple passes on a given graph module. The PassManager is
callable so to run it, we can just call the PassManager instance.
Private Attributes:
* **passes**: A list of callable passes
* **params**: An instance of PassManagerParams containing the result of the
flags set in the constructor.
"""
def __init__(
self,
passes: Optional[Union[List[PassType], List[List[PassType]]]] = None,
run_checks_after_each_pass: bool = False,
suppress_check_failures: bool = False,
) -> None:
r"""
Args:
passes: A list of passes
enable_debug_pass: set to true to enable the debug passes
run_checks_after_each_pass: whether to run checks and linting after each pass
"""
# Flatten the passes to a list of callables
passes = passes if passes else []
flattened_passes = [
fx.pass_result_wrapper(fn) for fn in pytree.tree_flatten(passes)[0]
]
super().__init__(
flattened_passes,
run_checks_after_each_pass=run_checks_after_each_pass,
suppress_check_failures=suppress_check_failures,
)
def check(self, module: torch.nn.Module) -> None:
"""
Runs various checks on the given graph module to make sure it contains
the needed data for passes.
Some checks that need to be run:
- Ensure that types of operator node match the types specified in
the node's spec field (ex. if the op returns a tuple then the
node's spec field is a tuple)
- Ensure that the graph module has type torch.fx.GraphModule
"""
assert isinstance(module, fx.GraphModule)
module.recompile()
module.graph.lint()
# TODO(qihan): use verifier.check_is_exir
for node in module.graph.nodes:
if node.op == "call_method":
raise ExportError(
ExportErrorType.NOT_SUPPORTED,
f"call_method `{node}` is not supported except for backend delegate.",
)