-
Notifications
You must be signed in to change notification settings - Fork 508
/
Copy pathvulkan_preprocess.py
223 lines (187 loc) · 8.3 KB
/
vulkan_preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from functools import partial
from typing import Any, Dict, final, List
import executorch.backends.vulkan.utils as utils
from executorch.backends.transforms.addmm_mm_to_linear import AddmmToLinearTransform
from executorch.backends.transforms.fuse_batch_norm_with_conv import (
FuseBatchNormWithConvPass,
)
from executorch.backends.transforms.fuse_conv_with_clamp import FuseClampPass
from executorch.backends.transforms.fuse_dequant_linear import FuseDequantLinearPass
from executorch.backends.transforms.fuse_view_copy import FuseViewCopyTransform
from executorch.backends.transforms.view_copy_to_squeeze_unsqueeze import (
ViewCopyToSqueezeUnsqueezePass,
)
from executorch.backends.vulkan._passes import (
insert_prepack_nodes,
RemoveLocalScalarDenseOpsTransform,
RemoveRedundantOpsTransform,
SqueezeUnsqueezeInputs,
TagMemoryMetaPass,
)
from executorch.backends.vulkan.serialization.vulkan_graph_builder import VkGraphBuilder
from executorch.backends.vulkan.serialization.vulkan_graph_schema import (
VkMemoryLayout,
VkStorageType,
)
from executorch.backends.vulkan.serialization.vulkan_graph_serialize import (
serialize_vulkan_graph,
)
from executorch.exir.backend.backend_details import (
BackendDetails,
CompileSpec,
ExportedProgram,
PreprocessResult,
)
from executorch.exir.backend.utils import DelegateMappingBuilder
from executorch.exir.memory_planning import greedy, memory_planning_algorithm_suite
from executorch.exir.pass_base import ExportPass, PassBase
from executorch.exir.passes import MemoryPlanningPass, SpecPropPass
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
from executorch.exir.program._program import _copy_module
from torch.export._remove_auto_functionalized_pass import (
unsafe_remove_auto_functionalized_pass,
)
DEFAULT_DEBUG_HANDLE = 65535
# pyre-ignore
def apply_passes(program: ExportedProgram, passes) -> ExportedProgram:
for p in passes:
if issubclass(type(p), ExportPass) or issubclass(type(p), PassBase):
new_gm = program.graph_module
# This is a workaround to allow the memory planning pass to work without
# having to first apply ToOutVarPass(). See the `greedy()` function in
# `exir.memory_planning`; if this attribute isn't set, assertions in
# `collect_spec_from_nodes()` will fail.
if isinstance(p, MemoryPlanningPass):
new_gm.encounter_to_out_var_failure = True
new_gm_res = p(new_gm)
assert new_gm_res is not None
new_gm = new_gm_res.graph_module
# See the application of this function in exir/program/_program.py for more
# details on why this step is necessary.
if isinstance(p, SpecPropPass):
p.update_placeholder_tensor_specs(program, new_gm)
_copy_module(program.graph_module, new_gm)
else:
program = p(program)
return program
def parse_compile_spec(compile_specs: List[CompileSpec]) -> Dict[str, Any]:
options = {}
for spec in compile_specs:
if spec.key == "storage_type_override":
options[spec.key] = VkStorageType(
int.from_bytes(spec.value, byteorder="little")
)
if spec.key == "memory_layout_override":
options[spec.key] = VkMemoryLayout(
int.from_bytes(spec.value, byteorder="little")
)
if spec.key in {"texture_limits_x", "texture_limits_y", "texture_limits_z"}:
options[spec.key] = int.from_bytes(spec.value, byteorder="little")
if spec.key == "skip_tag_memory_metadata":
options[spec.key] = bool.from_bytes(spec.value, byteorder="little")
# Unhandled options are ignored
return options
@final
class VulkanBackend(BackendDetails):
@classmethod
# pyre-ignore
def preprocess( # noqa: C901
cls,
program: ExportedProgram,
module_compile_spec: List[CompileSpec],
) -> PreprocessResult:
compile_options = parse_compile_spec(module_compile_spec)
limits_x = compile_options.get(
"texture_limits_x", utils.DEFAULT_TEXTURE_LIMITS[0]
)
limits_y = compile_options.get(
"texture_limits_y", utils.DEFAULT_TEXTURE_LIMITS[1]
)
limits_z = compile_options.get(
"texture_limits_z", utils.DEFAULT_TEXTURE_LIMITS[2]
)
texture_limits = (limits_x, limits_y, limits_z)
default_storage_type = compile_options.get(
"storage_type_override", VkStorageType.TEXTURE_3D
)
default_memory_layout = compile_options.get(
"memory_layout_override", VkMemoryLayout.TENSOR_WIDTH_PACKED
)
program = unsafe_remove_auto_functionalized_pass(program)
# First, apply passes that fuse/remove operators to consolidate the graph
# structure but still preserve an "ATen-compliant" graph structure (i.e. all
# arguments to ATen operators must match the ATen function schema).
program = apply_passes(
program,
[
RemoveRedundantOpsTransform(),
AddmmToLinearTransform(),
FuseDequantLinearPass(),
SqueezeUnsqueezeInputs(),
FuseViewCopyTransform(),
ViewCopyToSqueezeUnsqueezePass(),
FuseBatchNormWithConvPass(program),
FuseClampPass(),
],
)
# Next annotate tensor nodes with TensorSpec structs which is needed for dynamic
# shapes and memory planning. Until this point, the graph must be ATen compliant
# because SpecPropPass will be calling the underlying ATen operators during its
# execution.
program = apply_passes(program, [SpecPropPass()])
# Apply graph transforms which either require `TensorSpec`s to have been created
# or would create an non ATen compliant graph structure.
program = apply_passes(
program,
[
# Since this pass may replace a scalar argument with a tensor argument,
# this pass may result in a non ATen compliant graph structure.
RemoveLocalScalarDenseOpsTransform(),
insert_prepack_nodes,
],
)
# Optionally apply the memory metadata tagging pass, which will insert storage
# type and memory layout transition nodes to ensure that all tensor arguments
# to an operator is in a supported or optimal configuration. If this pass is not
# applied, there will be a risk that some operators recieve arguments with
# memory settings that are not supported by the implementation.
if not compile_options.get("skip_tag_memory_metadata", False):
program = apply_passes(
program,
[
TagMemoryMetaPass(
texture_limits,
default_storage_type=default_storage_type,
default_memory_layout=default_memory_layout,
),
],
)
# Finally, apply dynamic shape passes and memory planning pass. These passes
# must be applied only when the graph structure is finalized.
greedy_memory_planning = partial(greedy, allow_overlapping_allocations=False)
mem_planning_suite = partial(
memory_planning_algorithm_suite, algo_list=[greedy_memory_planning]
)
program = apply_passes(
program,
[
ConstraintBasedSymShapeEvalPass(),
MemoryPlanningPass(memory_planning_algo=mem_planning_suite),
],
)
graph_builder = VkGraphBuilder(
program, DelegateMappingBuilder(generated_identifiers=True)
)
vk_graph = graph_builder.build_graph()
return PreprocessResult(
processed_bytes=serialize_vulkan_graph(
vk_graph, graph_builder.const_tensors, []
),
debug_handle_map=graph_builder.delegate_mapping_builder.get_delegate_mapping(),
)