forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgen_jit_shape_functions.py
182 lines (141 loc) · 5.02 KB
/
gen_jit_shape_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
#!/usr/bin/env python3
import importlib.util
import os
import sys
from itertools import chain
from pathlib import Path
# Manually importing the shape function module based on current directory
# instead of torch imports to avoid needing to recompile Pytorch before
# running the script
file_path = Path.cwd() / "torch" / "jit" / "_shape_functions.py"
module_name = "torch.jit._shape_functions"
err_msg = """Could not find shape functions file, please make sure
you are in the root directory of the Pytorch git repo"""
if not file_path.exists():
raise Exception(err_msg)
spec = importlib.util.spec_from_file_location(module_name, file_path)
assert spec is not None
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
assert spec.loader is not None
assert module is not None
spec.loader.exec_module(module)
bounded_compute_graph_mapping = module.bounded_compute_graph_mapping
shape_compute_graph_mapping = module.shape_compute_graph_mapping
SHAPE_HEADER = r"""
/**
* @generated
* This is an auto-generated file. Please do not modify it by hand.
* To re-generate, please run:
* cd ~/pytorch && python
* torchgen/shape_functions/gen_jit_shape_functions.py
*/
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/runtime/serialized_shape_function_registry.h>
// clang-format off
namespace torch {
namespace jit {
std::string shape_funcs = ""
"""
DECOMP_CENTER = r"""
const std::string& GetSerializedShapeFunctions() {
return shape_funcs;
}
"""
DECOMP_END = r"""
// clang-format on
} // namespace jit
} // namespace torch
"""
SERIALIZED_SHAPE_UTIL_FILE_NAME = "serialized_shape_function_registry.cpp"
def gen_serialized_decompisitions() -> str:
already_serialized_names = set()
unique_funcs = []
all_funcs = chain(
shape_compute_graph_mapping.values(), *bounded_compute_graph_mapping.values()
)
for scripted_func in all_funcs:
if scripted_func.name in already_serialized_names:
continue
already_serialized_names.add(scripted_func.name)
unique_funcs.append(scripted_func)
output_strs = []
curr_str = ""
for scripted_func in unique_funcs:
serialized_code = scripted_func.code
# technically its higher but give a buffer bc there are weird rules
# around some characters
# TODO: this was the limit I found by googling but it seems way
# too short ?
MAX_MSFT_STR_LEN = 2000
if len(curr_str) + len(serialized_code) <= MAX_MSFT_STR_LEN:
curr_str += "\n" + serialized_code
else:
output_strs.append(curr_str)
curr_str = scripted_func.code
output_strs.append(curr_str)
final_output = ""
# Windows compiler doesnt correctly handle adjacent
# string literals
for output_str in output_strs:
start = '+ std::string(R"=====('
end = '\n)=====")\n'
final_output += start + output_str + end
final_output += ";"
return final_output
SHAPE_SCHEMA_START = r"""
const OperatorMap<std::string>& GetShapeFunctionMappings() {
static const OperatorMap<std::string> shape_mappings {
"""
SHAPE_SCHEMA_END = r"""
};
return shape_mappings;
}
"""
def gen_shape_mappings() -> str:
shape_mappings = []
for schema, scripted_func in shape_compute_graph_mapping.items():
shape_mappings.append(' {"' + schema + '", "' + scripted_func.name + '"},')
return SHAPE_SCHEMA_START + "\n".join(shape_mappings) + SHAPE_SCHEMA_END
BOUNDED_SCHEMA_START = r"""
const OperatorMap<std::pair<std::string, std::string>>& GetBoundedShapeMappings() {
static const OperatorMap<std::pair<std::string, std::string>> shape_mappings {
"""
def gen_bounded_mappings() -> str:
bounded_mappings = []
for schema, (lower_func, upper_func) in bounded_compute_graph_mapping.items():
map_str = (
' {"'
+ schema
+ '", {"'
+ lower_func.name
+ '", "'
+ upper_func.name
+ '"}},'
)
bounded_mappings.append(map_str)
return BOUNDED_SCHEMA_START + "\n".join(bounded_mappings) + SHAPE_SCHEMA_END
def write_decomposition_util_file(path: str) -> None:
decomposition_str = gen_serialized_decompisitions()
shape_mappings = gen_shape_mappings()
bounded_mappings = gen_bounded_mappings()
file_components = [
SHAPE_HEADER,
decomposition_str,
DECOMP_CENTER,
shape_mappings,
bounded_mappings,
DECOMP_END,
]
print("writing file to : ", path + "/" + SERIALIZED_SHAPE_UTIL_FILE_NAME)
with open(os.path.join(path, SERIALIZED_SHAPE_UTIL_FILE_NAME), "wb") as out_file:
final_output = "".join(file_components)
out_file.write(final_output.encode("utf-8"))
def main() -> None:
pytorch_dir = Path(__file__).resolve().parents[2]
upgrader_path = pytorch_dir / "torch" / "csrc" / "jit" / "runtime"
write_decomposition_util_file(str(upgrader_path))
if __name__ == "__main__":
main()