-
Notifications
You must be signed in to change notification settings - Fork 508
/
Copy pathutils.py
260 lines (197 loc) · 7.82 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from enum import IntEnum
from typing import Optional, Set, Tuple
import torch
from executorch.backends.vulkan.serialization.vulkan_graph_schema import (
VkMemoryLayout,
VkStorageType,
)
from executorch.exir.tensor import TensorSpec
from torch._export.utils import is_buffer, is_param
from torch._subclasses.fake_tensor import FakeTensor
from torch.export import ExportedProgram
##
## Node type determination
##
def is_get_attr_node(node: torch.fx.Node) -> bool:
return isinstance(node, torch.fx.Node) and node.op == "get_attr"
def is_constant(program: ExportedProgram, node: torch.fx.Node) -> bool:
return node.name in program.graph_signature.inputs_to_lifted_tensor_constants
def is_param_node(program: ExportedProgram, node: torch.fx.Node) -> bool:
"""
Check if the given node is a parameter within the exported program
"""
return (
is_get_attr_node(node)
or is_param(program, node)
or is_buffer(program, node)
or is_constant(program, node)
)
def is_symint_node(node: torch.fx.Node) -> bool:
"""
Returns true if the given node produces a SymInt value
"""
if "val" not in node.meta:
return False
if isinstance(node.meta["val"], torch.SymInt):
return True
return False
def is_tensor_node(node: torch.fx.Node) -> bool:
"""
Returns true if the given node produces a tensor value, or a collection of tensor values
"""
# All nodes with tensor values are tagged by the SpecPropPass transform
if "spec" in node.meta:
return True
if "val" not in node.meta:
return False
if isinstance(node.meta["val"], FakeTensor):
return True
if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
return all(isinstance(x, FakeTensor) for x in node.meta["val"])
return False
def tensor_node_is_bool(node: torch.fx.Node) -> bool:
"""
Returns true if a given node contains a tensor with bool dtype
"""
if isinstance(node.meta["val"], FakeTensor):
return node.meta["val"].dtype == torch.bool
if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
for fake_tensor in node.meta["val"]:
if isinstance(fake_tensor, FakeTensor):
if fake_tensor.dtype == torch.bool:
return True
return False
##
## Memory Layout, Storage Type Determination
##
ImageExtents = Tuple[int, int, int]
DEFAULT_TEXTURE_LIMITS = (16384, 16384, 2048)
DEFAULT_BUFFER_LIMIT = 128 * (1024 * 1024)
class PackedDim(IntEnum):
WIDTH = 0
HEIGHT = 1
CHANNELS = 2
all_packed_dims: Set[PackedDim] = {
PackedDim.WIDTH,
PackedDim.HEIGHT,
PackedDim.CHANNELS,
}
all_storage_types: Set[VkStorageType] = {
VkStorageType.BUFFER,
VkStorageType.TEXTURE_3D,
}
all_memory_layouts: Set[VkMemoryLayout] = {
VkMemoryLayout.TENSOR_WIDTH_PACKED,
VkMemoryLayout.TENSOR_HEIGHT_PACKED,
VkMemoryLayout.TENSOR_CHANNELS_PACKED,
}
def within_buffer_limit(node: torch.fx.Node, buffer_limit: int) -> int:
"""
Checks whether the tensors produced by the given node can fit within the device's
GPU buffer limit, which represents the maximum number of elements that can be stored
in a GPU buffer.
"""
assert is_tensor_node(node)
if isinstance(node.meta["val"], FakeTensor):
return node.meta["val"].numel() < buffer_limit
elif isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
return all(x.numel() < buffer_limit for x in node.meta["val"])
else:
raise RuntimeError(f"Cannot get numel for val of type {type(node.meta['val'])}")
def tensor_node_is_high_dim(node: torch.fx.Node) -> bool:
"""
Returns true if a given node contains a tensor with more than 4 dimensions
"""
if isinstance(node.meta["val"], FakeTensor):
return len(node.meta["val"].shape) > 4
if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
for fake_tensor in node.meta["val"]:
if isinstance(fake_tensor, FakeTensor):
if len(fake_tensor.shape) > 4:
return True
return False
def required_image_extents(sizes: torch.Size, layout: VkMemoryLayout) -> ImageExtents:
"""
Calculate the image extents that will be used to represent a tensor with the given sizes
and memory layout in the Vulkan Delegate.
"""
width = sizes[-1] if len(sizes) >= 1 else 1
height = sizes[-2] if len(sizes) >= 2 else 1
channels = sizes[-3] if len(sizes) >= 3 else 1
batch = sizes[0] if len(sizes) >= 4 else 1
if layout == VkMemoryLayout.TENSOR_WIDTH_PACKED:
width = (width + 3) // 4
elif layout == VkMemoryLayout.TENSOR_HEIGHT_PACKED:
height = (height + 3) // 4
elif layout == VkMemoryLayout.TENSOR_CHANNELS_PACKED:
channels = (channels + 3) // 4
else:
raise RuntimeError(f"Unsupported memory layout {layout}")
return width, height, channels * batch
def extents_are_valid(extents: ImageExtents, limits: ImageExtents) -> bool:
return all(extents[i] <= limits[i] for i in range(len(extents)))
def valid_texture_memory_layouts(
tensor_sizes: torch.Size, texture_limits: ImageExtents
) -> Set[VkMemoryLayout]:
"""
Given tensor sizes, determine the set of memory layouts which will prodice a texture
that can fit within the specified device limits.
"""
valid_layouts = set()
for layout in list(all_memory_layouts):
extents = required_image_extents(tensor_sizes, layout)
if extents_are_valid(extents, texture_limits):
valid_layouts.add(layout)
return valid_layouts
def possible_node_memory_layouts(
node: torch.fx.Node, texture_limits: ImageExtents
) -> Set[VkMemoryLayout]:
"""
Given a node, determine the set of memory layouts which can be used to represent all
tensors involved in the computation.
"""
assert is_tensor_node(node)
if isinstance(node.meta["val"], FakeTensor):
return valid_texture_memory_layouts(node.meta["val"].shape, texture_limits)
valid_layouts = set()
if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
for fake_tensor in node.meta["val"]:
valid_layouts = valid_layouts.union(
valid_texture_memory_layouts(fake_tensor.shape, texture_limits)
)
return valid_layouts
##
## TensorSpec Utils
##
def set_node_spec_attr(node: torch.fx.Node, attr: str, value):
assert "spec" in node.meta
spec = node.meta["spec"]
if isinstance(spec, TensorSpec):
setattr(spec, attr, value)
elif isinstance(spec, (list, tuple)):
for s in spec:
assert isinstance(s, TensorSpec)
setattr(s, attr, value)
else:
raise RuntimeError(f"Cannot set attr for spec of type {type(spec)}")
def get_node_spec_attr(node: torch.fx.Node, attr: str, return_first: bool = True):
assert "spec" in node.meta
spec = node.meta["spec"]
if isinstance(spec, TensorSpec):
return getattr(spec, attr) if hasattr(spec, attr) else None
elif isinstance(spec, (list, tuple)):
if return_first:
return getattr(spec[0], attr) if hasattr(spec[0], attr) else None
else:
return [getattr(s, attr) if hasattr(s, attr) else None for s in spec]
else:
raise RuntimeError(f"Cannot get attr for spec of type {type(spec)}")
def get_node_storage_type(node: torch.fx.Node) -> Optional[VkStorageType]:
return get_node_spec_attr(node, "vk_storage_type")
def get_node_memory_layout(node: torch.fx.Node) -> Optional[VkMemoryLayout]:
return get_node_spec_attr(node, "vk_memory_layout")