forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathufunc.py
545 lines (468 loc) · 17.4 KB
/
ufunc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple, Union
import torchgen.api.ufunc as ufunc
from torchgen.api.translate import translate
from torchgen.api.types import (
BaseCType,
Binding,
CType,
Expr,
NamedCType,
opmath_t,
scalar_t,
StructuredImplSignature,
VectorizedCType,
)
from torchgen.api.ufunc import UfunctorBindings
from torchgen.context import with_native_function
from torchgen.model import (
Argument,
BaseTy,
BaseType,
DispatchKey,
NativeFunctionsGroup,
ScalarType,
UfuncKey,
)
from torchgen.utils import OrderedSet
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# CUDA STUFF
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
# NB: not bothering to generate dispatch stub forward declaration in header,
# we can just paste it whereever necessary
# TODO: use BackendIndex
# dispatch_key: DispatchKey # only CPU/CUDA right now
# Represents functors for implementing CUDA ufuncs.
# Functors are templated by scalar_t because when USERS instantiate functors
# they are templated. A functor looks something like this:
#
# template <typename scalar_t>
# struct CUDAFunctorOnSelf_add {
# using opmath_t = at::opmath_type<scalar_t>;
# opmath_t other_;
# opmath_t alpha_;
# CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha)
# : other_(other), alpha_(alpha) {}
# __device__ scalar_t operator()(scalar_t self) {
# return ufunc::add(static_cast<opmath_t>(self), other_, alpha_);
# }
# };
#
@dataclass(frozen=True)
class UfunctorSignature:
g: NativeFunctionsGroup
scalar_tensor_idx: Optional[int]
name: str
def arguments(self) -> UfunctorBindings:
return ufunc.ufunctor_arguments(
self.g, scalar_tensor_idx=self.scalar_tensor_idx, scalar_t=scalar_t
)
def fields(self) -> List[Binding]:
# fields are renamed to have a trailing underscore, as is conventional
return [b.rename(f"{b.name}_") for b in self.arguments().ctor]
def returns_type(self) -> CType:
# TODO: don't hardcode; return type will be inferred based on tags on
# the native function
return BaseCType(scalar_t)
def decl_fields(self) -> str:
return "\n".join(f"{f.type} {f.name};" for f in self.fields())
def inline_defn_ctor(self) -> str:
args_str = ", ".join(a.decl() for a in self.arguments().ctor)
# NB: hypothetically could do this with translate but the
# transition here is very regular
init_str = ", ".join(f"{a.name}_({a.name})" for a in self.arguments().ctor)
return f"{self.name}({args_str}) : {init_str} {{}}"
def decl_apply(self) -> str:
args_str = ", ".join(a.decl() for a in self.arguments().apply)
return f"{self.returns_type().cpp_type()} operator()({args_str}) const"
@dataclass(frozen=True)
class UfuncSignature:
g: NativeFunctionsGroup
name: str
compute_t: CType
def arguments(self) -> List[Binding]:
return ufunc.ufunc_arguments(self.g, compute_t=self.compute_t)
def call(self, ctx: Sequence[Union[Binding, Expr]]) -> str:
return f"{self.name}({', '.join(a.expr for a in translate(ctx, self.arguments()))})"
# steps:
# 1. take the functional signature
# 2. use api.ufunc to convert it to template signature. this establishes
# the type of the template function
# 3. use api.ufunc (II) to generate a split struct / operator() signature.
# this establish context in which we call the template signature
#
# StructuredImplSignature context
# ~> functor constructor sig
#
# Functor constructor context
# ~> functor fields sig
#
# Functor apply context (functor fields + functor apply sig)
# ~> template sig
#
def eligible_for_binary_scalar_specialization(g: NativeFunctionsGroup) -> bool:
num_tensors = sum(
1 for a in g.functional.func.arguments.flat_non_out if a.type.is_tensor_like()
)
return num_tensors == 2
def compute_ufunc_cuda_functors(
g: NativeFunctionsGroup,
) -> Tuple[Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]], str]:
# First, build the functors.
ufunctor_sigs: Dict[ScalarType, Dict[UfuncKey, UfunctorSignature]] = {}
ufunctors: List[str] = []
loops = g.out.ufunc_inner_loop
scalar_tensor_idx_lookup = {
UfuncKey.CUDAFunctorOnSelf: 1,
UfuncKey.CUDAFunctorOnOther: 0,
UfuncKey.CUDAFunctor: None,
}
if eligible_for_binary_scalar_specialization(g):
keys = [
UfuncKey.CUDAFunctorOnSelf,
UfuncKey.CUDAFunctorOnOther,
UfuncKey.CUDAFunctor,
]
else:
keys = [UfuncKey.CUDAFunctor]
for k in [UfuncKey.CUDAFunctorOnSelf, UfuncKey.CUDAFunctorOnOther]:
assert k not in loops, f"cannot use {k} on non-binary function"
for k in keys:
# If the key was directly defined, skip functor codegen; we assume the
# user already done it for us
if k in loops:
ufunctor_sig = UfunctorSignature(
g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=loops[k].name
)
for dtype in loops[k].supported_dtypes:
ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig
continue
# Note [ScalarOnly and Generic must match names for CUDA]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Otherwise, look in ANY of the generic entries. For simplicity of
# codegen, both ScalarOnly and Generic are defined, the ufunc name
# must match (if they didn't match, we'd have to generate distinct
# functors per dtype, which is awful, so we're not going to do it unless
# someone really forces us to)
ufunc_name = None
supported_dtypes: OrderedSet[ScalarType] = OrderedSet()
for lk in [UfuncKey.ScalarOnly, UfuncKey.Generic]:
if lk not in loops:
continue
if ufunc_name is None:
ufunc_name = loops[lk].name
else:
# See Note [ScalarOnly and Generic must match names for CUDA]
assert (
ufunc_name == loops[lk].name
), "ScalarOnly and Generic must have same ufunc name"
supported_dtypes |= loops[lk].supported_dtypes
assert ufunc_name is not None
name = f"{k}_{ufunc_name}"
ufunctor_sig = UfunctorSignature(
g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=name
)
for dtype in supported_dtypes:
ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig
ufunc_sig = UfuncSignature(
g, name=f"ufunc::{ufunc_name}", compute_t=BaseCType(opmath_t)
)
apply_ctx = ufunctor_sig.fields() + ufunctor_sig.arguments().apply
ufunctors.append(
f"""
template <typename scalar_t>
struct {ufunctor_sig.name} {{
using opmath_t = at::opmath_type<scalar_t>;
{ufunctor_sig.decl_fields()}
{ufunctor_sig.inline_defn_ctor()}
__device__ {ufunctor_sig.decl_apply()} {{
return {ufunc_sig.call(apply_ctx)};
}}
}};
"""
)
return ufunctor_sigs, "\n".join(ufunctors)
@dataclass(frozen=True)
class BinaryScalarSpecializationConfig:
scalar_idx: int
ctor_tensor: str
ufunc_key: UfuncKey
BinaryScalarSpecializationConfigs = [
BinaryScalarSpecializationConfig(
scalar_idx=0,
ctor_tensor="self",
ufunc_key=UfuncKey.CUDAFunctorOnOther,
),
BinaryScalarSpecializationConfig(
scalar_idx=1,
ctor_tensor="other",
ufunc_key=UfuncKey.CUDAFunctorOnSelf,
),
]
def compute_ufunc_cuda_dtype_body(
g: NativeFunctionsGroup,
dtype: ScalarType,
inner_loops: Dict[UfuncKey, UfunctorSignature],
parent_ctx: Sequence[Binding],
) -> str:
body = "using opmath_t = at::opmath_type<scalar_t>;"
body += "if (false) {}\n" # for ease of codegen
for config in BinaryScalarSpecializationConfigs:
if config.ufunc_key not in inner_loops:
continue
ufunctor_sig = inner_loops[config.ufunc_key]
scalar_idx = config.scalar_idx + 1
# Make a copy and at the same time widen the type (not permissible
# without copy; we don't want to mutate the input argument anyway)
ctx: List[Union[Expr, Binding]] = list(parent_ctx)
ctx.append(
Expr(
expr=f"iter.scalar_value<opmath_t>({scalar_idx})",
type=NamedCType(config.ctor_tensor, BaseCType(opmath_t)),
)
)
ufunctor_ctor_exprs_str = ", ".join(
a.expr for a in translate(ctx, ufunctor_sig.arguments().ctor)
)
# NB: ufunctor must be allocated before iter.remove_operand is called,
# as it relies on iter
body += f"""\
else if (iter.is_cpu_scalar({scalar_idx})) {{
{ufunctor_sig.name}<scalar_t> ufunctor({ufunctor_ctor_exprs_str});
iter.remove_operand({scalar_idx});
gpu_kernel(iter, ufunctor);
}}"""
ufunctor_sig = inner_loops[UfuncKey.CUDAFunctor]
ufunctor_ctor_exprs_str = ", ".join(
a.expr for a in translate(parent_ctx, ufunctor_sig.arguments().ctor)
)
body += f"""
else {{
gpu_kernel(iter, {ufunctor_sig.name}<scalar_t>({ufunctor_ctor_exprs_str}));
}}
"""
return body
@with_native_function
def compute_ufunc_cuda(g: NativeFunctionsGroup) -> str:
# First, build the functors, indexing them by dtype
ufunctor_sigs, ufunctors = compute_ufunc_cuda_functors(g)
# Next, build the conditionals
sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CUDA))
dtype_cases = []
for dtype, inner_ufunc_sigs in ufunctor_sigs.items():
dtype_cases.append(
f"""
AT_DISPATCH_CASE(at::ScalarType::{dtype},
[&]() {{
{compute_ufunc_cuda_dtype_body(g, dtype, inner_ufunc_sigs, sig.arguments())}
}}
)
"""
)
dtype_cases_str = "\n".join(dtype_cases)
stub_sig = StubSignature(g)
return f"""
{ufunctors}
{stub_sig.type_defn()};
{stub_sig.dispatch_decl()};
{stub_sig.kernel_defn()} {{
AT_DISPATCH_SWITCH(iter.common_dtype(), "{sig.name}",
{dtype_cases_str}
);
}}
REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name});
{sig.defn()} {{
{stub_sig.direct_call(sig.arguments())};
}}
"""
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# CPU STUFF
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
@dataclass(frozen=True)
class StubSignature:
g: NativeFunctionsGroup
@property
def name(self) -> str:
return f"{str(self.g.functional.func.name.name)}_stub"
@property
def kernel_name(self) -> str:
return f"{str(self.g.functional.func.name.name)}_kernel"
@property
def type_name(self) -> str:
return f"{str(self.g.functional.func.name.name)}_fn"
def arguments(self) -> List[Binding]:
return ufunc.stub_arguments(self.g)
def type(self) -> str:
cpp_args = self.arguments()
return f"void(*)(TensorIteratorBase&, {', '.join(a.type for a in cpp_args)})"
def dispatch_decl(self) -> str:
return f"DECLARE_DISPATCH({self.type_name}, {self.name})"
def dispatch_defn(self) -> str:
return f"DEFINE_DISPATCH({self.name})"
def kernel_defn(self) -> str:
return f"void {self.kernel_name}(TensorIteratorBase& iter, {', '.join(a.defn() for a in self.arguments())})"
def type_defn(self) -> str:
return f"using {self.type_name} = {self.type()}"
# must be called from context where this is TensorIteratorBase*
def call(self, ctx: Sequence[Binding]) -> str:
return f"{self.name}(device_type(), *this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"
# used in CUDA to skip the unnecessary dynamic dispatch
def direct_call(self, ctx: Sequence[Binding]) -> str:
return f"{self.kernel_name}(*this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"
@with_native_function
def compute_ufunc_cpu(g: NativeFunctionsGroup) -> str:
stub_sig = StubSignature(g)
sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CPU))
return f"""
{stub_sig.type_defn()};
{stub_sig.dispatch_decl()};
{stub_sig.dispatch_defn()};
{sig.defn()} {{
{stub_sig.call(sig.arguments())};
}}
"""
def compute_ufunc_cpu_dtype_body(
g: NativeFunctionsGroup,
dtype: ScalarType,
inner_loops: Dict[UfuncKey, UfuncSignature],
parent_ctx: Sequence[Binding],
) -> str:
assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}"
assert inner_loops.keys() <= {UfuncKey.CPUScalar, UfuncKey.CPUVector}
scalar_loop = inner_loops[UfuncKey.CPUScalar]
vec_loop = None
if UfuncKey.CPUVector in inner_loops:
vec_loop = inner_loops[UfuncKey.CPUVector]
# NB: We DON'T use translate here, because translate is
# incapable of CSE'ing the scalar accesses in case it is also
# used by Vectorized; also, the unpacking here is very simple
# and only affects Scalar; everything else is implicitly captured
# by the lambda
# Setup scalar in scope
body = []
ctx = []
for b in parent_ctx:
if isinstance(b.argument, Argument) and b.argument.type != BaseType(
BaseTy.Scalar
):
continue
body.append(f"auto _s_{b.name} = {b.name}.to<scalar_t>();")
ctx.append(Expr(f"_s_{b.name}", NamedCType(b.nctype.name, BaseCType(scalar_t))))
if vec_loop is not None:
for b in parent_ctx:
if isinstance(b.argument, Argument) and b.argument.type != BaseType(
BaseTy.Scalar
):
continue
body.append(
f"auto _v_{b.name} = at::vec::Vectorized<scalar_t>(_s_{b.name});"
)
ctx.append(
Expr(
f"_v_{b.name}",
NamedCType(b.nctype.name, VectorizedCType(BaseCType(scalar_t))),
)
)
# Setup lambda signature
# NB: simplified version of ufunctor_arguments
scalar_bindings = []
vec_bindings = []
for a in g.functional.func.arguments.flat_non_out:
if not a.type.is_tensor_like():
continue
assert a.type == BaseType(BaseTy.Tensor)
scalar_bindings.append(
Binding(
name=a.name,
nctype=NamedCType(a.name, BaseCType(scalar_t)),
argument=a,
)
)
if vec_loop is not None:
vec_bindings.append(
Binding(
name=a.name,
nctype=NamedCType(a.name, VectorizedCType(BaseCType(scalar_t))),
argument=a,
)
)
def with_ctx(b: Sequence[Binding]) -> List[Union[Expr, Binding]]:
r: List[Union[Expr, Binding]] = []
r.extend(ctx)
r.extend(b)
return r
body_str = "\n".join(body)
if vec_loop is not None:
return f"""
{body_str}
cpu_kernel_vec(iter,
[=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }},
[=]({', '.join(b.decl() for b in vec_bindings)}) {{ return {vec_loop.call(with_ctx(vec_bindings))}; }}
);
"""
else:
return f"""
{body_str}
cpu_kernel(iter,
[=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }}
);
"""
@with_native_function
def compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str:
stub_sig = StubSignature(g)
# Reindex the ufunc by dtypes; processing generic/scalaronly as well
loops = g.out.ufunc_inner_loop
ufunc_sigs: Dict[ScalarType, Dict[UfuncKey, UfuncSignature]] = {}
for k in [UfuncKey.CPUScalar, UfuncKey.CPUVector]:
lks = []
# ORDER MATTERS: this specifies overriding precedence
if k in loops: # should happen rarely
lks.append(k)
if UfuncKey.ScalarOnly in loops and k is UfuncKey.CPUScalar:
lks.append(UfuncKey.ScalarOnly)
if UfuncKey.Generic in loops:
lks.append(UfuncKey.Generic)
# TODO: don't hardcode ufunc:: namespace here, should be centralized smh
for lk in lks:
for dtype in loops[lk].supported_dtypes:
compute_t: CType
if k is UfuncKey.CPUScalar:
compute_t = BaseCType(scalar_t)
elif k is UfuncKey.CPUVector:
compute_t = VectorizedCType(BaseCType(scalar_t))
else:
raise AssertionError()
inner_ufunc_sigs = ufunc_sigs.setdefault(dtype, {})
if k not in inner_ufunc_sigs:
inner_ufunc_sigs[k] = UfuncSignature(
g, name=f"ufunc::{loops[lk].name}", compute_t=compute_t
)
# Build the conditionals
dtype_cases = []
for dtype, inner_ufunc_sigs in ufunc_sigs.items():
dtype_cases.append(
f"""
AT_DISPATCH_CASE(at::ScalarType::{dtype},
[&]() {{
{compute_ufunc_cpu_dtype_body(g, dtype, inner_ufunc_sigs, stub_sig.arguments())}
}}
)
"""
)
dtype_cases_str = "\n".join(dtype_cases)
return f"""
namespace {{
{stub_sig.kernel_defn()} {{
AT_DISPATCH_SWITCH(iter.common_dtype(), "{stub_sig.name}",
{dtype_cases_str}
);
}}
}} // anonymous namespace
{stub_sig.type_defn()};
{stub_sig.dispatch_decl()};
REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name});
"""