forked from pytorch/pytorch.github.io
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathguards-overview.html
1270 lines (1073 loc) · 86.4 KB
/
guards-overview.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
<meta charset="utf-8">
<meta name="generator" content="Docutils 0.18.1: http://docutils.sourceforge.net/" />
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Guards Overview — PyTorch master documentation</title>
<link rel="canonical" href="https://pytorch.org/docs/stable/dynamo/guards-overview.html"/>
<link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
<!-- <link rel="stylesheet" href="../_static/pygments.css" type="text/css" /> -->
<link rel="stylesheet" href="../_static/pygments.css" type="text/css" />
<link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
<link rel="stylesheet" href="../_static/copybutton.css" type="text/css" />
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/katex.min.css" type="text/css" />
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/katex.min.css" type="text/css" />
<link rel="stylesheet" href="../_static/katex-math.css" type="text/css" />
<link rel="stylesheet" href="../_static/sphinx-dropdown.css" type="text/css" />
<link rel="stylesheet" href="../_static/panels-bootstrap.min.css" type="text/css" />
<link rel="stylesheet" href="../_static/css/jit.css" type="text/css" />
<link rel="index" title="Index" href="../genindex.html" />
<link rel="search" title="Search" href="../search.html" />
<link rel="next" title="Custom Backends" href="custom-backends.html" />
<link rel="prev" title="Getting Started" href="get-started.html" />
<!--
Search engines should not index the master version of documentation.
Stable documentation are built without release == 'master'.
-->
<meta name="robots" content="noindex">
<!-- Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=UA-117752657-2"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag(){dataLayer.push(arguments);}
gtag('js', new Date());
gtag('config', 'UA-117752657-2');
</script>
<!-- End Google Analytics -->
<script src="../_static/js/modernizr.min.js"></script>
<!-- Preload the theme fonts -->
<link rel="preload" href="../_static/fonts/FreightSans/freight-sans-book.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="../_static/fonts/FreightSans/freight-sans-medium.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="../_static/fonts/IBMPlexMono/IBMPlexMono-Medium.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="../_static/fonts/FreightSans/freight-sans-bold.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="../_static/fonts/FreightSans/freight-sans-medium-italic.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="../_static/fonts/IBMPlexMono/IBMPlexMono-SemiBold.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<!-- Preload the katex fonts -->
<link rel="preload" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/fonts/KaTeX_Math-Italic.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/fonts/KaTeX_Main-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/fonts/KaTeX_Main-Bold.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/fonts/KaTeX_Size1-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/fonts/KaTeX_Size4-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/fonts/KaTeX_Size2-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/fonts/KaTeX_Size3-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/fonts/KaTeX_Caligraphic-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.15.2/css/all.css" integrity="sha384-vSIIfh2YWi9wW0r9iZe7RJPrKwp6bG+s9QZMoITbCckVJqGCCRhc+ccxNcdpHuYu" crossorigin="anonymous">
</head>
<div class="container-fluid header-holder tutorials-header" id="header-holder">
<div class="container">
<div class="header-container">
<a class="header-logo" href="https://pytorch.org/" aria-label="PyTorch"></a>
<div class="main-menu">
<ul>
<li>
<a href="https://pytorch.org/get-started">Get Started</a>
</li>
<li>
<a href="https://pytorch.org/ecosystem">Ecosystem</a>
</li>
<li>
<a href="https://pytorch.org/mobile">Mobile</a>
</li>
<li>
<a href="https://pytorch.org/blog/">Blog</a>
</li>
<li>
<a href="https://pytorch.org/tutorials">Tutorials</a>
</li>
<li class="active docs-active">
<div id="resourcesDropdownButton" data-toggle="resources-dropdown" class="resources-dropdown">
<a class="resource-option with-down-orange-arrow">
Docs
</a>
<div class="resources-dropdown-menu">
<a class="doc-dropdown-option nav-dropdown-item" href="https://pytorch.org/docs/stable/index.html">
<span class="dropdown-title">PyTorch</span>
<p></p>
</a>
<a class="doc-dropdown-option nav-dropdown-item" href="https://pytorch.org/audio/stable/index.html">
<span class="dropdown-title">torchaudio</span>
<p></p>
</a>
<a class="doc-dropdown-option nav-dropdown-item" href="https://pytorch.org/text/stable/index.html">
<span class="dropdown-title">torchtext</span>
<p></p>
</a>
<a class="doc-dropdown-option nav-dropdown-item" href="https://pytorch.org/vision/stable/index.html">
<span class="dropdown-title">torchvision</span>
<p></p>
</a>
<a class="doc-dropdown-option nav-dropdown-item" href="https://pytorch.org/torcharrow">
<span class="dropdown-title">torcharrow</span>
<p></p>
</a>
<a class="doc-dropdown-option nav-dropdown-item" href="https://pytorch.org/data">
<span class="dropdown-title">TorchData</span>
<p></p>
</a>
<a class="doc-dropdown-option nav-dropdown-item" href="https://pytorch.org/torchrec">
<span class="dropdown-title">TorchRec</span>
<p></p>
</a>
<a class="doc-dropdown-option nav-dropdown-item" href="https://pytorch.org/serve/">
<span class="dropdown-title">TorchServe</span>
<p></p>
</a>
<a class="doc-dropdown-option nav-dropdown-item" href="https://pytorch.org/torchx/">
<span class="dropdown-title">TorchX</span>
<p></p>
</a>
<a class="doc-dropdown-option nav-dropdown-item" href="https://pytorch.org/xla">
<span class="dropdown-title">PyTorch on XLA Devices</span>
<p></p>
</a>
</div>
</li>
<li>
<div id="resourcesDropdownButton" data-toggle="resources-dropdown" class="resources-dropdown">
<a class="resource-option with-down-arrow">
Resources
</a>
<div class="resources-dropdown-menu">
<a class="nav-dropdown-item" href="https://pytorch.org/features">
<span class="dropdown-title">About</span>
<p>Learn about PyTorch’s features and capabilities</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/foundation">
<span class="dropdown-title">PyTorch Foundation</span>
<p>Learn about the PyTorch foundation</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/#community-module">
<span class="dropdown-title">Community</span>
<p>Join the PyTorch developer community to contribute, learn, and get your questions answered.</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/community-stories">
<span class="dropdown-title">Community Stories</span>
<p>Learn how our community solves real, everyday machine learning problems with PyTorch.</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/resources">
<span class="dropdown-title">Developer Resources</span>
<p>Find resources and get questions answered</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/events">
<span class="dropdown-title">Events</span>
<p>Find events, webinars, and podcasts</p>
</a>
<a class="nav-dropdown-item" href="https://discuss.pytorch.org/" target="_blank">
<span class="dropdown-title">Forums</span>
<p>A place to discuss PyTorch code, issues, install, research</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/hub">
<span class="dropdown-title">Models (Beta)</span>
<p>Discover, publish, and reuse pre-trained models</p>
</a>
</div>
</div>
</li>
<li>
<a href="https://github.com/pytorch/pytorch">GitHub</a>
</li>
</ul>
</div>
<a class="main-menu-open-button" href="#" data-behavior="open-mobile-menu"></a>
</div>
</div>
</div>
<body class="pytorch-body">
<div class="table-of-contents-link-wrapper">
<span>Table of Contents</span>
<a href="#" class="toggle-table-of-contents" data-behavior="toggle-table-of-contents"></a>
</div>
<nav data-toggle="wy-nav-shift" class="pytorch-left-menu" id="pytorch-left-menu">
<div class="pytorch-side-scroll">
<div class="pytorch-menu pytorch-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<div class="pytorch-left-menu-search">
<div class="version">
<a href='https://pytorch.org/docs/versions.html'>master (1.14.0a0+git876b702 ) ▼</a>
</div>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
<input type="text" name="q" placeholder="Search Docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<div>
<a style="color:#F05732" href="https://pytorch.org/docs/stable/dynamo/guards-overview.html">
You are viewing unstable developer preview docs.
Click here to view docs for latest stable release.
</a>
</div>
<p class="caption" role="heading"><span class="caption-text">Community</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../community/build_ci_governance.html">PyTorch Governance | Build + CI</a></li>
<li class="toctree-l1"><a class="reference internal" href="../community/contribution_guide.html">PyTorch Contribution Guide</a></li>
<li class="toctree-l1"><a class="reference internal" href="../community/design.html">PyTorch Design Philosophy</a></li>
<li class="toctree-l1"><a class="reference internal" href="../community/governance.html">PyTorch Governance | Mechanics</a></li>
<li class="toctree-l1"><a class="reference internal" href="../community/persons_of_interest.html">PyTorch Governance | Maintainers</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Developer Notes</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../notes/amp_examples.html">CUDA Automatic Mixed Precision examples</a></li>
<li class="toctree-l1"><a class="reference internal" href="../notes/autograd.html">Autograd mechanics</a></li>
<li class="toctree-l1"><a class="reference internal" href="../notes/broadcasting.html">Broadcasting semantics</a></li>
<li class="toctree-l1"><a class="reference internal" href="../notes/cpu_threading_torchscript_inference.html">CPU threading and TorchScript inference</a></li>
<li class="toctree-l1"><a class="reference internal" href="../notes/cuda.html">CUDA semantics</a></li>
<li class="toctree-l1"><a class="reference internal" href="../notes/ddp.html">Distributed Data Parallel</a></li>
<li class="toctree-l1"><a class="reference internal" href="../notes/extending.html">Extending PyTorch</a></li>
<li class="toctree-l1"><a class="reference internal" href="../notes/faq.html">Frequently Asked Questions</a></li>
<li class="toctree-l1"><a class="reference internal" href="../notes/gradcheck.html">Gradcheck mechanics</a></li>
<li class="toctree-l1"><a class="reference internal" href="../notes/hip.html">HIP (ROCm) semantics</a></li>
<li class="toctree-l1"><a class="reference internal" href="../notes/large_scale_deployments.html">Features for large-scale deployments</a></li>
<li class="toctree-l1"><a class="reference internal" href="../notes/modules.html">Modules</a></li>
<li class="toctree-l1"><a class="reference internal" href="../notes/mps.html">MPS backend</a></li>
<li class="toctree-l1"><a class="reference internal" href="../notes/multiprocessing.html">Multiprocessing best practices</a></li>
<li class="toctree-l1"><a class="reference internal" href="../notes/numerical_accuracy.html">Numerical accuracy</a></li>
<li class="toctree-l1"><a class="reference internal" href="../notes/randomness.html">Reproducibility</a></li>
<li class="toctree-l1"><a class="reference internal" href="../notes/serialization.html">Serialization semantics</a></li>
<li class="toctree-l1"><a class="reference internal" href="../notes/windows.html">Windows FAQ</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">torch.compile</span></p>
<ul class="current">
<li class="toctree-l1 current"><a class="reference internal" href="index.html">TorchDynamo Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="installation.html">Installing TorchDynamo</a></li>
<li class="toctree-l1"><a class="reference internal" href="get-started.html">Getting Started</a></li>
<li class="toctree-l1 current"><a class="current reference internal" href="#">Guards Overview</a></li>
<li class="toctree-l1"><a class="reference internal" href="custom-backends.html">Custom Backends</a></li>
<li class="toctree-l1"><a class="reference internal" href="deep-dive.html">TorchDynamo Deeper Dive</a></li>
<li class="toctree-l1"><a class="reference internal" href="troubleshooting.html">TorchDynamo Troubleshooting</a></li>
<li class="toctree-l1"><a class="reference internal" href="faq.html">Frequently Asked Questions</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Language Bindings</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../cpp_index.html">C++</a></li>
<li class="toctree-l1"><a class="reference external" href="https://pytorch.org/javadoc/">Javadoc</a></li>
<li class="toctree-l1"><a class="reference internal" href="../deploy.html">torch::deploy</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Python API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../torch.html">torch</a></li>
<li class="toctree-l1"><a class="reference internal" href="../nn.html">torch.nn</a></li>
<li class="toctree-l1"><a class="reference internal" href="../nn.functional.html">torch.nn.functional</a></li>
<li class="toctree-l1"><a class="reference internal" href="../tensors.html">torch.Tensor</a></li>
<li class="toctree-l1"><a class="reference internal" href="../tensor_attributes.html">Tensor Attributes</a></li>
<li class="toctree-l1"><a class="reference internal" href="../tensor_view.html">Tensor Views</a></li>
<li class="toctree-l1"><a class="reference internal" href="../amp.html">torch.amp</a></li>
<li class="toctree-l1"><a class="reference internal" href="../autograd.html">torch.autograd</a></li>
<li class="toctree-l1"><a class="reference internal" href="../library.html">torch.library</a></li>
<li class="toctree-l1"><a class="reference internal" href="../cuda.html">torch.cuda</a></li>
<li class="toctree-l1"><a class="reference internal" href="../backends.html">torch.backends</a></li>
<li class="toctree-l1"><a class="reference internal" href="../distributed.html">torch.distributed</a></li>
<li class="toctree-l1"><a class="reference internal" href="../distributed.algorithms.join.html">torch.distributed.algorithms.join</a></li>
<li class="toctree-l1"><a class="reference internal" href="../distributed.elastic.html">torch.distributed.elastic</a></li>
<li class="toctree-l1"><a class="reference internal" href="../fsdp.html">torch.distributed.fsdp</a></li>
<li class="toctree-l1"><a class="reference internal" href="../distributed.optim.html">torch.distributed.optim</a></li>
<li class="toctree-l1"><a class="reference internal" href="../distributed.tensor.parallel.html">torch.distributed.tensor.parallel</a></li>
<li class="toctree-l1"><a class="reference internal" href="../distributed.checkpoint.html">torch.distributed.checkpoint</a></li>
<li class="toctree-l1"><a class="reference internal" href="../distributions.html">torch.distributions</a></li>
<li class="toctree-l1"><a class="reference internal" href="../_dynamo.html">torch._dynamo</a></li>
<li class="toctree-l1"><a class="reference internal" href="../fft.html">torch.fft</a></li>
<li class="toctree-l1"><a class="reference internal" href="../futures.html">torch.futures</a></li>
<li class="toctree-l1"><a class="reference internal" href="../fx.html">torch.fx</a></li>
<li class="toctree-l1"><a class="reference internal" href="../hub.html">torch.hub</a></li>
<li class="toctree-l1"><a class="reference internal" href="../jit.html">torch.jit</a></li>
<li class="toctree-l1"><a class="reference internal" href="../linalg.html">torch.linalg</a></li>
<li class="toctree-l1"><a class="reference internal" href="../monitor.html">torch.monitor</a></li>
<li class="toctree-l1"><a class="reference internal" href="../signal.html">torch.signal</a></li>
<li class="toctree-l1"><a class="reference internal" href="../special.html">torch.special</a></li>
<li class="toctree-l1"><a class="reference internal" href="../torch.overrides.html">torch.overrides</a></li>
<li class="toctree-l1"><a class="reference internal" href="../package.html">torch.package</a></li>
<li class="toctree-l1"><a class="reference internal" href="../profiler.html">torch.profiler</a></li>
<li class="toctree-l1"><a class="reference internal" href="../nn.init.html">torch.nn.init</a></li>
<li class="toctree-l1"><a class="reference internal" href="../onnx.html">torch.onnx</a></li>
<li class="toctree-l1"><a class="reference internal" href="../onnx_diagnostics.html">torch.onnx diagnostics</a></li>
<li class="toctree-l1"><a class="reference internal" href="../optim.html">torch.optim</a></li>
<li class="toctree-l1"><a class="reference internal" href="../complex_numbers.html">Complex Numbers</a></li>
<li class="toctree-l1"><a class="reference internal" href="../ddp_comm_hooks.html">DDP Communication Hooks</a></li>
<li class="toctree-l1"><a class="reference internal" href="../pipeline.html">Pipeline Parallelism</a></li>
<li class="toctree-l1"><a class="reference internal" href="../quantization.html">Quantization</a></li>
<li class="toctree-l1"><a class="reference internal" href="../rpc.html">Distributed RPC Framework</a></li>
<li class="toctree-l1"><a class="reference internal" href="../random.html">torch.random</a></li>
<li class="toctree-l1"><a class="reference internal" href="../masked.html">torch.masked</a></li>
<li class="toctree-l1"><a class="reference internal" href="../nested.html">torch.nested</a></li>
<li class="toctree-l1"><a class="reference internal" href="../sparse.html">torch.sparse</a></li>
<li class="toctree-l1"><a class="reference internal" href="../storage.html">torch.Storage</a></li>
<li class="toctree-l1"><a class="reference internal" href="../testing.html">torch.testing</a></li>
<li class="toctree-l1"><a class="reference internal" href="../benchmark_utils.html">torch.utils.benchmark</a></li>
<li class="toctree-l1"><a class="reference internal" href="../bottleneck.html">torch.utils.bottleneck</a></li>
<li class="toctree-l1"><a class="reference internal" href="../checkpoint.html">torch.utils.checkpoint</a></li>
<li class="toctree-l1"><a class="reference internal" href="../cpp_extension.html">torch.utils.cpp_extension</a></li>
<li class="toctree-l1"><a class="reference internal" href="../data.html">torch.utils.data</a></li>
<li class="toctree-l1"><a class="reference internal" href="../jit_utils.html">torch.utils.jit</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dlpack.html">torch.utils.dlpack</a></li>
<li class="toctree-l1"><a class="reference internal" href="../mobile_optimizer.html">torch.utils.mobile_optimizer</a></li>
<li class="toctree-l1"><a class="reference internal" href="../model_zoo.html">torch.utils.model_zoo</a></li>
<li class="toctree-l1"><a class="reference internal" href="../tensorboard.html">torch.utils.tensorboard</a></li>
<li class="toctree-l1"><a class="reference internal" href="../type_info.html">Type Info</a></li>
<li class="toctree-l1"><a class="reference internal" href="../named_tensor.html">Named Tensors</a></li>
<li class="toctree-l1"><a class="reference internal" href="../name_inference.html">Named Tensors operator coverage</a></li>
<li class="toctree-l1"><a class="reference internal" href="../config_mod.html">torch.__config__</a></li>
</ul>
<p class="caption" role="heading"><span class="caption-text">Libraries</span></p>
<ul>
<li class="toctree-l1"><a class="reference external" href="https://pytorch.org/audio/stable">torchaudio</a></li>
<li class="toctree-l1"><a class="reference external" href="https://pytorch.org/data">TorchData</a></li>
<li class="toctree-l1"><a class="reference external" href="https://pytorch.org/torchrec">TorchRec</a></li>
<li class="toctree-l1"><a class="reference external" href="https://pytorch.org/serve">TorchServe</a></li>
<li class="toctree-l1"><a class="reference external" href="https://pytorch.org/text/stable">torchtext</a></li>
<li class="toctree-l1"><a class="reference external" href="https://pytorch.org/vision/stable">torchvision</a></li>
<li class="toctree-l1"><a class="reference external" href="http://pytorch.org/xla/">PyTorch on XLA Devices</a></li>
</ul>
</div>
</div>
</nav>
<div class="pytorch-container">
<div class="pytorch-page-level-bar" id="pytorch-page-level-bar">
<div class="pytorch-breadcrumbs-wrapper">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="pytorch-breadcrumbs">
<li>
<a href="../index.html">
Docs
</a> >
</li>
<li><a href="index.html">TorchDynamo Overview</a> ></li>
<li>Guards Overview</li>
<li class="pytorch-breadcrumbs-aside">
<a href="../_sources/dynamo/guards-overview.rst.txt" rel="nofollow"><img src="../_static/images/view-page-source-icon.svg"></a>
</li>
</ul>
</div>
</div>
<div class="pytorch-shortcuts-wrapper" id="pytorch-shortcuts-wrapper">
Shortcuts
</div>
</div>
<section data-toggle="wy-nav-shift" id="pytorch-content-wrap" class="pytorch-content-wrap">
<div class="pytorch-content-left">
<div class="rst-content">
<div role="main" class="main-content" itemscope="itemscope" itemtype="http://schema.org/Article">
<article itemprop="articleBody" id="pytorch-article" class="pytorch-article">
<section id="guards-overview">
<h1>Guards Overview<a class="headerlink" href="#guards-overview" title="Permalink to this heading">¶</a></h1>
<p>From a UX perspective, TorchDynamo is very easy to use. The user invokes
<code class="docutils literal notranslate"><span class="pre">torchdynamo.optimize</span></code> as an annotation:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="nd">@torchdynamo</span><span class="o">.</span><span class="n">optimize</span><span class="p">(</span><span class="n">my_compiler</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">fn_foo</span><span class="p">(</span><span class="n">bar</span><span class="p">):</span>
</pre></div>
</div>
<p>Where a complete example looks like this:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span>
<span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">torchdynamo</span>
<span class="k">def</span> <span class="nf">my_compiler</span><span class="p">(</span><span class="n">gm</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">fx</span><span class="o">.</span><span class="n">GraphModule</span><span class="p">,</span> <span class="n">example_inputs</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]):</span>
<span class="nb">print</span><span class="p">(</span><span class="s2">"my_compiler() called with FX graph:"</span><span class="p">)</span>
<span class="n">gm</span><span class="o">.</span><span class="n">graph</span><span class="o">.</span><span class="n">print_tabular</span><span class="p">()</span>
<span class="k">return</span> <span class="n">gm</span><span class="o">.</span><span class="n">forward</span> <span class="c1"># return a python callable</span>
<span class="nd">@torchdynamo</span><span class="o">.</span><span class="n">optimize</span><span class="p">(</span><span class="n">my_compiler</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">toy_example</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">a</span> <span class="o">/</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">a</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span>
<span class="k">if</span> <span class="n">b</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span>
<span class="n">b</span> <span class="o">=</span> <span class="n">b</span> <span class="o">*</span> <span class="o">-</span><span class="mi">1</span>
<span class="k">return</span> <span class="n">x</span> <span class="o">*</span> <span class="n">b</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">100</span><span class="p">):</span>
<span class="n">toy_example</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">10</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="mi">10</span><span class="p">))</span>
</pre></div>
</div>
<p>This allows TorchDynamo to capture the interpreted Python frames, grab
any and all relevant information, and speed things up wherever it can.
The speedup comes from a few places, and can be rather dependent on the
backend (<cite>my_compiler</cite> in the example above) provided, but the one speedup
that is important in this section is <strong>caching</strong>. Caching itself is not
a direct speedup but a critical enablement that prevents
recompilation. We dig a hole with dynamo, and caching allows us to get
out. It enables us to hold perf
neutrality while then enabling backends - the true source of our
speedups.</p>
<p>With even a pass-through no-op backend provided:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">my_compiler</span><span class="p">(</span><span class="n">gm</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">fx</span><span class="o">.</span><span class="n">GraphModule</span><span class="p">,</span> <span class="n">example_inputs</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">]):</span>
<span class="k">return</span> <span class="n">gm</span><span class="o">.</span><span class="n">forward</span>
</pre></div>
</div>
<p>We can see TorchDynamo speeding up Python execution even on
regular Python, not just PyTorch.</p>
<section id="caching-and-guards-overview">
<h2>Caching and Guards Overview<a class="headerlink" href="#caching-and-guards-overview" title="Permalink to this heading">¶</a></h2>
<p>TorchDynamo operates through caching transformed (by TorchDynamo) user
bytecode. When TorchDynamo receives a frame for evaluation, it checks if the
<strong>objects referenced in the frame have changed</strong> in certain ways, and if
not, TorchDynamo reads the previously transformed user bytecode to evaluate it.
In this section, we will focus on how we can identify whether or not the
<strong>objects referenced in the frame have changed</strong>. This is a critical
piece of functionality in TorchDynamo, because it drives the entire
invalidation lifecycle. This functionality is called <strong>guards</strong>.</p>
<p>At a very high level, the flow can be summarized like this:</p>
<ol class="arabic simple">
<li><p>TorchDynamo receives a Python frame.</p></li>
<li><p>It converts the frame (1) passing it through instruction
translation.</p></li>
<li><p>For the objects captured in (2), TorchDynamo creates tracking objects that
are:
* tracked on an output graph, which is an internal specialization
of a <cite>torch.fx.Tracer</cite>
* guards</p></li>
<li><p>TorchDynamo processes the guard objects created in (3), turning them into a
generated Python function, <cite>check_fn</cite>, associated with a piece of code.</p></li>
<li><p>The <cite>check_fn</cite> is evaluated whenever we encounter this code a
subsequent time - if a <cite>check_fn</cite> passes and evaluates to <cite>True</cite>, TorchDynamo
identifies the code in the cache and the code encountered here as same, and
can be safely used. If it fails and evaluates to <cite>False</cite>, TorchDynamo
identifies the code in the cache as not valid, and can be thrown out in
favor of a new entry, through recompilation or a graph break.</p></li>
</ol>
</section>
<section id="python-frame-evaluation-and-pep-523">
<h2>Python Frame Evaluation and PEP 523<a class="headerlink" href="#python-frame-evaluation-and-pep-523" title="Permalink to this heading">¶</a></h2>
<p>The functionality of TorchDynamo is based on
<a class="reference external" href="https://peps.python.org/pep-0523/">PEP 523</a>.</p>
<p>TorchDynamo installs a frame evaluation function on Python by using
<cite>_PyInterpreterState_SetEvalFrameFunc</cite>. TorchDynamo has a hook where
Python can hand control back to us during evaluation.</p>
<p>The function we have installed is <code class="docutils literal notranslate"><span class="pre">convert_frame</span></code> or
<code class="docutils literal notranslate"><span class="pre">convert_frame_assert</span></code> in the <code class="docutils literal notranslate"><span class="pre">nopython=True</span></code> case, but glossing
over that nuance for now, let’s take a look at <code class="docutils literal notranslate"><span class="pre">convert_frame_assert</span></code>,
as <code class="docutils literal notranslate"><span class="pre">convert_frame</span></code> proxies to it.</p>
<p>We can find it on <a class="reference external" href="https://github.com/pytorch/torchdynamo/blob/main/torchdynamo/convert_frame.py#L200">line 20 of convert_frame.py</a>,
with a signature as follows:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">convert_frame_assert</span><span class="p">(</span><span class="n">compiler_fn</span><span class="p">:</span> <span class="n">Callable</span><span class="p">,</span> <span class="n">one_graph</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span>
</pre></div>
</div>
<p>This function wraps the entry point of where Python invokes TorchDynamo
with a frame:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">_convert_frame_assert</span><span class="p">(</span><span class="n">frame</span><span class="p">:</span> <span class="n">types</span><span class="o">.</span><span class="n">FrameType</span><span class="p">,</span> <span class="n">cache_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span>
</pre></div>
</div>
<p>Here is what this function does:</p>
<ol class="arabic">
<li><p>Checks if it has seen this <code class="docutils literal notranslate"><span class="pre">code</span></code>(see: f_code <a class="reference external" href="https://docs.python.org/3/library/inspect.html">here</a>) before and exits
early if it did.</p></li>
<li><p>Checks if the code is an unsupported case.</p></li>
<li><p>Checks if the <code class="docutils literal notranslate"><span class="pre">cache_size</span></code> (second arg above) crosses the limit
defined in the config, <code class="docutils literal notranslate"><span class="pre">cache_size_limit</span></code>. If it has, the function
drops the frame and logs warnings. This helps to avoid constant
recompilation of a frame as it generally means that the frame is hot
in an unexpected way and caching it produces needless overhead,
as it is likely to get evicted the next time it is encountered.</p></li>
<li><p>Passes the frame, alongside a function that creates an
<code class="docutils literal notranslate"><span class="pre">InstructionTranslator</span></code> through bytecode
transformation, via <code class="docutils literal notranslate"><span class="pre">transform_code_object</span></code>. A few crucial things
happen under the hood here:</p>
<ol class="arabic">
<li><p>New code is produced through <code class="docutils literal notranslate"><span class="pre">transform_code_object</span></code>.</p></li>
<li><p>An FX tracer named <code class="docutils literal notranslate"><span class="pre">output</span></code> is produced through
<code class="docutils literal notranslate"><span class="pre">InstructionTranslator</span></code>.</p>
<p>This can be a bit confusing,
as <code class="docutils literal notranslate"><span class="pre">InstructionTranslator</span></code> is not an <cite>fx</cite> tracer, but its stored
in a variable named tracer, and its output*<strong>is</strong><em>an `fx`tracer.</em></p>
</li>
<li><p>The function produces guards and stores them on <code class="docutils literal notranslate"><span class="pre">output</span></code> above.</p></li>
<li><p>The function produces <code class="docutils literal notranslate"><span class="pre">output_instructions</span></code> and stores them on
<code class="docutils literal notranslate"><span class="pre">output</span></code> above.</p></li>
<li><p>The function maps the newly produced transformed code to the initial code it
read off the frame. This mapping is worth remembering, we will
refer to it much later on below where we cover guard failures.</p></li>
</ol>
</li>
<li><p>Using the transformed code from 4.1 and the guards from 4.3,
the function produces a <cite>GuardedCode</cite>.</p></li>
</ol>
<p>Now that we have learned about frame evaluation, let’s review
<code class="docutils literal notranslate"><span class="pre">InstructionTranslator</span></code>, and see how it turns the frame we handed
it over into TorchDynamo internal types.</p>
</section>
<section id="instructiontranslator">
<h2>InstructionTranslator<a class="headerlink" href="#instructiontranslator" title="Permalink to this heading">¶</a></h2>
<p><cite>InstructionTranslator</cite> does a lot! We won’t cover the details of
everything it does, but most importantly for this document, it produces
a mapping of <code class="docutils literal notranslate"><span class="pre">symbolic_locals</span></code> which maintains a mapping from the
frame’s <code class="docutils literal notranslate"><span class="pre">f_locals</span></code> to TorchDynamo internal Variable objects (more on these
in a moment. <code class="docutils literal notranslate"><span class="pre">symbolic_locals</span></code> is filled via traversing the frame’s
locals:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="bp">self</span><span class="o">.</span><span class="n">symbolic_locals</span> <span class="o">=</span> <span class="n">collections</span><span class="o">.</span><span class="n">OrderedDict</span><span class="p">(</span>
<span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">VariableBuilder</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">LocalSource</span><span class="p">(</span><span class="n">k</span><span class="p">))(</span><span class="n">f_locals</span><span class="p">[</span><span class="n">k</span><span class="p">]))</span>
<span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">vars</span>
<span class="k">if</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">f_locals</span>
<span class="p">)</span>
</pre></div>
</div>
<p>The important component here is the invocation of a call
into <code class="docutils literal notranslate"><span class="pre">VariableBuilder</span></code>. <code class="docutils literal notranslate"><span class="pre">VariableBuilder</span></code>’s call implementation
proxies into a function called <code class="docutils literal notranslate"><span class="pre">_wrap</span></code>, which in turn both constructs
instances of <code class="docutils literal notranslate"><span class="pre">VariableTracker</span></code> and calls <code class="docutils literal notranslate"><span class="pre">make_guards</span></code> on them. More
on that later.</p>
<p>This mapping, in turn, is critical as each Variable has associated
guards, which are then passed to <code class="docutils literal notranslate"><span class="pre">self.output</span></code>, the instance of
<code class="docutils literal notranslate"><span class="pre">OutputGraph</span></code>, an fx tracer, mentioned in 4.2 of the section above. If
you recall, this <code class="docutils literal notranslate"><span class="pre">OutputGraph</span></code>, stored in a variable called <code class="docutils literal notranslate"><span class="pre">output</span></code>
is where our guards are stored before being passed on to become
<code class="docutils literal notranslate"><span class="pre">GuardedCode</span></code></p>
<p>How does <code class="docutils literal notranslate"><span class="pre">InstructionTranslator</span></code> do this? At the heart of it, there is
a loop that is pumped, which drives a function <code class="docutils literal notranslate"><span class="pre">step</span></code>.</p>
<p><code class="docutils literal notranslate"><span class="pre">step</span></code> is just that - a single processing step, taking exactly one
instruction and doing <em>something</em> with it.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>These are real instructions processed by TorchDynamo’s
<code class="docutils literal notranslate"><span class="pre">transform_code_object</span></code>, and it is pretty cool.</p>
</div>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>This section purposly skips the details of
<a class="reference external" href="https://docs.python.org/3/library/dis.html">dis.get_instructions</a>.</p>
</div>
<p>For the example above, here is a snippet of a what a few
<code class="docutils literal notranslate"><span class="pre">Instruction</span></code>'s may look like:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">Instruction</span><span class="p">(</span><span class="n">opcode</span><span class="o">=</span><span class="mi">124</span><span class="p">,</span> <span class="n">opname</span><span class="o">=</span><span class="s1">'LOAD_FAST'</span><span class="p">,</span> <span class="n">arg</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">argval</span><span class="o">=</span><span class="s1">'b'</span><span class="p">,</span> <span class="n">offset</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">starts_line</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">is_jump_target</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">target</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
<span class="n">Instruction</span><span class="p">(</span><span class="n">opcode</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span> <span class="n">opname</span><span class="o">=</span><span class="s1">'LOAD_CONST'</span><span class="p">,</span> <span class="n">arg</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">argval</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">offset</span><span class="o">=</span><span class="mi">34</span><span class="p">,</span> <span class="n">starts_line</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">is_jump_target</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">target</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
<span class="n">Instruction</span><span class="p">(</span><span class="n">opcode</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">opname</span><span class="o">=</span><span class="s1">'BINARY_MULTIPLY'</span><span class="p">,</span> <span class="n">arg</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">argval</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">offset</span><span class="o">=</span><span class="mi">36</span><span class="p">,</span> <span class="n">starts_line</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">is_jump_target</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">target</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span>
</pre></div>
</div>
<p>This is the core functionality of this function. Take a look at the <code class="docutils literal notranslate"><span class="pre">opname</span></code>,
and then take a look at this little snippet from inside <code class="docutils literal notranslate"><span class="pre">step</span></code>;</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">if</span> <span class="ow">not</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inst</span><span class="o">.</span><span class="n">opname</span><span class="p">):</span>
<span class="n">unimplemented</span><span class="p">(</span><span class="sa">f</span><span class="s2">"missing: </span><span class="si">{</span><span class="n">inst</span><span class="o">.</span><span class="n">opname</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span>
<span class="nb">getattr</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inst</span><span class="o">.</span><span class="n">opname</span><span class="p">)(</span><span class="n">inst</span><span class="p">)</span>
</pre></div>
</div>
<p>As we can see, the function checks if the current class, the
<code class="docutils literal notranslate"><span class="pre">InstructionTranslator</span></code> has an attribute set matching the operator name
(for example, <code class="docutils literal notranslate"><span class="pre">LOAD_CONST</span></code>). If it does, the function invokes it, passing the
whole instruction object in. If it does not, the function drops the frame as
unimplemented.</p>
<p>For the <code class="docutils literal notranslate"><span class="pre">LOAD_CONST</span></code> example, we can see that we do indeed support it,
with a relatively straightforward definition:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">LOAD_CONST</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inst</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">push</span><span class="p">(</span><span class="n">ConstantVariable</span><span class="p">(</span><span class="n">value</span><span class="o">=</span><span class="n">inst</span><span class="o">.</span><span class="n">argval</span><span class="p">))</span>
</pre></div>
</div>
<p>We can see that this function creates a new instance of the class
<code class="docutils literal notranslate"><span class="pre">ConstantVariable</span></code> , with a value, in our example case, -1, and then
pushes it onto the stack.</p>
<p>There are dozens of such methods - see <code class="docutils literal notranslate"><span class="pre">symbolic_convert.py</span></code> for all of
them. Generally, we implement as many matching methods to Python
bytecode instructions as possible.</p>
<p>Across both the logic downstream of <code class="docutils literal notranslate"><span class="pre">step</span></code> and the logic from invoking
<code class="docutils literal notranslate"><span class="pre">VariableBuilder</span></code> - we now have a lot of <code class="docutils literal notranslate"><span class="pre">VariableTracker</span></code>s and of
course, we’ve spoken about creating guards quiet a bit. Let’s dig into
what Variables are, and get a little closer to understanding guards.</p>
</section>
<section id="variables">
<h2>Variables<a class="headerlink" href="#variables" title="Permalink to this heading">¶</a></h2>
<p>A <code class="docutils literal notranslate"><span class="pre">ConstantVariable</span></code> is an instance of<code class="docutils literal notranslate"><span class="pre">VariableTracker</span></code>.
<code class="docutils literal notranslate"><span class="pre">VariableTracker</span></code> represents a tracked Python local or stack value.</p>
<p>When it comes to representing an object inside TorchDynamo, a
<code class="docutils literal notranslate"><span class="pre">VariableTracker</span></code> does exactly what it says - it tracks a given variable.
It is an extremely flexible class, but there are a few points to keep in
mind:</p>
<ul class="simple">
<li><p>It manages the <code class="docutils literal notranslate"><span class="pre">guard</span></code> relationship around the underlying object
through:</p>
<ul>
<li><p><code class="docutils literal notranslate"><span class="pre">make_guard</span></code></p></li>
<li><p><code class="docutils literal notranslate"><span class="pre">replace_guards</span></code></p></li>
<li><p><code class="docutils literal notranslate"><span class="pre">add_guard(s)</span></code></p></li>
<li><p><code class="docutils literal notranslate"><span class="pre">propagate</span></code> - <code class="docutils literal notranslate"><span class="pre">propagate(*vars:</span> <span class="pre">List[List["VariableTracker"]])</span></code> -
Perhaps the most important of all, in that it combines guards from
all the provided <code class="docutils literal notranslate"><span class="pre">VariableTracker</span></code> instances passed in. It visits
the guards and combines the guards from these onto itself.</p></li>
</ul>
</li>
<li><p>It acts as a proxy on behalf of the underlying object, implementing
methods for the rest of TorchDynamo to get information about the
tracked object:</p>
<ul>
<li><p><code class="docutils literal notranslate"><span class="pre">call_method</span></code></p></li>
<li><p><code class="docutils literal notranslate"><span class="pre">call_function</span></code></p></li>
<li><p><code class="docutils literal notranslate"><span class="pre">python_type</span></code></p></li>
<li><p><code class="docutils literal notranslate"><span class="pre">as_proxy</span></code></p></li>
<li><p><code class="docutils literal notranslate"><span class="pre">is/as_python_proxy</span></code></p></li>
</ul>
</li>
<li><p>It stores the variable <code class="docutils literal notranslate"><span class="pre">source</span></code> of type <code class="docutils literal notranslate"><span class="pre">Source</span></code>, from
<code class="docutils literal notranslate"><span class="pre">torchdynamo/source.py</span></code>. This source type is a relatively self
contained class that helps us organize and bookeep where the original
source came from, and helps provide convenience methods for things
like getting the name, and importantly for us, producing guards.</p></li>
</ul>
<p>And this class (<code class="docutils literal notranslate"><span class="pre">VariableTracker</span></code>) is built around subclassing,
somewhere between a full Abstract Base Class and fully fleshed out class
- it leaves many methods raising <code class="docutils literal notranslate"><span class="pre">NotImplementedError</span></code> - with reliance on
subclasses. See <code class="docutils literal notranslate"><span class="pre">torchdynamo/variables/</span></code> for all subclasses to fulfill
contracts and custom behaviors.</p>
<p>Knowing what we know now, we can see an example of how an instruction
from <code class="docutils literal notranslate"><span class="pre">dis</span></code>, <code class="docutils literal notranslate"><span class="pre">BUILD_TUPLE</span></code>:</p>
<blockquote>
<div><p><code class="docutils literal notranslate"><span class="pre">BUILD_TUPLE(count)</span></code> Creates a tuple consuming count items from the
stack, and pushes the resulting tuple onto the stack.</p>
</div></blockquote>
<p>In our case, our signature will be a <em>little</em> different due to the way
we create <code class="docutils literal notranslate"><span class="pre">Instruction</span></code> objects, but the gist of it will be the same.
Instead of passing in <code class="docutils literal notranslate"><span class="pre">count</span></code>, we pass in an object with a little
extra bookkeeping, and of course, we deal with turning regular old
python objects into TorchDynamo notions:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">BUILD_TUPLE</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inst</span><span class="p">):</span>
<span class="n">items</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">popn</span><span class="p">(</span><span class="n">inst</span><span class="o">.</span><span class="n">argval</span><span class="p">)</span>
<span class="n">options</span> <span class="o">=</span> <span class="n">VariableTracker</span><span class="o">.</span><span class="n">propagate</span><span class="p">(</span><span class="n">items</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">push</span><span class="p">(</span><span class="n">TupleVariable</span><span class="p">(</span><span class="n">items</span><span class="p">,</span> <span class="o">**</span><span class="n">options</span><span class="p">))</span>
</pre></div>
</div>
<p>Here is what this code does:</p>
<ol class="arabic simple">
<li><p>The function reads <code class="docutils literal notranslate"><span class="pre">argval</span></code>, which in this case, is
analogous to <code class="docutils literal notranslate"><span class="pre">counts</span></code> in the pydoc for the equivalent instruction.</p></li>
<li><p>The function <code class="docutils literal notranslate"><span class="pre">popn</span></code> the items, in this case, the signature is
<code class="docutils literal notranslate"><span class="pre">def</span>  <span class="pre">popn(self,</span> <span class="pre">n:</span> <span class="pre">int)</span> <span class="pre">-></span> <span class="pre">List[TensorVariable]:</span></code> this hints at an
underlying contract - we are returning <code class="docutils literal notranslate"><span class="pre">TensorVariables</span></code>. If we
take a closer look at <code class="docutils literal notranslate"><span class="pre">sybmolic_convert.py</span></code> and
<code class="docutils literal notranslate"><span class="pre">InstructionTranslatorBase</span></code>/<code class="docutils literal notranslate"><span class="pre">InstructionTranslator</span></code>we see that
the only thing pushed onto and popped from our stack are
<code class="docutils literal notranslate"><span class="pre">VariableTracker</span></code>s.</p></li>
</ol>
<ol class="arabic simple" start="3">
<li><p>The function calls <code class="docutils literal notranslate"><span class="pre">VariableTracker.propogate</span></code>. This
takes the guards from every single item popped off the stack in 2,
and recursively traverses it and combines all the guards into
<code class="docutils literal notranslate"><span class="pre">options</span></code>: <code class="docutils literal notranslate"><span class="pre">py</span>  <span class="pre">return</span> <span class="pre">{</span>      <span class="pre">"guards":</span> <span class="pre">guards,</span>  <span class="pre">}</span></code></p></li>
<li><p>The function then makes a new instance of a <code class="docutils literal notranslate"><span class="pre">VariableTracker</span></code>,
<code class="docutils literal notranslate"><span class="pre">TupleVariable</span></code>out of the <code class="docutils literal notranslate"><span class="pre">items</span></code> and <code class="docutils literal notranslate"><span class="pre">options</span></code>. This then
allows us to install all the appropriate guards from the <code class="docutils literal notranslate"><span class="pre">items</span></code>
that make up the new <code class="docutils literal notranslate"><span class="pre">TupleVariable</span></code></p></li>
</ol>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Where did the first guards come from? Propagation
is a good technique, but we need something created before it can be
propagated. <code class="docutils literal notranslate"><span class="pre">VariableBuilder</span></code> calls
<code class="docutils literal notranslate"><span class="pre">make_guards</span></code> as it creates <code class="docutils literal notranslate"><span class="pre">VariableTracker</span></code> instances, from
<code class="docutils literal notranslate"><span class="pre">f_locals</span></code>. This in turn calls into the <code class="docutils literal notranslate"><span class="pre">source</span></code>, to have it create
guards.</p>
</div>
<p>After all this, bytecode translation is done and we are one step closer
to producing <code class="docutils literal notranslate"><span class="pre">GuardedCode</span></code>. We now understand how locals become
<code class="docutils literal notranslate"><span class="pre">VariableTracker</span></code>s, how instructions are handled, and where guards
are called on for creation. Before we can go into seeing how code and
guards are combined into a GuardedCode object, we need to dig a little
bit into those <code class="docutils literal notranslate"><span class="pre">make_guard</span></code> and <code class="docutils literal notranslate"><span class="pre">source.make_guard</span></code> calls above. We
can then understand, what was going on when we made guards
alongside, and on, <code class="docutils literal notranslate"><span class="pre">VariableTracker</span></code> instances.</p>
</section>
<section id="making-guards">
<h2>Making Guards<a class="headerlink" href="#making-guards" title="Permalink to this heading">¶</a></h2>
<p>Guards are just Python objects, of the class <code class="docutils literal notranslate"><span class="pre">Guard</span></code>. Let’s look at them
in more detail.</p>
<p>Looking at the definition of the dataclass (and therefore, ctor
signature), we see that it has a name, a source, and a create function.</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="nd">@dataclasses</span><span class="o">.</span><span class="n">dataclass</span>
<span class="k">class</span> <span class="nc">Guard</span><span class="p">:</span>
<span class="n">name</span><span class="p">:</span> <span class="nb">str</span>
<span class="n">source</span><span class="p">:</span> <span class="n">GuardSource</span>
<span class="n">create_fn</span><span class="p">:</span> <span class="n">Callable</span>
</pre></div>
</div>
<p>The name should be the name of the variable.</p>
<p>The source here is an enum indicating what <em>kind</em> of source the guard
belongs to.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>Not to be confused with <code class="docutils literal notranslate"><span class="pre">Source</span></code> and the other types
in <code class="docutils literal notranslate"><span class="pre">source.py</span></code>, as stored on <code class="docutils literal notranslate"><span class="pre">VariableTracker</span></code>.</p>
</div>
<p><code class="docutils literal notranslate"><span class="pre">create_fn</span></code> provides the main functionality to transition from a simple
dataclass to actually producing valid Python code to be invoked for
knowing whether or not things have changed in between invocations, and
whether we can safely read from the code cache or not.</p>
<p>The most common code paths for getting an instance of a guard are
through <code class="docutils literal notranslate"><span class="pre">make_guards</span></code> on <code class="docutils literal notranslate"><span class="pre">VariableTracker</span></code>.
<code class="docutils literal notranslate"><span class="pre">make_guards</span></code>->``source.make_guard``->``return Guard(self.name(), self.guard_source(), fn)``</p>
<p>Or, in a concrete example:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="o">...</span>
<span class="k">elif</span> <span class="n">istype</span><span class="p">(</span><span class="n">value</span><span class="p">,</span> <span class="nb">range</span><span class="p">):</span>
<span class="n">guards</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">make_guards</span><span class="p">(</span><span class="n">GuardBuilder</span><span class="o">.</span><span class="n">EQUALS_MATCH</span><span class="p">)</span>
<span class="k">return</span> <span class="n">RangeVariable</span><span class="p">(</span><span class="n">value</span><span class="o">=</span><span class="n">value</span><span class="p">,</span> <span class="n">guards</span><span class="o">=</span><span class="n">guards</span><span class="p">)</span>
</pre></div>
</div>
<p>Since <code class="docutils literal notranslate"><span class="pre">source</span></code> was set at the construction time of this
<code class="docutils literal notranslate"><span class="pre">VariableTracker</span></code>, all that was needed here was to provide the <code class="docutils literal notranslate"><span class="pre">fn</span></code>,
<code class="docutils literal notranslate"><span class="pre">GuardBuilder.EQUALS_MATCH</span></code> to the <code class="docutils literal notranslate"><span class="pre">create_fn</span></code> field.</p>
<p>This <code class="docutils literal notranslate"><span class="pre">create_fn</span></code> must be a method on <code class="docutils literal notranslate"><span class="pre">GuardBuilder</span></code>. The reason for
this becomes apparent in our next step. Once we have all the guards
created for a frame, we move on to <code class="docutils literal notranslate"><span class="pre">CheckFunctionManager</span></code> and
<code class="docutils literal notranslate"><span class="pre">compile_check_fn</span></code>.</p>
<p>Before the <code class="docutils literal notranslate"><span class="pre">convert_frame</span></code> function can produce a <code class="docutils literal notranslate"><span class="pre">GuardedCode</span></code>,
it needs to run the <code class="docutils literal notranslate"><span class="pre">CheckFunctionManager</span></code>, with all the guards, to
produce a <code class="docutils literal notranslate"><span class="pre">check_fn</span></code> which will then, in turn get passed in alongside
the code into <code class="docutils literal notranslate"><span class="pre">GuardedCode</span></code>. This is the same <code class="docutils literal notranslate"><span class="pre">check_fn</span></code> that we store in our
cache entry, and the same one we run to know whether or not to retrieve
the code stored alongside. For reference, here is that code:</p>
<div class="highlight-cpp notranslate"><div class="highlight"><pre><span></span><span class="k">static</span><span class="w"> </span><span class="n">CacheEntry</span><span class="w"> </span><span class="o">*</span><span class="nf">create_cache_entry</span><span class="p">(</span><span class="n">CacheEntry</span><span class="w"> </span><span class="o">*</span><span class="n">next</span><span class="p">,</span><span class="w"></span>
<span class="w"> </span><span class="n">PyObject</span><span class="w"> </span><span class="o">*</span><span class="n">guarded_code</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w"></span>
<span class="w"> </span><span class="n">CacheEntry</span><span class="w"> </span><span class="o">*</span><span class="n">e</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="p">(</span><span class="n">CacheEntry</span><span class="w"> </span><span class="o">*</span><span class="p">)</span><span class="n">malloc</span><span class="p">(</span><span class="k">sizeof</span><span class="p">(</span><span class="n">CacheEntry</span><span class="p">));</span><span class="w"></span>
<span class="w"> </span><span class="n">DEBUG_NULL_CHECK</span><span class="p">(</span><span class="n">e</span><span class="p">);</span><span class="w"></span>
<span class="w"> </span><span class="n">e</span><span class="o">-></span><span class="n">check_fn</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">PyObject_GetAttrString</span><span class="p">(</span><span class="n">guarded_code</span><span class="p">,</span><span class="w"> </span><span class="s">"check_fn"</span><span class="p">);</span><span class="w"></span>
<span class="w"> </span><span class="n">NULL_CHECK</span><span class="p">(</span><span class="n">e</span><span class="o">-></span><span class="n">check_fn</span><span class="p">);</span><span class="w"></span>
<span class="w"> </span><span class="n">e</span><span class="o">-></span><span class="n">code</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="p">(</span><span class="n">PyCodeObject</span><span class="w"> </span><span class="o">*</span><span class="p">)</span><span class="n">PyObject_GetAttrString</span><span class="p">(</span><span class="n">guarded_code</span><span class="p">,</span><span class="w"> </span><span class="s">"code"</span><span class="p">);</span><span class="w"></span>
<span class="w"> </span><span class="n">NULL_CHECK</span><span class="p">(</span><span class="n">e</span><span class="o">-></span><span class="n">code</span><span class="p">);</span><span class="w"></span>
<span class="w"> </span><span class="n">e</span><span class="o">-></span><span class="n">next</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="n">next</span><span class="p">;</span><span class="w"></span>
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">e</span><span class="p">;</span><span class="w"></span>
<span class="p">}</span><span class="w"></span>
</pre></div>
</div>
<p>We now know how a <code class="docutils literal notranslate"><span class="pre">check_fn</span></code> function is used, and who makes it, and
what it is composed of, but what we do not yet know is how. How does a
list of <code class="docutils literal notranslate"><span class="pre">Guard</span></code> objects become a function we can run later on?</p>
<p>First, we iterate these guards:</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">for</span> <span class="n">guard</span> <span class="ow">in</span> <span class="nb">sorted</span><span class="p">(</span><span class="n">guards</span> <span class="ow">or</span> <span class="p">[],</span> <span class="n">key</span><span class="o">=</span><span class="n">Guard</span><span class="o">.</span><span class="n">sort_key</span><span class="p">):</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">config</span><span class="o">.</span><span class="n">guard_nn_modules</span> <span class="ow">and</span> <span class="n">guard</span><span class="o">.</span><span class="n">is_nn_module</span><span class="p">():</span>
<span class="k">continue</span>
<span class="n">guard</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">local_builder</span><span class="p">,</span> <span class="n">global_builder</span><span class="p">)</span>
</pre></div>
</div>
<p>Calling <code class="docutils literal notranslate"><span class="pre">guard.create</span></code> runs that <code class="docutils literal notranslate"><span class="pre">create_fn</span></code> we set on the <code class="docutils literal notranslate"><span class="pre">Guard</span></code>
class above (don’t confuse it with the <code class="docutils literal notranslate"><span class="pre">check_fn</span></code> we are working on
producing, the names are similar, so it can get a little confusing). In
our example above, our <code class="docutils literal notranslate"><span class="pre">create_fn</span></code> is <code class="docutils literal notranslate"><span class="pre">GuardBuilder.EQUALS_MATCH</span></code>.
So we are now invoking it, passing in the <code class="docutils literal notranslate"><span class="pre">self</span></code>, the guard itself,
in.</p>
<p>The signature is: <code class="docutils literal notranslate"><span class="pre">def</span> <span class="pre">EQUALS_MATCH(self,</span> <span class="pre">guard:</span> <span class="pre">Guard):</span></code></p>
<p>And internally to that function, we can use the <code class="docutils literal notranslate"><span class="pre">name</span></code> on the guard to
get back our original object, querying it for data and type information,
which in turn gets us to the most important bit: appending code.</p>
<p>At its simplest, <code class="docutils literal notranslate"><span class="pre">EQUALS_MATCH</span></code> appends just one line of code:
<code class="docutils literal notranslate"><span class="pre">self.code.append(f"{ref}</span> <span class="pre">==</span> <span class="pre">{val!r}")</span></code>. Where <code class="docutils literal notranslate"><span class="pre">ref</span></code> is the name of
the variable, and <code class="docutils literal notranslate"><span class="pre">val</span></code> is the value. It might produce code like this:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">y</span> <span class="o">==</span> <span class="mi">2</span>
</pre></div>
</div>
<p>This is a basic example. But if we append a few other kinds of <code class="docutils literal notranslate"><span class="pre">GuardBuilder</span></code>
functions and then combine them all with
<code class="docutils literal notranslate"><span class="pre">and</span></code> in between each statement (as we do), we might get something
like this:</p>
<div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">___guarded_code</span><span class="o">.</span><span class="n">valid</span> <span class="ow">and</span> <span class="n">___check_type_id</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="mi">94367738391392</span><span class="p">)</span> <span class="ow">and</span> <span class="n">y</span> <span class="o">==</span> <span class="mi">2</span> <span class="ow">and</span> <span class="n">___check_tensors</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
</pre></div>
</div>
<p>Here is what this code performs:</p>
<ol class="arabic simple">
<li><p>A check for <code class="docutils literal notranslate"><span class="pre">.valid</span></code></p></li>
<li><p>A type ID check</p></li>
<li><p>A value check</p></li>
<li><p>A tensor check</p></li>
</ol>
<p>This becomes the heart of the code our <code class="docutils literal notranslate"><span class="pre">check_fn</span></code>, which in turn
is evaluated the <strong>next</strong> time we encounter this code. It
will then check:</p>
<ol class="arabic simple">
<li><p>Is this code still valid?</p></li>
<li><p>If (1), Does <code class="docutils literal notranslate"><span class="pre">y</span></code> still have a type of <code class="docutils literal notranslate"><span class="pre">94367738391392</span></code>?</p></li>
<li><p>If (2), is <code class="docutils literal notranslate"><span class="pre">y</span></code> still 2?</p></li>
<li><p>If (3), let’s check on if tensor <code class="docutils literal notranslate"><span class="pre">x</span></code> changed in some specific ways.</p></li>
</ol>
<p>If all of these are still true, then we can use the code cached
alongside this <code class="docutils literal notranslate"><span class="pre">check_fn</span></code>.</p>
<div class="admonition note">
<p class="admonition-title">Note</p>
<p>For a deeper dive for how and where this happens
you can read <code class="docutils literal notranslate"><span class="pre">static</span> <span class="pre">PyCodeObject</span> <span class="pre">*lookup(CacheEntry</span> <span class="pre">*e,</span> <span class="pre">PyObject</span> <span class="pre">*f_locals)</span> <span class="pre">{</span></code> of
<code class="docutils literal notranslate"><span class="pre">_eval_frame.c</span></code>.</p>
</div>
<p>If not, then, we can move on to recompiling the code anew, and storing
that in the cache alongside this code, and a whole new <code class="docutils literal notranslate"><span class="pre">check_fn</span></code>,
again to be checked on yet another subsequent frame.</p>
<p>There are lots of other such functions on <code class="docutils literal notranslate"><span class="pre">GuardBuilder</span></code> which get
coalesced into, at times massive, strings which then get evaluated as
Python code and stored into <code class="docutils literal notranslate"><span class="pre">check_fn</span></code>. The example above
illustrates of a simple case. To understand this functionality better, read
the other functions on <code class="docutils literal notranslate"><span class="pre">GuardBuilder</span></code>, or better yet, dump the <code class="docutils literal notranslate"><span class="pre">code</span></code> variable
in <code class="docutils literal notranslate"><span class="pre">compile_check_fn</span></code> to see what is getting produced,
especially on larger, real models.</p>
</section>
<section id="summary">
<h2>Summary<a class="headerlink" href="#summary" title="Permalink to this heading">¶</a></h2>
<p>In this section, we have reviewed:</p>
<ul class="simple">
<li><p>The role of <code class="docutils literal notranslate"><span class="pre">.valid</span></code> and invalidation around weak references (and potentially soon to be NN Moduleinvalidations).</p></li>
<li><p>How the C++ side of guard functions (<code class="docutils literal notranslate"><span class="pre">___check_type_id</span></code>, <code class="docutils literal notranslate"><span class="pre">___check_tensors</span></code>, etc) operate</p></li>
<li><p>What happens when guards fail.</p></li>
<li><p>What happens if we produce invalid guard code.</p></li>
</ul>
<p>We covered how user provided code wrapped in a TorchDynamo context
goes on to get traced and tracked internally, organized into <code class="docutils literal notranslate"><span class="pre">VariableTracker</span></code>s
<code class="docutils literal notranslate"><span class="pre">Source</span></code>s and subsequently <code class="docutils literal notranslate"><span class="pre">Guard</span></code>s, and how those <code class="docutils literal notranslate"><span class="pre">Guards</span></code> in
turn guide cache entry selection and invalidation when handing Python
code.</p>
</section>
</section>
</article>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="custom-backends.html" class="btn btn-neutral float-right" title="Custom Backends" accesskey="n" rel="next">Next <img src="../_static/images/chevron-right-orange.svg" class="next-page"></a>
<a href="get-started.html" class="btn btn-neutral" title="Getting Started" accesskey="p" rel="prev"><img src="../_static/images/chevron-right-orange.svg" class="previous-page"> Previous</a>
</div>
<hr>
<div role="contentinfo">
<p>
© Copyright 2022, PyTorch Contributors.
</p>
</div>
<div>
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
</div>
</footer>
</div>
</div>
<div class="pytorch-content-right" id="pytorch-content-right">
<div class="pytorch-right-menu" id="pytorch-right-menu">
<div class="pytorch-side-scroll" id="pytorch-side-scroll-right">
<ul>
<li><a class="reference internal" href="#">Guards Overview</a><ul>
<li><a class="reference internal" href="#caching-and-guards-overview">Caching and Guards Overview</a></li>
<li><a class="reference internal" href="#python-frame-evaluation-and-pep-523">Python Frame Evaluation and PEP 523</a></li>
<li><a class="reference internal" href="#instructiontranslator">InstructionTranslator</a></li>
<li><a class="reference internal" href="#variables">Variables</a></li>
<li><a class="reference internal" href="#making-guards">Making Guards</a></li>
<li><a class="reference internal" href="#summary">Summary</a></li>
</ul>
</li>
</ul>
</div>
</div>
</div>
</section>
</div>
<script type="text/javascript" id="documentation_options" data-url_root="../" src="../_static/documentation_options.js"></script>
<script data-url_root="../" id="documentation_options" src="../_static/documentation_options.js"></script>
<script src="../_static/jquery.js"></script>
<script src="../_static/underscore.js"></script>
<script src="../_static/_sphinx_javascript_frameworks_compat.js"></script>
<script src="../_static/doctools.js"></script>
<script src="../_static/clipboard.min.js"></script>
<script src="../_static/copybutton.js"></script>
<script type="text/javascript" src="../_static/js/vendor/popper.min.js"></script>
<script type="text/javascript" src="../_static/js/vendor/bootstrap.min.js"></script>