forked from pytorch/pytorch.github.io
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlarge_scale_deployments.html
867 lines (683 loc) · 49.1 KB
/
large_scale_deployments.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
<!DOCTYPE html>
<!--[if IE 8]><html class="no-js lt-ie9" lang="en" > <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js" lang="en" > <!--<![endif]-->
<head>
<meta name="robots" content="noindex">
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Features for large-scale deployments — PyTorch 1.12 documentation</title>
<link rel="canonical" href="https://pytorch.org/docs/stable/notes/large_scale_deployments.html"/>
<link rel="stylesheet" href="../_static/css/theme.css" type="text/css" />
<!-- <link rel="stylesheet" href="../_static/pygments.css" type="text/css" /> -->
<link rel="stylesheet" href="../_static/copybutton.css" type="text/css" />
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/katex.min.css" type="text/css" />
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/katex.min.css" type="text/css" />
<link rel="stylesheet" href="../_static/katex-math.css" type="text/css" />
<link rel="stylesheet" href="../_static/panels-main.c949a650a448cc0ae9fd3441c0e17fb0.css" type="text/css" />
<link rel="stylesheet" href="../_static/panels-variables.06eb56fa6e07937060861dad626602ad.css" type="text/css" />
<link rel="stylesheet" href="../_static/css/jit.css" type="text/css" />
<link rel="index" title="Index" href="../genindex.html" />
<link rel="search" title="Search" href="../search.html" />
<link rel="next" title="Modules" href="modules.html" />
<link rel="prev" title="HIP (ROCm) semantics" href="hip.html" />
<!-- Google Analytics -->
<script async src="https://www.googletagmanager.com/gtag/js?id=UA-117752657-2"></script>
<script>
window.dataLayer = window.dataLayer || [];
function gtag(){dataLayer.push(arguments);}
gtag('js', new Date());
gtag('config', 'UA-117752657-2');
</script>
<!-- End Google Analytics -->
<script src="../_static/js/modernizr.min.js"></script>
<!-- Preload the theme fonts -->
<link rel="preload" href="../_static/fonts/FreightSans/freight-sans-book.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="../_static/fonts/FreightSans/freight-sans-medium.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="../_static/fonts/IBMPlexMono/IBMPlexMono-Medium.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="../_static/fonts/FreightSans/freight-sans-bold.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="../_static/fonts/FreightSans/freight-sans-medium-italic.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="../_static/fonts/IBMPlexMono/IBMPlexMono-SemiBold.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<!-- Preload the katex fonts -->
<link rel="preload" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/fonts/KaTeX_Math-Italic.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/fonts/KaTeX_Main-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/fonts/KaTeX_Main-Bold.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/fonts/KaTeX_Size1-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/fonts/KaTeX_Size4-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/fonts/KaTeX_Size2-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/fonts/KaTeX_Size3-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="preload" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/fonts/KaTeX_Caligraphic-Regular.woff2" as="font" type="font/woff2" crossorigin="anonymous">
<link rel="stylesheet" href="https://use.fontawesome.com/releases/v5.15.2/css/all.css" integrity="sha384-vSIIfh2YWi9wW0r9iZe7RJPrKwp6bG+s9QZMoITbCckVJqGCCRhc+ccxNcdpHuYu" crossorigin="anonymous">
</head>
<div class="container-fluid header-holder tutorials-header" id="header-holder">
<div class="container">
<div class="header-container">
<a class="header-logo" href="https://pytorch.org/" aria-label="PyTorch"></a>
<div class="main-menu">
<ul>
<li>
<a href="https://pytorch.org/get-started">Get Started</a>
</li>
<li>
<a href="https://pytorch.org/ecosystem">Ecosystem</a>
</li>
<li>
<a href="https://pytorch.org/mobile">Mobile</a>
</li>
<li>
<a href="https://pytorch.org/blog/">Blog</a>
</li>
<li>
<a href="https://pytorch.org/tutorials">Tutorials</a>
</li>
<li class="active docs-active">
<div id="resourcesDropdownButton" data-toggle="resources-dropdown" class="resources-dropdown">
<a class="resource-option with-down-orange-arrow">
Docs
</a>
<div class="resources-dropdown-menu">
<a class="doc-dropdown-option nav-dropdown-item" href="https://pytorch.org/docs/stable/index.html">
<span class="dropdown-title">PyTorch</span>
<p></p>
</a>
<a class="doc-dropdown-option nav-dropdown-item" href="https://pytorch.org/audio/stable/index.html">
<span class="dropdown-title">torchaudio</span>
<p></p>
</a>
<a class="doc-dropdown-option nav-dropdown-item" href="https://pytorch.org/text/stable/index.html">
<span class="dropdown-title">torchtext</span>
<p></p>
</a>
<a class="doc-dropdown-option nav-dropdown-item" href="https://pytorch.org/vision/stable/index.html">
<span class="dropdown-title">torchvision</span>
<p></p>
</a>
<a class="doc-dropdown-option nav-dropdown-item" href="https://pytorch.org/torchrec">
<span class="dropdown-title">TorchRec</span>
<p></p>
</a>
<a class="doc-dropdown-option nav-dropdown-item" href="https://pytorch.org/data">
<span class="dropdown-title">TorchData</span>
<p></p>
</a>
<a class="doc-dropdown-option nav-dropdown-item" href="https://pytorch.org/serve/">
<span class="dropdown-title">TorchServe</span>
<p></p>
</a>
<a class="doc-dropdown-option nav-dropdown-item" href="https://pytorch.org/torchx/">
<span class="dropdown-title">TorchX</span>
<p></p>
</a>
<a class="doc-dropdown-option nav-dropdown-item" href="https://pytorch.org/xla">
<span class="dropdown-title">PyTorch on XLA Devices</span>
<p></p>
</a>
</div>
</li>
<li>
<div id="resourcesDropdownButton" data-toggle="resources-dropdown" class="resources-dropdown">
<a class="resource-option with-down-arrow">
Resources
</a>
<div class="resources-dropdown-menu">
<a class="nav-dropdown-item" href="https://pytorch.org/features">
<span class="dropdown-title">About</span>
<p>Learn about PyTorch’s features and capabilities</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/#community-module">
<span class="dropdown-title">Community</span>
<p>Join the PyTorch developer community to contribute, learn, and get your questions answered.</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/resources">
<span class="dropdown-title">Developer Resources</span>
<p>Find resources and get questions answered</p>
</a>
<a class="nav-dropdown-item" href="https://discuss.pytorch.org/" target="_blank">
<span class="dropdown-title">Forums</span>
<p>A place to discuss PyTorch code, issues, install, research</p>
</a>
<a class="nav-dropdown-item" href="https://pytorch.org/hub">
<span class="dropdown-title">Models (Beta)</span>
<p>Discover, publish, and reuse pre-trained models</p>
</a>
</div>
</div>
</li>
<li>
<a href="https://github.com/pytorch/pytorch">GitHub</a>
</li>
</ul>
</div>
<a class="main-menu-open-button" href="#" data-behavior="open-mobile-menu"></a>
</div>
</div>
</div>
<body class="pytorch-body">
<div class="table-of-contents-link-wrapper">
<span>Table of Contents</span>
<a href="#" class="toggle-table-of-contents" data-behavior="toggle-table-of-contents"></a>
</div>
<nav data-toggle="wy-nav-shift" class="pytorch-left-menu" id="pytorch-left-menu">
<div class="pytorch-side-scroll">
<div class="pytorch-menu pytorch-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
<div class="pytorch-left-menu-search">
<div class="version">
<a href='https://pytorch.org/docs/versions.html'>1.12 ▼</a>
</div>
<div role="search">
<form id="rtd-search-form" class="wy-form" action="../search.html" method="get">
<input type="text" name="q" placeholder="Search Docs" />
<input type="hidden" name="check_keywords" value="yes" />
<input type="hidden" name="area" value="default" />
</form>
</div>
</div>
<p class="caption" role="heading"><span class="caption-text">Community</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../community/build_ci_governance.html">PyTorch Governance | Build + CI</a></li>
<li class="toctree-l1"><a class="reference internal" href="../community/contribution_guide.html">PyTorch Contribution Guide</a></li>
<li class="toctree-l1"><a class="reference internal" href="../community/design.html">PyTorch Design Philosophy</a></li>
<li class="toctree-l1"><a class="reference internal" href="../community/governance.html">PyTorch Governance | Mechanics</a></li>
<li class="toctree-l1"><a class="reference internal" href="../community/persons_of_interest.html">PyTorch Governance | Maintainers</a></li>
</ul>
<p class="caption"><span class="caption-text">Notes</span></p>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="amp_examples.html">CUDA Automatic Mixed Precision examples</a></li>
<li class="toctree-l1"><a class="reference internal" href="autograd.html">Autograd mechanics</a></li>
<li class="toctree-l1"><a class="reference internal" href="broadcasting.html">Broadcasting semantics</a></li>
<li class="toctree-l1"><a class="reference internal" href="cpu_threading_torchscript_inference.html">CPU threading and TorchScript inference</a></li>
<li class="toctree-l1"><a class="reference internal" href="cuda.html">CUDA semantics</a></li>
<li class="toctree-l1"><a class="reference internal" href="ddp.html">Distributed Data Parallel</a></li>
<li class="toctree-l1"><a class="reference internal" href="extending.html">Extending PyTorch</a></li>
<li class="toctree-l1"><a class="reference internal" href="faq.html">Frequently Asked Questions</a></li>
<li class="toctree-l1"><a class="reference internal" href="gradcheck.html">Gradcheck mechanics</a></li>
<li class="toctree-l1"><a class="reference internal" href="hip.html">HIP (ROCm) semantics</a></li>
<li class="toctree-l1 current"><a class="current reference internal" href="#">Features for large-scale deployments</a></li>
<li class="toctree-l1"><a class="reference internal" href="modules.html">Modules</a></li>
<li class="toctree-l1"><a class="reference internal" href="mps.html">MPS backend</a></li>
<li class="toctree-l1"><a class="reference internal" href="multiprocessing.html">Multiprocessing best practices</a></li>
<li class="toctree-l1"><a class="reference internal" href="numerical_accuracy.html">Numerical accuracy</a></li>
<li class="toctree-l1"><a class="reference internal" href="randomness.html">Reproducibility</a></li>
<li class="toctree-l1"><a class="reference internal" href="serialization.html">Serialization semantics</a></li>
<li class="toctree-l1"><a class="reference internal" href="windows.html">Windows FAQ</a></li>
</ul>
<p class="caption"><span class="caption-text">Language Bindings</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../cpp_index.html">C++</a></li>
<li class="toctree-l1"><a class="reference external" href="https://pytorch.org/javadoc/">Javadoc</a></li>
<li class="toctree-l1"><a class="reference internal" href="../deploy.html">torch::deploy</a></li>
</ul>
<p class="caption"><span class="caption-text">Python API</span></p>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../torch.html">torch</a></li>
<li class="toctree-l1"><a class="reference internal" href="../nn.html">torch.nn</a></li>
<li class="toctree-l1"><a class="reference internal" href="../nn.functional.html">torch.nn.functional</a></li>
<li class="toctree-l1"><a class="reference internal" href="../tensors.html">torch.Tensor</a></li>
<li class="toctree-l1"><a class="reference internal" href="../tensor_attributes.html">Tensor Attributes</a></li>
<li class="toctree-l1"><a class="reference internal" href="../tensor_view.html">Tensor Views</a></li>
<li class="toctree-l1"><a class="reference internal" href="../amp.html">torch.amp</a></li>
<li class="toctree-l1"><a class="reference internal" href="../autograd.html">torch.autograd</a></li>
<li class="toctree-l1"><a class="reference internal" href="../library.html">torch.library</a></li>
<li class="toctree-l1"><a class="reference internal" href="../cuda.html">torch.cuda</a></li>
<li class="toctree-l1"><a class="reference internal" href="../backends.html">torch.backends</a></li>
<li class="toctree-l1"><a class="reference internal" href="../distributed.html">torch.distributed</a></li>
<li class="toctree-l1"><a class="reference internal" href="../distributed.algorithms.join.html">torch.distributed.algorithms.join</a></li>
<li class="toctree-l1"><a class="reference internal" href="../distributed.elastic.html">torch.distributed.elastic</a></li>
<li class="toctree-l1"><a class="reference internal" href="../fsdp.html">torch.distributed.fsdp</a></li>
<li class="toctree-l1"><a class="reference internal" href="../distributed.optim.html">torch.distributed.optim</a></li>
<li class="toctree-l1"><a class="reference internal" href="../distributions.html">torch.distributions</a></li>
<li class="toctree-l1"><a class="reference internal" href="../fft.html">torch.fft</a></li>
<li class="toctree-l1"><a class="reference internal" href="../futures.html">torch.futures</a></li>
<li class="toctree-l1"><a class="reference internal" href="../fx.html">torch.fx</a></li>
<li class="toctree-l1"><a class="reference internal" href="../hub.html">torch.hub</a></li>
<li class="toctree-l1"><a class="reference internal" href="../jit.html">torch.jit</a></li>
<li class="toctree-l1"><a class="reference internal" href="../linalg.html">torch.linalg</a></li>
<li class="toctree-l1"><a class="reference internal" href="../monitor.html">torch.monitor</a></li>
<li class="toctree-l1"><a class="reference internal" href="../special.html">torch.special</a></li>
<li class="toctree-l1"><a class="reference internal" href="../torch.overrides.html">torch.overrides</a></li>
<li class="toctree-l1"><a class="reference internal" href="../package.html">torch.package</a></li>
<li class="toctree-l1"><a class="reference internal" href="../profiler.html">torch.profiler</a></li>
<li class="toctree-l1"><a class="reference internal" href="../nn.init.html">torch.nn.init</a></li>
<li class="toctree-l1"><a class="reference internal" href="../onnx.html">torch.onnx</a></li>
<li class="toctree-l1"><a class="reference internal" href="../optim.html">torch.optim</a></li>
<li class="toctree-l1"><a class="reference internal" href="../complex_numbers.html">Complex Numbers</a></li>
<li class="toctree-l1"><a class="reference internal" href="../ddp_comm_hooks.html">DDP Communication Hooks</a></li>
<li class="toctree-l1"><a class="reference internal" href="../pipeline.html">Pipeline Parallelism</a></li>
<li class="toctree-l1"><a class="reference internal" href="../quantization.html">Quantization</a></li>
<li class="toctree-l1"><a class="reference internal" href="../rpc.html">Distributed RPC Framework</a></li>
<li class="toctree-l1"><a class="reference internal" href="../random.html">torch.random</a></li>
<li class="toctree-l1"><a class="reference internal" href="../nested.html">torch.nested</a></li>
<li class="toctree-l1"><a class="reference internal" href="../sparse.html">torch.sparse</a></li>
<li class="toctree-l1"><a class="reference internal" href="../storage.html">torch.Storage</a></li>
<li class="toctree-l1"><a class="reference internal" href="../testing.html">torch.testing</a></li>
<li class="toctree-l1"><a class="reference internal" href="../benchmark_utils.html">torch.utils.benchmark</a></li>
<li class="toctree-l1"><a class="reference internal" href="../bottleneck.html">torch.utils.bottleneck</a></li>
<li class="toctree-l1"><a class="reference internal" href="../checkpoint.html">torch.utils.checkpoint</a></li>
<li class="toctree-l1"><a class="reference internal" href="../cpp_extension.html">torch.utils.cpp_extension</a></li>
<li class="toctree-l1"><a class="reference internal" href="../data.html">torch.utils.data</a></li>
<li class="toctree-l1"><a class="reference internal" href="../dlpack.html">torch.utils.dlpack</a></li>
<li class="toctree-l1"><a class="reference internal" href="../mobile_optimizer.html">torch.utils.mobile_optimizer</a></li>
<li class="toctree-l1"><a class="reference internal" href="../model_zoo.html">torch.utils.model_zoo</a></li>
<li class="toctree-l1"><a class="reference internal" href="../tensorboard.html">torch.utils.tensorboard</a></li>
<li class="toctree-l1"><a class="reference internal" href="../type_info.html">Type Info</a></li>
<li class="toctree-l1"><a class="reference internal" href="../named_tensor.html">Named Tensors</a></li>
<li class="toctree-l1"><a class="reference internal" href="../name_inference.html">Named Tensors operator coverage</a></li>
<li class="toctree-l1"><a class="reference internal" href="../config_mod.html">torch.__config__</a></li>
</ul>
<p class="caption"><span class="caption-text">Libraries</span></p>
<ul>
<li class="toctree-l1"><a class="reference external" href="https://pytorch.org/audio/stable">torchaudio</a></li>
<li class="toctree-l1"><a class="reference external" href="https://pytorch.org/data">TorchData</a></li>
<li class="toctree-l1"><a class="reference external" href="https://pytorch.org/torchrec">TorchRec</a></li>
<li class="toctree-l1"><a class="reference external" href="https://pytorch.org/serve">TorchServe</a></li>
<li class="toctree-l1"><a class="reference external" href="https://pytorch.org/text/stable">torchtext</a></li>
<li class="toctree-l1"><a class="reference external" href="https://pytorch.org/vision/stable">torchvision</a></li>
<li class="toctree-l1"><a class="reference external" href="http://pytorch.org/xla/">PyTorch on XLA Devices</a></li>
</ul>
</div>
</div>
</nav>
<div class="pytorch-container">
<div class="pytorch-page-level-bar" id="pytorch-page-level-bar">
<div class="pytorch-breadcrumbs-wrapper">
<div role="navigation" aria-label="breadcrumbs navigation">
<ul class="pytorch-breadcrumbs">
<li>
<a href="../index.html">
Docs
</a> >
</li>
<li>Features for large-scale deployments</li>
<li class="pytorch-breadcrumbs-aside">
<a href="../_sources/notes/large_scale_deployments.rst.txt" rel="nofollow"><img src="../_static/images/view-page-source-icon.svg"></a>
</li>
</ul>
</div>
</div>
<div class="pytorch-shortcuts-wrapper" id="pytorch-shortcuts-wrapper">
Shortcuts
</div>
</div>
<section data-toggle="wy-nav-shift" id="pytorch-content-wrap" class="pytorch-content-wrap">
<div class="pytorch-content-left">
<div class="rst-content">
<div role="main" class="main-content" itemscope="itemscope" itemtype="http://schema.org/Article">
<article itemprop="articleBody" id="pytorch-article" class="pytorch-article">
<div class="section" id="features-for-large-scale-deployments">
<h1>Features for large-scale deployments<a class="headerlink" href="#features-for-large-scale-deployments" title="Permalink to this headline">¶</a></h1>
<div class="contents local topic" id="contents">
<ul class="simple">
<li><p><a class="reference internal" href="#fleet-wide-operator-profiling" id="id1">Fleet-wide operator profiling</a></p></li>
<li><p><a class="reference internal" href="#api-usage-logging" id="id2">API usage logging</a></p></li>
<li><p><a class="reference internal" href="#attaching-metadata-to-saved-torchscript-models" id="id3">Attaching metadata to saved TorchScript models</a></p></li>
<li><p><a class="reference internal" href="#build-environment-considerations" id="id4">Build environment considerations</a></p></li>
<li><p><a class="reference internal" href="#common-extension-points" id="id5">Common extension points</a></p></li>
</ul>
</div>
<p>This note talks about several extension points and tricks that might be useful
when running PyTorch within a larger system or operating multiple systems using
PyTorch in a larger organization.</p>
<p>It doesn’t cover topics of deploying models to production. Check
<a class="reference internal" href="../jit.html#module-torch.jit" title="torch.jit"><code class="xref py py-mod docutils literal notranslate"><span class="pre">torch.jit</span></code></a> or one of the corresponding tutorials.</p>
<p>The note assumes that you either build PyTorch from source in your
organization or have an ability to statically link additional code to be loaded
when PyTorch is used. Therefore, many of the hooks are exposed as C++ APIs that
can be triggered once in a centralized place, e.g. in static initialization
code.</p>
<div class="section" id="fleet-wide-operator-profiling">
<h2><a class="toc-backref" href="#id1">Fleet-wide operator profiling</a><a class="headerlink" href="#fleet-wide-operator-profiling" title="Permalink to this headline">¶</a></h2>
<p>PyTorch comes with <code class="xref py py-mod docutils literal notranslate"><span class="pre">torch.autograd.profiler</span></code> capable of measuring time
taken by individual operators on demand. One can use the same mechanism to do
“always ON” measurements for any process running PyTorch. It might be useful for
gathering information about PyTorch workloads running in a given process or
across the entire set of machines.</p>
<p>New callbacks for any operator invocation can be added with
<code class="docutils literal notranslate"><span class="pre">torch::addGlobalCallback</span></code>. Hooks will be called with
<code class="docutils literal notranslate"><span class="pre">torch::RecordFunction</span></code> struct that describes invocation
context (e.g. <cite>name</cite>). If enabled, <code class="docutils literal notranslate"><span class="pre">RecordFunction::inputs()</span></code> contains arguments
of the function represented as <code class="docutils literal notranslate"><span class="pre">torch::IValue</span></code> variant type. Note, that inputs
logging is relatively expensive and thus has to be enabled explicitly.</p>
<p>The operator callbacks also have access to <code class="docutils literal notranslate"><span class="pre">c10::ThreadLocalDebugInfo::get()</span></code>
interface that returns a pointer to the struct holding the debug information.
This debug information can be set earlier by using <code class="docutils literal notranslate"><span class="pre">at::DebugInfoGuard</span></code> object.
Debug information is propagated through the forward (including async <code class="docutils literal notranslate"><span class="pre">fork</span></code>
tasks) and backward passes and can be useful for passing some extra information
about execution environment (e.g. model id) from the higher layers of the
application down to the operator callbacks.</p>
<p>Invoking callbacks adds some overhead, so usually it’s useful to just randomly
sample operator invocations. This can be enabled on per-callback basis with an
optional sampling rate passed into <code class="docutils literal notranslate"><span class="pre">torch::addGlobalCallback</span></code>.</p>
<p>Note, that <code class="docutils literal notranslate"><span class="pre">addGlobalCallback</span></code> is not thread-safe and can be called only when no
PyTorch operator is running. Usually, it’s a good idea to call them once during
initialization.</p>
<p>Here’s an example:</p>
<div class="highlight-cpp notranslate"><div class="highlight"><pre><span></span><span class="c1">// Called somewhere in the program beginning</span>
<span class="kt">void</span><span class="w"> </span><span class="nf">init</span><span class="p">()</span><span class="w"> </span><span class="p">{</span><span class="w"></span>
<span class="w"> </span><span class="c1">// Sample one in a hundred operator runs randomly</span>
<span class="w"> </span><span class="n">addGlobalCallback</span><span class="p">(</span><span class="w"></span>
<span class="w"> </span><span class="n">RecordFunctionCallback</span><span class="p">(</span><span class="w"></span>
<span class="w"> </span><span class="o">&</span><span class="n">onFunctionEnter</span><span class="p">,</span><span class="w"></span>
<span class="w"> </span><span class="o">&</span><span class="n">onFunctionExit</span><span class="p">)</span><span class="w"></span>
<span class="w"> </span><span class="p">.</span><span class="n">needsInputs</span><span class="p">(</span><span class="nb">true</span><span class="p">)</span><span class="w"></span>
<span class="w"> </span><span class="p">.</span><span class="n">samplingProb</span><span class="p">(</span><span class="mf">0.01</span><span class="p">)</span><span class="w"></span>
<span class="w"> </span><span class="p">);</span><span class="w"></span>
<span class="w"> </span><span class="c1">// Note, to enable observers in the model calling thread,</span>
<span class="w"> </span><span class="c1">// call enableRecordFunction() in the thread before running a model</span>
<span class="p">}</span><span class="w"></span>
<span class="kt">void</span><span class="w"> </span><span class="nf">onFunctionEnter</span><span class="p">(</span><span class="k">const</span><span class="w"> </span><span class="n">RecordFunction</span><span class="o">&</span><span class="w"> </span><span class="n">fn</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w"></span>
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">cerr</span><span class="w"> </span><span class="o"><<</span><span class="w"> </span><span class="s">"Before function "</span><span class="w"> </span><span class="o"><<</span><span class="w"> </span><span class="n">fn</span><span class="p">.</span><span class="n">name</span><span class="p">()</span><span class="w"></span>
<span class="w"> </span><span class="o"><<</span><span class="w"> </span><span class="s">" with "</span><span class="w"> </span><span class="o"><<</span><span class="w"> </span><span class="n">fn</span><span class="p">.</span><span class="n">inputs</span><span class="p">().</span><span class="n">size</span><span class="p">()</span><span class="w"> </span><span class="o"><<</span><span class="w"> </span><span class="s">" inputs"</span><span class="w"> </span><span class="o"><<</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">endl</span><span class="p">;</span><span class="w"></span>
<span class="p">}</span><span class="w"></span>
<span class="kt">void</span><span class="w"> </span><span class="nf">onFunctionExit</span><span class="p">(</span><span class="k">const</span><span class="w"> </span><span class="n">RecordFunction</span><span class="o">&</span><span class="w"> </span><span class="n">fn</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w"></span>
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">cerr</span><span class="w"> </span><span class="o"><<</span><span class="w"> </span><span class="s">"After function "</span><span class="w"> </span><span class="o"><<</span><span class="w"> </span><span class="n">fn</span><span class="p">.</span><span class="n">name</span><span class="p">();</span><span class="w"></span>
<span class="p">}</span><span class="w"></span>
</pre></div>
</div>
</div>
<div class="section" id="api-usage-logging">
<h2><a class="toc-backref" href="#id2">API usage logging</a><a class="headerlink" href="#api-usage-logging" title="Permalink to this headline">¶</a></h2>
<p>When running in a broader ecosystem, for example in managed job scheduler, it’s
often useful to track which binaries invoke particular PyTorch APIs. There
exists simple instrumentation injected at several important API points that
triggers a given callback. Because usually PyTorch is invoked in one-off python
scripts, the callback fires only once for a given process for each of the APIs.</p>
<p><code class="docutils literal notranslate"><span class="pre">c10::SetAPIUsageHandler</span></code> can be used to register API usage instrumentation
handler. Passed argument is going to be an “api key” identifying used point, for
example <code class="docutils literal notranslate"><span class="pre">python.import</span></code> for PyTorch extension import or
<code class="docutils literal notranslate"><span class="pre">torch.script.compile</span></code> if TorchScript compilation was triggered.</p>
<div class="highlight-cpp notranslate"><div class="highlight"><pre><span></span><span class="n">SetAPIUsageLogger</span><span class="p">([](</span><span class="k">const</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">string</span><span class="o">&</span><span class="w"> </span><span class="n">event_name</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w"></span>
<span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">cerr</span><span class="w"> </span><span class="o"><<</span><span class="w"> </span><span class="s">"API was used: "</span><span class="w"> </span><span class="o"><<</span><span class="w"> </span><span class="n">event_name</span><span class="w"> </span><span class="o"><<</span><span class="w"> </span><span class="n">std</span><span class="o">::</span><span class="n">endl</span><span class="p">;</span><span class="w"></span>
<span class="p">});</span><span class="w"></span>
</pre></div>
</div>
<p>Note for developers: new API trigger points can be added in code with
<code class="docutils literal notranslate"><span class="pre">C10_LOG_API_USAGE_ONCE("my_api")</span></code> in C++ or
<code class="docutils literal notranslate"><span class="pre">torch._C._log_api_usage_once("my.api")</span></code> in Python.</p>
</div>
<div class="section" id="attaching-metadata-to-saved-torchscript-models">
<h2><a class="toc-backref" href="#id3">Attaching metadata to saved TorchScript models</a><a class="headerlink" href="#attaching-metadata-to-saved-torchscript-models" title="Permalink to this headline">¶</a></h2>
<p>TorchScript modules can be saved as an archive file that bundles serialized
parameters and module code as TorchScript (see <a class="reference internal" href="../generated/torch.jit.save.html#torch.jit.save" title="torch.jit.save"><code class="xref py py-meth docutils literal notranslate"><span class="pre">torch.jit.save()</span></code></a>). It’s
often convenient to bundle additional information together with the model, for
example, description of model producer or auxiliary artifacts.</p>
<p>It can be achieved by passing the <code class="docutils literal notranslate"><span class="pre">_extra_files</span></code> argument to
<a class="reference internal" href="../generated/torch.jit.save.html#torch.jit.save" title="torch.jit.save"><code class="xref py py-meth docutils literal notranslate"><span class="pre">torch.jit.save()</span></code></a> and <code class="docutils literal notranslate"><span class="pre">torch::jit::load</span></code> to store and retrieve
arbitrary binary blobs during saving process. Since TorchScript files are
regular ZIP archives, extra information gets stored as regular files inside
archive’s <code class="docutils literal notranslate"><span class="pre">extra/</span></code> directory.</p>
<p>There’s also a global hook allowing to attach extra files to any TorchScript
archive produced in the current process. It might be useful to tag models with
producer metadata, akin to JPEG metadata produced by digital cameras. Example
usage might look like:</p>
<div class="highlight-cpp notranslate"><div class="highlight"><pre><span></span><span class="n">SetExportModuleExtraFilesHook</span><span class="p">([](</span><span class="k">const</span><span class="w"> </span><span class="n">Module</span><span class="o">&</span><span class="p">)</span><span class="w"> </span><span class="p">{</span><span class="w"></span>
<span class="w"> </span><span class="n">ExtraFilesMap</span><span class="w"> </span><span class="n">files</span><span class="p">;</span><span class="w"></span>
<span class="w"> </span><span class="n">files</span><span class="p">[</span><span class="s">"producer_info.json"</span><span class="p">]</span><span class="w"> </span><span class="o">=</span><span class="w"> </span><span class="s">"{</span><span class="se">\"</span><span class="s">user</span><span class="se">\"</span><span class="s">: </span><span class="se">\"</span><span class="s">"</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="n">getenv</span><span class="p">(</span><span class="s">"USER"</span><span class="p">)</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="s">"</span><span class="se">\"</span><span class="s">}"</span><span class="p">;</span><span class="w"></span>
<span class="w"> </span><span class="k">return</span><span class="w"> </span><span class="n">files</span><span class="p">;</span><span class="w"></span>
<span class="p">});</span><span class="w"></span>
</pre></div>
</div>
</div>
<div class="section" id="build-environment-considerations">
<h2><a class="toc-backref" href="#id4">Build environment considerations</a><a class="headerlink" href="#build-environment-considerations" title="Permalink to this headline">¶</a></h2>
<p>TorchScript’s compilation needs to have access to the original python files as
it uses python’s <code class="docutils literal notranslate"><span class="pre">inspect.getsource</span></code> call. In certain production environments
it might require explicitly deploying <code class="docutils literal notranslate"><span class="pre">.py</span></code> files along with precompiled
<code class="docutils literal notranslate"><span class="pre">.pyc</span></code>.</p>
</div>
<div class="section" id="common-extension-points">
<h2><a class="toc-backref" href="#id5">Common extension points</a><a class="headerlink" href="#common-extension-points" title="Permalink to this headline">¶</a></h2>
<p>PyTorch APIs are generally loosely coupled and it’s easy to replace a component
with specialized version. Common extension points include:</p>
<ul class="simple">
<li><p>Custom operators implemented in C++ - see <a class="reference external" href="https://pytorch.org/tutorials/advanced/cpp_extension.html">tutorial for more details</a>.</p></li>
<li><p>Custom data reading can be often integrated directly by invoking corresponding python library. Existing functionality of <a class="reference internal" href="../data.html#module-torch.utils.data" title="torch.utils.data"><code class="xref py py-mod docutils literal notranslate"><span class="pre">torch.utils.data</span></code></a> can be utilized by extending <a class="reference internal" href="../data.html#torch.utils.data.Dataset" title="torch.utils.data.Dataset"><code class="xref py py-class docutils literal notranslate"><span class="pre">Dataset</span></code></a> or <a class="reference internal" href="../data.html#torch.utils.data.IterableDataset" title="torch.utils.data.IterableDataset"><code class="xref py py-class docutils literal notranslate"><span class="pre">IterableDataset</span></code></a>.</p></li>
</ul>
</div>
</div>
</article>
</div>
<footer>
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
<a href="modules.html" class="btn btn-neutral float-right" title="Modules" accesskey="n" rel="next">Next <img src="../_static/images/chevron-right-orange.svg" class="next-page"></a>
<a href="hip.html" class="btn btn-neutral" title="HIP (ROCm) semantics" accesskey="p" rel="prev"><img src="../_static/images/chevron-right-orange.svg" class="previous-page"> Previous</a>
</div>
<hr>
<div role="contentinfo">
<p>
© Copyright 2022, PyTorch Contributors.
</p>
</div>
<div>
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a <a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a> provided by <a href="https://readthedocs.org">Read the Docs</a>.
</div>
</footer>
</div>
</div>
<div class="pytorch-content-right" id="pytorch-content-right">
<div class="pytorch-right-menu" id="pytorch-right-menu">
<div class="pytorch-side-scroll" id="pytorch-side-scroll-right">
<ul>
<li><a class="reference internal" href="#">Features for large-scale deployments</a><ul>
<li><a class="reference internal" href="#fleet-wide-operator-profiling">Fleet-wide operator profiling</a></li>
<li><a class="reference internal" href="#api-usage-logging">API usage logging</a></li>
<li><a class="reference internal" href="#attaching-metadata-to-saved-torchscript-models">Attaching metadata to saved TorchScript models</a></li>
<li><a class="reference internal" href="#build-environment-considerations">Build environment considerations</a></li>
<li><a class="reference internal" href="#common-extension-points">Common extension points</a></li>
</ul>
</li>
</ul>
</div>
</div>
</div>
</section>
</div>
<script type="text/javascript" id="documentation_options" data-url_root="../" src="../_static/documentation_options.js"></script>
<script src="../_static/jquery.js"></script>
<script src="../_static/underscore.js"></script>
<script src="../_static/doctools.js"></script>
<script src="../_static/clipboard.min.js"></script>
<script src="../_static/copybutton.js"></script>
<script type="text/javascript" src="../_static/js/vendor/popper.min.js"></script>
<script type="text/javascript" src="../_static/js/vendor/bootstrap.min.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/list.js/1.5.0/list.min.js"></script>
<script type="text/javascript" src="../_static/js/theme.js"></script>
<script type="text/javascript">
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
</script>
<script script type="text/javascript">
var collapsedSections = ['Notes', 'Language Bindings', 'Libraries', 'Community'];
</script>
<img height="1" width="1" style="border-style:none;" alt="" src="https://www.googleadservices.com/pagead/conversion/795629140/?label=txkmCPmdtosBENSssfsC&guid=ON&script=0"/>
<!-- Begin Footer -->
<div class="container-fluid docs-tutorials-resources" id="docs-tutorials-resources">
<div class="container">
<div class="row">
<div class="col-md-4 text-center">
<h2>Docs</h2>
<p>Access comprehensive developer documentation for PyTorch</p>
<a class="with-right-arrow" href="https://pytorch.org/docs/stable/index.html">View Docs</a>
</div>
<div class="col-md-4 text-center">
<h2>Tutorials</h2>
<p>Get in-depth tutorials for beginners and advanced developers</p>
<a class="with-right-arrow" href="https://pytorch.org/tutorials">View Tutorials</a>
</div>
<div class="col-md-4 text-center">
<h2>Resources</h2>
<p>Find development resources and get your questions answered</p>
<a class="with-right-arrow" href="https://pytorch.org/resources">View Resources</a>
</div>
</div>
</div>
</div>
<footer class="site-footer">
<div class="container footer-container">
<div class="footer-logo-wrapper">
<a href="https://pytorch.org/" class="footer-logo"></a>
</div>
<div class="footer-links-wrapper">
<div class="footer-links-col">
<ul>
<li class="list-title"><a href="https://pytorch.org/">PyTorch</a></li>
<li><a href="https://pytorch.org/get-started">Get Started</a></li>
<li><a href="https://pytorch.org/features">Features</a></li>
<li><a href="https://pytorch.org/ecosystem">Ecosystem</a></li>
<li><a href="https://pytorch.org/blog/">Blog</a></li>
<li><a href="https://github.com/pytorch/pytorch/blob/master/CONTRIBUTING.md">Contributing</a></li>
</ul>
</div>
<div class="footer-links-col">
<ul>
<li class="list-title"><a href="https://pytorch.org/resources">Resources</a></li>
<li><a href="https://pytorch.org/tutorials">Tutorials</a></li>
<li><a href="https://pytorch.org/docs/stable/index.html">Docs</a></li>
<li><a href="https://discuss.pytorch.org" target="_blank">Discuss</a></li>
<li><a href="https://github.com/pytorch/pytorch/issues" target="_blank">Github Issues</a></li>
<li><a href="https://pytorch.org/assets/brand-guidelines/PyTorch-Brand-Guidelines.pdf" target="_blank">Brand Guidelines</a></li>
</ul>
</div>
<div class="footer-links-col">
<ul>
<li class="list-title">Stay up to date</li>
<li><a href="https://www.facebook.com/pytorch" target="_blank">Facebook</a></li>
<li><a href="https://twitter.com/pytorch" target="_blank">Twitter</a></li>
<li><a href="https://www.youtube.com/pytorch" target="_blank">YouTube</a></li>
<li><a href="https://www.linkedin.com/company/pytorch" target="_blank">LinkedIn</a></li>
</ul>
</div>
<div class="footer-links-col">
<ul>
<li class="list-title">PyTorch Podcasts</li>
<li><a href="https://open.spotify.com/show/6UzHKeiy368jKfQMKKvJY5" target="_blank">Spotify</a></li>
<li><a href="https://podcasts.apple.com/us/podcast/pytorch-developer-podcast/id1566080008" target="_blank">Apple</a></li>
<li><a href="https://www.google.com/podcasts?feed=aHR0cHM6Ly9mZWVkcy5zaW1wbGVjYXN0LmNvbS9PQjVGa0lsOA%3D%3D" target="_blank">Google</a></li>
<li><a href="https://music.amazon.com/podcasts/7a4e6f0e-26c2-49e9-a478-41bd244197d0/PyTorch-Developer-Podcast?" target="_blank">Amazon</a></li>
</ul>
</div>
</div>
<hr size="20" color="white" />
<div class="privacy-policy">
<p class="privacy-policy-links"><a href="https://www.linuxfoundation.org/terms/" target="_blank">Terms</a> | <a href="https://www.linuxfoundation.org/privacy-policy/" target="_blank">Privacy</a></p>
</div>
<hr size="20" color="white" />
<div class="copyright">
<p>© Copyright The Linux Foundation. The PyTorch Foundation is a project of The Linux Foundation.
For web site terms of use, trademark policy and other policies applicable to The PyTorch Foundation please see
<a href="https://www.linuxfoundation.org/policies/" style="color:#ee4c2c">www.linuxfoundation.org/policies/</a>. The PyTorch Foundation supports the PyTorch open source
project, which has been established as PyTorch Project a Series of LF Projects, LLC. For policies applicable to the PyTorch Project a Series of LF Projects, LLC,
please see <a href="https://www.lfprojects.org/policies/" style="color:#ee4c2c">www.lfprojects.org/policies/</a>.</p>
</div>
</div>
</footer>
<div class="cookie-banner-wrapper">
<div class="container">
<p class="gdpr-notice">To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. As the current maintainers of this site, Facebook’s Cookies Policy applies. Learn more, including about available controls: <a href="https://www.facebook.com/policies/cookies/">Cookies Policy</a>.</p>
<img class="close-button" src="../_static/images/pytorch-x.svg">
</div>
</div>
<!-- End Footer -->
<!-- Begin Mobile Menu -->
<div class="mobile-main-menu">
<div class="container-fluid">
<div class="container">
<div class="mobile-main-menu-header-container">
<a class="header-logo" href="https://pytorch.org/" aria-label="PyTorch"></a>
<a class="main-menu-close-button" href="#" data-behavior="close-mobile-menu"></a>
</div>
</div>
</div>
<div class="mobile-main-menu-links-container">
<div class="main-menu">
<ul>
<li>
<a href="https://pytorch.org/get-started">Get Started</a>
</li>
<li>
<a href="https://pytorch.org/ecosystem">Ecosystem</a>
</li>
<li>
<a href="https://pytorch.org/mobile">Mobile</a>
</li>
<li>
<a href="https://pytorch.org/hub">PyTorch Hub</a>
</li>
<li>
<a href="https://pytorch.org/blog/">Blog</a>
</li>
<li>
<a href="https://pytorch.org/tutorials">Tutorials</a>
</li>
<li class="resources-mobile-menu-title" class="active">
Docs
</li>
<ul class="resources-mobile-menu-items">
<li>
<a href="https://pytorch.org/docs/stable/index.html">PyTorch</a>
</li>
<li>
<a href="https://pytorch.org/audio/stable/index.html">torchaudio</a>
</li>
<li>
<a href="https://pytorch.org/text/stable/index.html">torchtext</a>
</li>
<li>
<a href="https://pytorch.org/vision/stable/index.html">torchvision</a>
</li>
<li>
<a href="https://pytorch.org/serve/">TorchServe</a>
</li>
<li>
<a href="https://pytorch.org/torchx/">TorchX</a>
</li>
<li>
<a href="https://pytorch.org/xla">PyTorch on XLA Devices</a>
</li>
</ul>
<li class="resources-mobile-menu-title">
Resources
</li>
<ul class="resources-mobile-menu-items">
<li>
<a href="https://pytorch.org/resources">Developer Resources</a>
</li>
<li>
<a href="https://pytorch.org/features">About</a>
</li>
<li>
<a href="https://pytorch.org/hub">Models (Beta)</a>
</li>
<li>
<a href="https://pytorch.org/#community-module">Community</a>
</li>
<li>
<a href="https://discuss.pytorch.org/">Forums</a>
</li>
</ul>
<li>
<a href="https://github.com/pytorch/pytorch">Github</a>
</li>
</ul>
</div>
</div>
</div>
<!-- End Mobile Menu -->
<script type="text/javascript" src="../_static/js/vendor/anchor.min.js"></script>
<script type="text/javascript">
$(document).ready(function() {
mobileMenu.bind();
mobileTOC.bind();
pytorchAnchors.bind();
sideMenus.bind();
scrollToAnchor.bind();
highlightNavigation.bind();
mainMenuDropdown.bind();
filterTags.bind();
// Add class to links that have code blocks, since we cannot create links in code blocks
$("article.pytorch-article a span.pre").each(function(e) {
$(this).closest("a").addClass("has-code");
});
})
</script>
</body>
</html>