forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathunboxing.py
213 lines (196 loc) · 7.58 KB
/
unboxing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
from dataclasses import dataclass
from typing import Callable, List, Sequence, Tuple
from torchgen.api.types import Binding, CType, NamedCType
from torchgen.model import (
Argument,
BaseTy,
BaseType,
ListType,
NativeFunction,
OptionalType,
Type,
)
connector = "\n\t"
# Return unboxing function name for a NativeFunction
def name(f: NativeFunction) -> str:
return f.func.name.unambiguous_name()
@dataclass(frozen=True)
class Unboxing:
"""
Takes a sequence of Bindings and unbox EValues to these Bindings. Return generated code that performs correct unboxing.
A sample generated code:
// aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
void mul_out(EValue** stack) {
EValue& self = *stack[0];
EValue& other = *stack[1];
EValue& out = *stack[2];
const torch::executor::Tensor & self_base = self.to<torch::executor::Tensor>();
const torch::executor::Tensor & other_base = other.to<torch::executor::Tensor>();
torch::executor::Tensor & out_base = out.to<torch::executor::Tensor>();
EXECUTORCH_SCOPE_PROF("native_call_mul.out");
torch::executor::mul_outf(self_base, other_base, out_base);
}
"""
# this is a callable that converts a JIT argument, into its C++ type.
# Translates (type, mutability, binds) to NamedCType. E.g., torchgen.api.cpp.argumenttype_type.
argument_type_gen: Callable[
...,
NamedCType,
]
# Convert all the arguments in a NativeFunction to C++ code
def convert_arguments(
self, args: Sequence[Binding]
) -> Tuple[List[Binding], List[str]]:
code_list = [f"EValue& {args[i].name} = *stack[{i}];" for i in range(len(args))]
binding_list = []
for arg in args:
# expecting only Argument
if not isinstance(arg.argument, Argument):
raise Exception(
f"Unexpected argument type, expecting `Argument` but got {arg}"
)
argument: Argument = arg.argument
unboxed_name, _, code, decl = self.argumenttype_evalue_convert(
argument.type, argument.name, mutable=argument.is_write
)
code_list.extend(decl)
code_list.extend(code)
binding_list.append(arg.with_name(unboxed_name))
return binding_list, code_list
def argumenttype_evalue_convert(
self, t: Type, arg_name: str, *, mutable: bool = False
) -> Tuple[str, CType, List[str], List[str]]:
"""
Takes in the type, name and mutability corresponding to an argument, and generates a tuple of:
(1) the C++ code necessary to unbox the argument
(2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType
:param t: a `Type` of an argument
:param arg_name: argument name
:param mutable: boolean for whether this argument type is mutable
:return: unboxed result
"""
ctype = self.argument_type_gen(t, mutable=mutable, binds=arg_name).type
if isinstance(t, BaseType):
out_name = f"{arg_name}_base"
code, decl = self._gen_code_base_type(
arg_name=arg_name, out_name=out_name, ctype=ctype
)
elif isinstance(t, OptionalType):
out_name = f"{arg_name}_opt_out"
code, decl = self._gen_code_optional_type(
arg_name=arg_name, out_name=out_name, t=t, ctype=ctype
)
elif isinstance(t, ListType):
out_name = f"{arg_name}_list_out"
code, decl = self._gen_code_list_type(
arg_name=arg_name, out_name=out_name, t=t, ctype=ctype
)
else:
raise Exception(f"Cannot handle type {t}. arg_name: {arg_name}")
return out_name, ctype, code, decl
def _gen_code_base_type(
self, arg_name: str, out_name: str, ctype: CType
) -> Tuple[List[str], List[str]]:
return [
f"{ctype.cpp_type()} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
], []
def _gen_code_optional_type(
self, arg_name: str, out_name: str, t: OptionalType, ctype: CType
) -> Tuple[List[str], List[str]]:
in_name = f"{arg_name}_opt_in"
res_name, base_type, res_code, decl = self.argumenttype_evalue_convert(
t.elem, in_name
)
return (
f"""
{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toOptional<{base_type.cpp_type(strip_ref=True)}>();
""".split(
"\n"
),
decl,
)
def _gen_code_list_type(
self, arg_name: str, out_name: str, t: ListType, ctype: CType
) -> Tuple[List[str], List[str]]:
in_name = f"{arg_name}_list_in"
elem_name = f"{arg_name}_elem"
code = []
res_name, res_ctype, res_code, decl = self.argumenttype_evalue_convert(
t.elem, elem_name
)
if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.Tensor:
code.extend(
f"""
{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toTensorList();
""".split(
"\n"
)
)
elif isinstance(t.elem, BaseType) and (
t.elem.name == BaseTy.int or t.elem.name == BaseTy.SymInt
):
code.extend(
f"""
{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toIntList();
""".split(
"\n"
)
)
elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.float:
code.extend(
f"""
{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toDoubleList();
""".split(
"\n"
)
)
elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool:
# handle list type with size, e.g., bool[4]
code.extend(
f"""
{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.toBoolList();
""".split(
"\n"
)
)
# pytorch codegen:
# we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<c10::optional<at::Tensor>>
elif (
isinstance(t.elem, OptionalType)
and isinstance(t.elem.elem, BaseType)
and t.elem.elem.name == BaseTy.Tensor
):
code.extend(
f"""
#ifdef USE_ATEN_LIB
at::ArrayRef<c10::optional<at::Tensor>> {in_name} = {arg_name}.toListOptionalTensor();
c10::List<c10::optional<at::Tensor>> {out_name};
for (auto {elem_name}: {in_name}) {{
{out_name}.push_back({elem_name});
}}
#else
torch::executor::ArrayRef<torch::executor::optional<torch::executor::Tensor>> {out_name} = {arg_name}.toListOptionalTensor();
#endif
""".split(
"\n"
)
)
else:
# use ArrayRef as default.
vec_name = arg_name + "_vec"
# need to bring vector instantiation out of scope so that ArrayRef has valid data
decl.append(
f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};"
)
code.extend(
f"""
for (EValue {elem_name}: {in_name}) {{
{connector.join(res_code)}
{vec_name}.push_back({res_name});
}}
{ctype.cpp_type(strip_ref=True)} {out_name}({vec_name});
""".split(
"\n"
)
)
return code, decl