forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsym_node.py
1365 lines (1060 loc) · 41.4 KB
/
sym_node.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
This file does three things:
- Contains the definition of SymNode
- Installs all the magic methods into SymBool, SymFloat, SymFloat at import time
- Does not depend on sympy at import time
As this file is imported from within torch/__init__.py we do not want it to depend on SymPy
to avoid having to load SymPy at import time, as doing so is *very* slow.
"""
import builtins
import itertools
import logging
import math
import operator
import sys
from functools import lru_cache, update_wrapper
from typing import Optional, Type, TYPE_CHECKING, Union
import torch
# NB: The sym_* functions are used via getattr() and must be imported here.
from torch import ( # noqa: F401
sym_float,
sym_ite,
sym_max,
sym_min,
sym_not,
SymBool,
SymFloat,
SymInt,
)
from torch.fx.experimental._sym_dispatch_mode import (
handle_sym_dispatch,
sym_function_mode,
)
if TYPE_CHECKING:
from torch.fx.experimental.symbolic_shapes import ShapeEnv
log = logging.getLogger(__name__)
sym_node_log = torch._logging.getArtifactLogger(__name__, "sym_node")
__all__ = ["SymNode", "method_to_operator", "magic_methods"]
SymTypes = (SymInt, SymFloat, SymBool)
def _to_symtype(t):
if t is bool:
return SymBool
if t is int:
return SymInt
if t is float:
return SymFloat
return t
# TODO: An incomplete list
# 1. Set variables to be equal when we do equality
# 2. Specialize on 0/1 when we do subtraction
class SymNode:
"""
This is a type erased SymInt/SymFloat which we use to do actual operations.
End users don't touch this. Magic methods are NOT defined on this object.
"""
def __init__(
self,
expr,
shape_env,
pytype,
hint: Optional[Union[int, float, bool]],
constant=None,
fx_node=None,
):
self._expr = expr
self.shape_env = shape_env
self.pytype = pytype
# What's the difference between hint and constant?
#
# - A constant is known to be invariant across invocations of the model;
# it will always be this value. We only really know this when we
# encounter an honest-to-goodness literal (when wrapping it into
# a SymNode, we set constant.) Most of the time, constant is None
#
# - A hint is a *particular* value from the particular run we are
# tracing, but it may vary the next time around. It's useful to
# keep this around, as if we need a concrete value from a SymNode,
# we will return the hint and guard on the expression that produced
# it giving the same hint next time around. The hint is not
# guaranteed to be set either: if you have an unbacked SymNode,
# there won't be any hint; it was the result of some tensor-dependent
# computation, but we don't know what it actually is because we
# haven't actually run the tensor computation.
#
# If _hint is None, we will query maybe_evaluate_static(compute_hint=True)
# in hopes that we've learned enough about the unbacked symints to
# discharge the hint; otherwise, you're likely to just error out.
#
# (A previous version of this system had some optimizations to only
# recompute when it was possible we had learned enough about the
# unbacked symint that a hint was now possible, but as we added more
# potential refinements to unbacked symints this got harder to keep
# in sync, so we've deleted it for now.)
if hint is not None:
assert type(hint) is pytype or type(hint) is _to_symtype(pytype), (
"Cannot create SymNode of type "
f"{pytype} with incompatible hint of type {type(hint)}"
)
self._hint = hint
self.constant: Optional[Union[int, float, bool]] = constant
# Record the FX node of the current node if we are doing translation
# validation. They will be used for building the input assertions for
# the translation validation problem.
self.fx_node = (
fx_node if self.shape_env._translation_validation_enabled else None
)
def with_shape_env(self, shape_env: "ShapeEnv") -> "SymNode":
return SymNode(
self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node
)
@property
def expr(self):
return self.shape_env.replace(self._expr)
# Recompute the hint and see if we've got it now
# Precondition: self._hint is None
def _update_hint(self):
r = self.shape_env._maybe_evaluate_static(self.expr, compute_hint=True)
if r is not None:
self._hint = self.pytype(r) if not isinstance(r, SymTypes) else r
@property
def hint(self):
if self._hint is None:
self._update_hint()
return self._hint
def has_hint(self):
if self._hint is None:
self._update_hint()
return self._hint is not None
def require_hint(self, fallback=None):
if self._hint is None:
self._update_hint()
if self._hint is None:
if fallback is not None:
return fallback
# NB: we expect this to raise
return self.shape_env.size_hint(self.expr)
return self._hint
def maybe_as_int(self):
if self.expr.is_number:
return int(self.expr)
else:
return None
# NB: This does conversions, not sure if this is good or not
def maybe_as_float(self):
import sympy
if isinstance(self.expr, sympy.Float):
return float(self.expr)
else:
return None
def maybe_as_bool(self):
import sympy
if self.expr is sympy.true:
return True
elif self.expr is sympy.false:
return False
else:
return None
def is_int(self):
return self.pytype is int
def is_float(self):
return self.pytype is float
def is_bool(self):
return self.pytype is bool
def is_nested_int(self):
# Unbacked SymInts cannot be nested int today
return (
self._hint is not None
and isinstance(self._hint, SymInt)
and self._hint.node.is_nested_int()
)
def wrap_int(self, num):
assert type(num) is int
import sympy
return SymNode(
sympy.Integer(num), self.shape_env, int, num, constant=num, fx_node=num
)
def wrap_float(self, num):
assert type(num) is float
import sympy
return SymNode(
sympy.Float(num), self.shape_env, float, num, constant=num, fx_node=num
)
def wrap_bool(self, num):
assert type(num) is bool
import sympy
return SymNode(
sympy.true if num else sympy.false,
self.shape_env,
bool,
num,
constant=num,
fx_node=num,
)
def clone(self):
return self
def str(self):
return f"{self.expr}"
def __str__(self):
return self.str()
def __repr__(self):
return self.str()
# These methods call the metaprogrammed methods, they're hand written
# here so we get good stack traces
def abs(self) -> "SymNode":
return self._abs() # type: ignore[attr-defined]
def pos(self) -> "SymNode":
return self._pos() # type: ignore[attr-defined]
def round(self, ndigits=None) -> "SymNode":
return self._round(ndigits) # type: ignore[attr-defined]
def trunc(self) -> "SymNode":
return self._trunc() # type: ignore[attr-defined]
def add(self, other) -> "SymNode":
return self._add(other) # type: ignore[attr-defined]
def sub(self, other) -> "SymNode":
return self._sub(other) # type: ignore[attr-defined]
def mul(self, other) -> "SymNode":
return self._mul(other) # type: ignore[attr-defined]
def mod(self, other) -> "SymNode":
return self._mod(other) # type: ignore[attr-defined]
def pow(self, other) -> "SymNode":
return self._pow(other) # type: ignore[attr-defined]
def and_(self, other) -> "SymNode":
return self._and_(other) # type: ignore[attr-defined]
def or_(self, other) -> "SymNode":
return self._or_(other) # type: ignore[attr-defined]
def truediv(self, other) -> "SymNode":
return self._truediv(other) # type: ignore[attr-defined]
def floordiv(self, other) -> "SymNode":
return self._floordiv(other) # type: ignore[attr-defined]
def lshift(self, other) -> "SymNode":
return self._lshift(other) # type: ignore[attr-defined]
def rshift(self, other) -> "SymNode":
return self._rshift(other) # type: ignore[attr-defined]
def sym_not(self) -> "SymNode": # noqa: F811
return self._sym_not() # type: ignore[attr-defined]
def eq(self, other) -> "SymNode":
return self._eq(other) # type: ignore[attr-defined]
def ne(self, other) -> "SymNode":
return self._ne(other) # type: ignore[attr-defined]
def gt(self, other) -> "SymNode":
return self._gt(other) # type: ignore[attr-defined]
def lt(self, other) -> "SymNode":
return self._lt(other) # type: ignore[attr-defined]
def le(self, other) -> "SymNode":
return self._le(other) # type: ignore[attr-defined]
def ge(self, other) -> "SymNode":
return self._ge(other) # type: ignore[attr-defined]
def floor(self) -> "SymNode":
return self._floor() # type: ignore[attr-defined]
def is_integer(self) -> "SymNode":
return self._is_integer() # type: ignore[attr-defined]
def sym_float(self) -> "SymNode": # noqa: F811
return self._sym_float() # type: ignore[attr-defined]
def sym_int(self) -> "SymNode":
return self._sym_int() # type: ignore[attr-defined]
def ceil(self) -> "SymNode":
return self._ceil() # type: ignore[attr-defined]
def neg(self) -> "SymNode":
return self._neg() # type: ignore[attr-defined]
def sym_min(self, other) -> "SymNode": # noqa: F811
return self._sym_min(other) # type: ignore[attr-defined]
def sym_max(self, other) -> "SymNode": # noqa: F811
return self._sym_max(other) # type: ignore[attr-defined]
def sym_ite(self, then_val, else_val) -> "SymNode":
return self._sym_ite(then_val, else_val) # type: ignore[attr-defined]
def is_contiguous(self, sizes, strides) -> "SymNode":
return self._is_contiguous(sizes, strides) # type: ignore[attr-defined]
def is_channels_last_contiguous_2d(self, sizes, strides) -> "SymNode":
return self._is_channels_last_contiguous_2d(sizes, strides) # type: ignore[attr-defined]
def is_channels_last_contiguous_3d(self, sizes, strides) -> "SymNode":
return self._is_channels_last_contiguous_3d(sizes, strides) # type: ignore[attr-defined]
def is_channels_last_strides_2d(self, sizes, strides) -> "SymNode":
return self._is_channels_last_strides_2d(sizes, strides) # type: ignore[attr-defined]
def is_channels_last_strides_3d(self, sizes, strides) -> "SymNode":
return self._is_channels_last_strides_3d(sizes, strides) # type: ignore[attr-defined]
def is_non_overlapping_and_dense_indicator(self, sizes, strides) -> "SymNode":
return self._is_non_overlapping_and_dense_indicator(sizes, strides) # type: ignore[attr-defined]
# Make C++ happy
def sym_or(self, other):
return self.or_(other)
def sym_and(self, other):
return self.and_(other)
def is_non_overlapping_and_dense(self, sizes, strides):
return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(to_node(self, 1)) # type: ignore[attr-defined]
def int_(self):
return self.guard_int("", 0) # NB: uses Python backtrace
# You can manually trigger a guard with this function
def guard_int(self, file, line):
# TODO: use the file/line for some useful diagnostic on why a
# guard occurred
r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
try:
return int(r)
except Exception:
log.warning("Failed to convert to int: %s", r)
raise
def guard_float(self, file, line):
# TODO: use the file/line for some useful diagnostic on why a
# guard occurred
r = self.shape_env.evaluate_expr(
self.expr, self.hint, fx_node=self.fx_node, expect_rational=False
)
try:
return float(r)
except Exception:
log.warning("Failed to convert to float: %s", r)
raise
def guard_bool(self, file, line):
# TODO: use the file/line for some useful diagnostic on why a
# guard occurred
r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
try:
return bool(r)
except Exception:
log.warning("Failed to convert to bool: %s", r)
raise
def expect_true(self, file, line):
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
if (
self.has_hint()
and not free_unbacked_symbols(self.expr)
and not self.shape_env.prefer_deferred_runtime_asserts_over_guards
):
# OK to generate guards
return self.guard_bool(file, line)
# Generate a deferred runtime assert (this might actually end up doing
# a regular guard if we can!)
# TODO: file/line here is very important, because the assert has been
# deferred so you can't backtrace easily
return self.shape_env.defer_runtime_assert(
self.expr, f"{file}:{line}", fx_node=self.fx_node
)
def expect_size(self, file, line):
from torch.fx.experimental.symbolic_shapes import _advise_is_size
b = self.ge(self.wrap_int(0))
# Generate a deferred runtime assert
r = b.expect_true(file, line)
# Refine compile time range, but only if it's unbacked.
# If you refine range for hinted variables, you can end up making
# improper deductions since compile time reasoning may be
# incompatible with runtime reasoning.
if r and not self.has_hint():
_advise_is_size(SymInt(self))
return r
def guard_size_oblivious(self, file, line):
"""
Like guard_bool, but if we encounter unbacked symbols, if those symbols
are size-like, we will treat them as >= 2 for the purposes of the analysis.
This CHANGES the runtime semantics, but all size-oblivious sites have been
audited to ensure that the runtime semantics don't change in a material way.
Acceptable runtime semantic changes are, e.g., squeeze() no longer dropping
an unbacked one size, or a tensor reporting as non-contiguous even if it's
contiguous if it would have been reported contiguous due to being empty.
"""
# TODO: use the file/line for some useful diagnostic on why a
# guard occurred
r = self.shape_env.evaluate_expr(
self.expr, self.hint, fx_node=self.fx_node, size_oblivious=True
)
try:
return bool(r)
except Exception:
log.warning("Failed to convert to bool: %s", r)
raise
def bool_(self):
return self.guard_bool("", 0)
def is_symbolic(self):
return True
def nested_int(self):
return None
def is_constant(self):
return False
# TODO: this probably needs the sizes-strides eval functions
METHOD_TO_OPERATOR = {
"pos": operator.pos,
"abs": operator.abs,
"add": operator.add,
"and": operator.and_,
"ceil": math.ceil,
"eq": operator.eq,
"floor": math.floor,
"trunc": math.trunc,
"floordiv": operator.floordiv,
"ge": operator.ge,
"gt": operator.gt,
"is_integer": lambda x: x.is_integer(),
"le": operator.le,
"lshift": operator.lshift,
"lt": operator.lt,
"mod": operator.mod,
"mul": operator.mul,
"ne": operator.ne,
"neg": operator.neg,
"or": operator.or_,
"pow": operator.pow,
"round": builtins.round,
"rshift": operator.rshift,
"sub": operator.sub,
"sym_float": sym_float,
"sym_ite": sym_ite,
"sym_max": sym_max,
"sym_min": sym_min,
"sym_not": sym_not,
"truediv": operator.truediv,
}
unary_magic_methods = {
"abs",
"sym_float",
"ceil",
"floor",
"neg",
"sym_not",
"pos",
"trunc",
}
# Adding math ops: sqrt, cos, sin, ...
def _get_sym_node_fn(name):
def fn(self):
return getattr(self, f"_sym_{name}")()
return fn
math_op_names = (
"sqrt",
"cos",
"cosh",
"sin",
"sinh",
"tan",
"tanh",
"asin",
"acos",
"atan",
)
for name in math_op_names:
sym_name = f"sym_{name}"
priv_sym_name = f"_{sym_name}"
setattr(SymNode, sym_name, _get_sym_node_fn(name))
METHOD_TO_OPERATOR[sym_name] = getattr(torch, priv_sym_name)
unary_magic_methods.add(sym_name)
__all__.append(sym_name)
# Unary methods that are not magic methods
unary_nonmagic_methods = {
"is_integer",
}
unary_methods = unary_magic_methods | unary_nonmagic_methods
# Most methods are only registered on SymInt and SymFloat
# Some methods are only be registered on SymBool
only_bool_magic_methods = {"and", "or", "sym_not", "sym_ite"}
# Methods that implicitly convert SymBool into SymInt
bool_becomes_int_magic_methods = {"add", "sub", "mul"}
# Methods that are also on SymBool, in addition to on SymInt and SymFloat
also_bool_magic_methods = {"eq"}
bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods
# Methods that are only for float
only_float_magic_methods = {"is_integer"}
magic_methods_on_operator_with_trailing_underscore = {"and", "or"}
always_float_magic_methods = {"truediv", "sym_float", "pow"}
for name in math_op_names:
sym_name = f"sym_{name}"
always_float_magic_methods.add(sym_name)
always_int_magic_methods = {"ceil", "floor", "trunc"}
always_bool_magic_methods = {
"eq",
"ne",
"gt",
"lt",
"le",
"ge",
"and",
"or",
"sym_not",
"is_non_overlapping_and_dense",
"is_integer",
}
# Methods that have a `__foo__` as well as `__rfoo__`
def _sympy_truediv(a, b):
from torch.utils._sympy.functions import TrueDiv
return TrueDiv(a, b)
def _sympy_floordiv(a, b):
from torch.utils._sympy.functions import FloorDiv
return FloorDiv(a, b)
def _sympy_mod(a, b):
from torch.utils._sympy.functions import Mod
return Mod(a, b)
def _sympy_pow(a, b):
from torch.utils._sympy.functions import Pow
return Pow(a, b)
def _sympy_and(a, b):
import sympy
return sympy.And(a, b)
def _sympy_or(a, b):
import sympy
return sympy.Or(a, b)
def _sympy_lshift(a, b):
from torch.utils._sympy.functions import LShift
return LShift(a, b)
def _sympy_rshift(a, b):
from torch.utils._sympy.functions import RShift
return RShift(a, b)
reflectable_magic_methods = {
"add": operator.add,
"sub": operator.sub,
"mul": operator.mul,
"mod": _sympy_mod,
"pow": _sympy_pow,
"and": _sympy_and,
"or": _sympy_or,
"truediv": _sympy_truediv,
"floordiv": _sympy_floordiv,
"lshift": _sympy_lshift,
"rshift": _sympy_rshift,
}
def _floor_ceil_helper(a, fn):
import sympy
if isinstance(a, sympy.Mul):
aa = a.args
if len(aa) == 2 and isinstance(aa[0], sympy.Float) and aa[1].is_integer:
coef = sympy.Integer(aa[0])
if aa[0] == coef: # structural equality test
return coef * aa[1]
if (
isinstance(a, sympy.Float)
and a == sympy.Integer(a)
or isinstance(a, sympy.Integer)
):
return sympy.Integer(a)
return fn(a)
def _sympy_floor(a):
import sympy
return _floor_ceil_helper(a, sympy.floor)
def _sympy_trunc(a):
from torch.utils._sympy.functions import Trunc
return Trunc(a)
def _sympy_ceil(a):
import sympy
return _floor_ceil_helper(a, sympy.ceiling)
def _sympy_eq(a, b):
import sympy
return sympy.Eq(a, b)
def _sympy_ne(a, b):
import sympy
return sympy.Ne(a, b)
def _sympy_gt(a, b):
import sympy
return sympy.Gt(a, b)
def _sympy_lt(a, b):
import sympy
return sympy.Lt(a, b)
def _sympy_le(a, b):
import sympy
return sympy.Le(a, b)
def _sympy_ge(a, b):
import sympy
return sympy.Ge(a, b)
def _sympy_min(a, b):
import sympy
return sympy.Min(a, b)
def _sympy_max(a, b):
import sympy
return sympy.Max(a, b)
def _sympy_ite(a, t, f):
import sympy
return sympy.Piecewise((t, a), (f, True))
current_module = sys.modules[__name__]
def _get_sym_math_fn(name):
def fn(a):
import torch.utils._sympy.functions
return getattr(torch.utils._sympy.functions, f"OpaqueUnaryFn_{name}")(a)
return fn
for name in math_op_names:
priv_sympy_name = f"_sympy_{name}"
fn = _get_sym_math_fn(name)
fn.__qualname__ = fn.__name__ = priv_sympy_name
setattr(current_module, priv_sympy_name, fn)
del fn, name, priv_sympy_name # type: ignore[possibly-undefined]
def _sympy_abs(a):
import sympy
return sympy.Abs(a)
def _sympy_round(number, ndigits=None):
from torch.utils._sympy.functions import Round, RoundDecimal
if ndigits is None:
return Round(number)
else:
return RoundDecimal(number, ndigits)
def _sympy_sym_float(a):
# Cannot use sympy.Float(a) here, coz it expects python literals
# Multiply by 1.0 to cast to float. This is needed when the input
# is a SymInt which has the assumption that it is integer and
# SymPy will otherwise assume that return value cannot be a float.
return a * 1.0
def _sympy_is_integer(a):
import sympy
return sympy.Eq(sympy.floor(a), a)
magic_methods = {
**reflectable_magic_methods,
"sym_not": operator.invert,
"pos": operator.pos,
"eq": _sympy_eq,
"ne": _sympy_ne,
"gt": _sympy_gt,
"lt": _sympy_lt,
"le": _sympy_le,
"ge": _sympy_ge,
"floor": _sympy_floor,
"trunc": _sympy_trunc,
"sym_float": _sympy_sym_float,
"ceil": _sympy_ceil,
"neg": operator.neg,
"sym_min": _sympy_min,
"sym_max": _sympy_max,
"sym_ite": _sympy_ite,
"abs": _sympy_abs,
"round": _sympy_round,
"is_integer": _sympy_is_integer,
}
for name in math_op_names:
sym_name = f"sym_{name}"
magic_methods[sym_name] = getattr(current_module, f"_sympy_{name}")
del name, sym_name, math_op_names, current_module # type: ignore[possibly-undefined]
def sympy_is_contiguous(sizes, strides):
dim = len(sizes)
return sympy_is_contiguous_generic(sizes, strides, list(range(dim - 1, -1, -1)))
def sympy_is_contiguous_generic(sizes, strides, dim_order):
import sympy
dim = len(sizes)
if len(dim_order) != dim:
return sympy.false
is_contiguous = sympy.true
z = sympy.Integer(1)
# Contiguous if the strides make sense (or the dim is size 1)
for d in dim_order:
is_contiguous &= sympy.Eq(sizes[d], sympy.Integer(1)) | sympy.Eq(strides[d], z)
z *= sizes[d]
# OR if any size is zero
for d in range(dim):
is_contiguous |= sympy.Eq(sizes[d], sympy.Integer(0))
return is_contiguous
# NB: There is a TODO in C++ to allow omitting the batch dim. If that
# happens you will need to refactor this
def sympy_is_channels_last_contiguous_2d(sizes, strides):
return sympy_is_contiguous_generic(sizes, strides, [1, 3, 2, 0])
def sympy_is_channels_last_contiguous_3d(sizes, strides):
return sympy_is_contiguous_generic(sizes, strides, [1, 4, 3, 2, 0])
def sympy_is_channels_last_strides_generic(sizes, strides, dim_order):
import sympy
dim = len(sizes)
if dim != len(dim_order):
return sympy.false
m = sympy.Integer(0)
r = sympy.true
# special case for trivial C dimension. default to NCHW
r &= sympy.Ne(strides[1], 0)
for d in dim_order:
r &= sympy.Ne(sizes[d], 0) & (strides[d] >= m)
# Fallback to NCHW as default layout for ambiguous cases
# This is the flaw of implicit memory_format from strides.
# N111 tensor with identical strides for size 1 dimension;
# Two cases could lead us here:
# a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1])
# b. N11W contiguous Tensor sliced on the W-dimension.
# ([N,1,1,1]@[W,W,W,W])
if d == 0:
r &= sympy.Ne(m, strides[1])
# This is necessary to:
# 1. distinguish the memory_format of N1H1;
# [H, 1, 1, 1] channels_last stride
# [H, H, 1, 1] contiguous stride
# 2. permutation of 1C1W:
# [1, C, 1, H]@[HC, H, H, 1] transpose(1, 3)
# [1, H, 1, C]@[HC, 1, H, H] shouldn't be identified as
# channels_last
m = strides[d] * sympy.Max(sizes[d], 1)
return r
def sympy_is_channels_last_strides_2d(sizes, strides):
return sympy_is_channels_last_strides_generic(sizes, strides, [1, 3, 2, 0])
def sympy_is_channels_last_strides_3d(sizes, strides):
return sympy_is_channels_last_strides_generic(sizes, strides, [1, 4, 3, 2, 0])
def _sympy_is_non_overlapping_and_dense_indicator(sizes, strides):
from torch.utils._sympy.functions import IsNonOverlappingAndDenseIndicator
return IsNonOverlappingAndDenseIndicator(*sizes, *strides)
sizes_strides_methods = {
# TODO: These could also be done with indicators, maybe it is better
# for reasoning to do it that way
"is_contiguous": sympy_is_contiguous,
"is_channels_last_contiguous_2d": sympy_is_channels_last_contiguous_2d,
"is_channels_last_contiguous_3d": sympy_is_channels_last_contiguous_3d,
"is_channels_last_strides_2d": sympy_is_channels_last_strides_2d,
"is_channels_last_strides_3d": sympy_is_channels_last_strides_3d,
"is_non_overlapping_and_dense_indicator": _sympy_is_non_overlapping_and_dense_indicator,
}
alternate_impl_if_hinted_methods = {
"sym_min": builtins.min,
"sym_max": builtins.max,
}
def to_node(self, num):
if isinstance(num, SymTypes):
return num.node
elif type(num) is bool:
return self.wrap_bool(num)
elif type(num) is int:
return self.wrap_int(num)
elif type(num) is float:
return self.wrap_float(num)
else:
# NotImplemented is important so that Python tries the
# other magic method
return NotImplemented
def wrap_node(x):
# TODO: let C++ also take advantage of this
if isinstance(x, SymNode) and x.constant is not None:
return x.constant
if x.is_int():
return SymInt(x)
elif x.is_float():
return SymFloat(x)
elif x.is_bool():
return SymBool(x)
else:
raise AssertionError(f"unrecognized return type {x}")
def method_to_operator(method):
return METHOD_TO_OPERATOR[method]
def _make_node_magic(method, func):
func = lru_cache(256)(func)
if method in magic_methods_on_operator_with_trailing_underscore:
method_attr = f"{method}_"
else:
method_attr = method
def binary_magic_impl(self, other):
from torch.fx.experimental.symbolic_shapes import safe_expand
op = method_to_operator(method)
out_hint = None
if self.hint is not None and other.hint is not None:
out_hint = op(self.hint, other.hint)
alternate_impl = alternate_impl_if_hinted_methods.get(method)
if alternate_impl and out_hint is not None:
return to_node(self, alternate_impl(wrap_node(self), wrap_node(other)))
if sym_function_mode():
return to_node(
self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {})
)
assert isinstance(other, SymNode)
# TODO: consider constant prop here
try:
out = func(self.expr, other.expr)
except Exception:
log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr)
raise
out = safe_expand(out)
sym_node_log.debug("%s %s %s -> %s", func, self.expr, other.expr, out)
pytype: Type