forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathet_cpp.py
368 lines (330 loc) · 12.7 KB
/
et_cpp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
from typing import List, Optional, Sequence, Set, Union
from torchgen import local
from torchgen.api.types import (
ArgName,
ArrayCType,
BaseCType,
Binding,
ConstRefCType,
CType,
MutRefCType,
NamedCType,
SpecialArgName,
TupleCType,
VectorCType,
voidT,
)
from torchgen.model import (
Argument,
Arguments,
BaseTy,
BaseType,
ListType,
NativeFunction,
OptionalType,
Return,
SelfArgument,
TensorOptionsArguments,
Type,
)
from torchgen.utils import assert_never
from .types import (
ArrayRefCType,
BaseTypeToCppMapping,
OptionalCType,
scalarT,
tensorListT,
tensorT,
)
"""
This file describes the translation of JIT schema to the public C++ API, which is what people use when they call
functions like at::add. It also serves as a native function API, which is the signature of kernels,
since in Executorch CppSignature is the same as NativeSignature.
Difference between this file and torchgen.api.cpp.py:
- Executorch doesn't support TensorOptions, however in this file we still keep the logic here to be compatible with
torchgen.api.cpp, so that we can do stuff like ATen mode (running ATen kernels in Executorch).
- Executorch doesn't support Dimname.
- Executorch runtime doesn't support SymInt, will treat it as int.
"""
# Translation of "value types" in JIT schema to C++ API type. Value
# types look the same no matter if they are argument types or return
# types. Returns None if the type in question is not a value type.
def valuetype_type(
t: Type,
*,
binds: ArgName,
remove_non_owning_ref_types: bool = False,
) -> Optional[NamedCType]:
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
return None
# For SymInt we simply treat it as int.
elif str(t) == "SymInt":
return NamedCType(binds, BaseCType(BaseTypeToCppMapping[BaseTy.int]))
if remove_non_owning_ref_types:
if t.name == BaseTy.str:
raise AssertionError(
"string ref->value conversion: not implemented yet"
)
# All other BaseType currently map directly to BaseCppTypes.
return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name]))
elif isinstance(t, OptionalType):
elem = valuetype_type(t.elem, binds=binds)
if elem is None:
return None
return NamedCType(binds, OptionalCType(elem.type))
elif isinstance(t, ListType):
if str(t.elem) == "bool":
assert t.size is not None
return NamedCType(
binds, ArrayCType(BaseCType(BaseTypeToCppMapping[BaseTy.bool]), t.size)
)
else:
return None
else:
raise AssertionError(f"unrecognized type {repr(t)}")
# Translation of types occuring in JIT arguments to a C++ argument type.
# If remove_non_owning_ref_types is set, we'll guarantee that the outputed CType is not a non-owning reference type.
# For example, we'll return std::vector<int> instead of IntArrayRef.
# See Note [translation from C++ reference to value types]
def argumenttype_type(
t: Type,
*,
mutable: bool,
binds: ArgName,
remove_non_owning_ref_types: bool = False,
) -> NamedCType:
# If it's a value type, do the value type translation
r = valuetype_type(
t,
binds=binds,
remove_non_owning_ref_types=remove_non_owning_ref_types,
)
if r is not None:
return r
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor:
if mutable and not local.use_const_ref_for_mutable_tensors():
return NamedCType(binds, MutRefCType(BaseCType(tensorT)))
else:
return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
elif t.name == BaseTy.Scalar:
return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
else:
raise AssertionError(f"base type should have been value type {t}")
elif isinstance(t, OptionalType):
if str(t.elem) == "Tensor":
if mutable and not local.use_const_ref_for_mutable_tensors():
return NamedCType(
binds, MutRefCType(BaseCType(tensorT))
) # TODO: fix this discrepancy
else:
return NamedCType(
binds, ConstRefCType(OptionalCType(BaseCType(tensorT)))
)
elif str(t.elem) == "Scalar":
return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
return NamedCType(binds, OptionalCType(elem.type))
elif isinstance(t, ListType):
# TODO: keeping these special cases for Tensor[] and Tensor?[] so that we can hookup with ATen kernels.
if str(t.elem) == "Tensor":
return NamedCType(binds, BaseCType(tensorListT))
elif str(t.elem) == "Dimname":
raise NotImplementedError("Executorch doesn't support Dimname")
elif str(t.elem) == "Tensor?":
return NamedCType(binds, ArrayRefCType(OptionalCType(BaseCType(tensorT))))
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
return NamedCType(binds, ArrayRefCType(elem.type))
else:
raise AssertionError(f"unrecognized type {repr(t)}")
# Translate a JIT argument into its C++ type
def argument_type(a: Argument, *, binds: ArgName) -> NamedCType:
return argumenttype_type(a.type, mutable=a.is_write, binds=binds)
# Translation of a (non-multi) return type from JIT to C++
# N.B: returntype_type returns a CType, not a NamedCType.
# This is mostly because of the mismatch between return types and return names.
# e.g. a function with a return type of 'void' has 0 return names,
# and a function with a return type of 'std::tuple' has >1 return name.
def returntype_type(t: Type, *, mutable: bool) -> CType:
# placeholder is ignored
r = valuetype_type(t, binds="__placeholder__")
if r is not None:
return r.type
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor:
if mutable:
if local.use_const_ref_for_mutable_tensors():
return ConstRefCType(BaseCType(tensorT))
else:
return MutRefCType(BaseCType(tensorT))
else:
# Note [Tensor Copy Returns]
# Currently, we use "Argument.is_write" to determine
# whether or not Tensor return types should be copies or references.
# If that ever changes, take a look at other locations of this note!
return BaseCType(tensorT)
elif t.name == BaseTy.Scalar:
return BaseCType(scalarT)
elif isinstance(t, ListType):
assert (
not mutable
), "Native functions should never return a mutable tensor list. They should return void."
elem = returntype_type(t.elem, mutable=False)
assert t.size is None, f"fixed size list returns not supported: {t}"
return VectorCType(elem)
raise AssertionError(f"unrecognized return type {t}")
# Translation of a single return to its C++ type
def return_type(r: Return) -> CType:
return returntype_type(r.type, mutable=r.is_write)
# Translation of a full (possibly multi) return from JIT to its C++ type
def returns_type(rs: Sequence[Return]) -> CType:
if len(rs) == 0:
return BaseCType(voidT)
elif len(rs) == 1:
return return_type(rs[0])
else:
return TupleCType([return_type(r) for r in rs])
def return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
returns: List[str] = []
for i, r in enumerate(f.func.returns):
# If we have an inplace function, the return argument is
# implicitly named self.
# TODO: Consider incorporating this into the data model
if f.func.name.name.inplace:
assert i == 0, "illegal inplace function with multiple returns"
name = "self"
# If we are out function, the name is the name of the
# corresponding output function (r.name will get recorded
# in field_name later.)
elif f.func.is_out_fn():
name = f.func.arguments.out[i].name
# If the return argument is explicitly named...
elif r.name:
name_conflict = any(
r.name == a.name for a in f.func.schema_order_arguments()
)
if name_conflict and not f.func.is_out_fn():
name = f"{r.name}_return"
else:
name = r.name
# If there is no explicit name and no fallback name was passed in, we just name the output result,
# unless it's a multi-return, in which case it's result0,
# result1, etc (zero-indexed)
else:
name = fallback_name if len(f.func.returns) == 1 else f"{fallback_name}{i}"
returns.append(name)
return returns
JIT_TO_CPP_DEFAULT = {
"False": "false",
"True": "true",
"None": "torch::executorch::nullopt", # UGH this one is type directed
"[]": "{}",
"contiguous_format": "torch::executorch::MemoryFormat::Contiguous",
"long": "torch::executorch::kLong",
}
# Convert a JIT default into C++ expression representing the default
def default_expr(d: str, t: Type) -> str:
if d == "None" and str(t) == "Tensor?":
return "{}"
if isinstance(t, BaseType) and t.name is BaseTy.str:
# Schema allows single quotes but C++ needs double
if len(d) >= 2 and d[0] == "'" and d[-1] == "'":
s = ""
i = 1
while i + 1 < len(d):
if d[i] != "\\":
if d[i] == '"':
s += '\\"'
else:
s += d[i]
i += 1
else:
if d[i + 1] == "'":
s += "'"
else:
s += d[i : i + 2]
i += 2
return f'"{s}"'
if isinstance(t, OptionalType):
if d == "None":
return "torch::executor::nullopt"
return default_expr(d, t.elem)
if isinstance(t, ListType):
if d.startswith("[") and d.endswith("]"):
return "{" + d[1:-1] + "}"
elif t.size is None:
# NOTE: Sized lists can have scalar defaults
raise ValueError(f"Expected a list default '[...]' but found: '{d}'")
return JIT_TO_CPP_DEFAULT.get(d, d)
# Convert an argument into its C++ API form
def argument(
a: Union[Argument, TensorOptionsArguments, SelfArgument],
*,
cpp_no_default_args: Set[str],
method: bool,
faithful: bool,
has_tensor_options: bool,
) -> List[Binding]:
def sub_argument(
a: Union[Argument, TensorOptionsArguments, SelfArgument]
) -> List[Binding]:
return argument(
a,
cpp_no_default_args=cpp_no_default_args,
method=method,
faithful=faithful,
has_tensor_options=has_tensor_options,
)
if isinstance(a, Argument):
binds: ArgName
if a.name == "memory_format" and has_tensor_options:
binds = SpecialArgName.possibly_redundant_memory_format
else:
binds = a.name
default: Optional[str] = None
if a.name not in cpp_no_default_args and a.default is not None:
default = default_expr(a.default, a.type)
return [
Binding(
nctype=argument_type(a, binds=binds),
name=a.name,
default=default,
argument=a,
)
]
elif isinstance(a, TensorOptionsArguments):
raise NotImplementedError("Need to implement type resolution for TensorOptions")
elif isinstance(a, SelfArgument):
if method:
# Caller is responsible for installing implicit this in context!
return []
else:
return sub_argument(a.argument)
else:
assert_never(a)
def arguments(
arguments: Arguments,
*,
faithful: bool,
method: bool,
cpp_no_default_args: Set[str],
) -> List[Binding]:
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
if faithful:
args.extend(arguments.non_out)
args.extend(arguments.out)
else:
args.extend(arguments.out)
args.extend(arguments.non_out)
return [
r.no_default() if faithful else r
for a in args
for r in argument(
a,
faithful=faithful,
method=method,
has_tensor_options=arguments.tensor_options is not None,
cpp_no_default_args=cpp_no_default_args,
)
]