Skip to content

Commit 243acd7

Browse files
jbschlosserbrianjo
andauthored
Recipe for skipping parameter init (#1559)
* Recipe for skipping parameter init Co-authored-by: Joel Benjamin Schlosser <[email protected]> Co-authored-by: Brian Johnson <[email protected]>
1 parent 817e3ae commit 243acd7

File tree

1 file changed

+127
-0
lines changed

1 file changed

+127
-0
lines changed

prototype_source/skip_param_init.rst

+127
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
Skipping Module Parameter Initialization
2+
========================================
3+
4+
Introduction
5+
------------
6+
7+
When a module is created, its learnable parameters are initialized according
8+
to a default initialization scheme associated with the module type. For example, the `weight`
9+
parameter for a :class:`torch.nn.Linear` module is initialized from a
10+
`uniform(-1/sqrt(in_features), 1/sqrt(in_features))` distribution. If some other initialization
11+
scheme is desired, this has traditionally required re-initializing the parameters
12+
after module instantiation:
13+
14+
::
15+
16+
from torch import nn
17+
18+
# Initializes weight from the default distribution: uniform(-1/sqrt(10), 1/sqrt(10)).
19+
m = nn.Linear(10, 5)
20+
21+
# Re-initialize weight from a different distribution.
22+
nn.init.orthogonal_(m.weight)
23+
24+
In this case, the initialization done during construction is wasted computation, and it may be non-trivial if
25+
the `weight` parameter is large.
26+
27+
Skipping Initialization
28+
-----------------------
29+
30+
It is now possible to skip parameter initialization during module construction, avoiding
31+
wasted computation. This is easily accomplished using the :func:`torch.nn.utils.skip_init` function:
32+
33+
::
34+
35+
from torch import nn
36+
from torch.nn.utils import skip_init
37+
38+
m = skip_init(nn.Linear, 10, 5)
39+
40+
# Example: Do custom, non-default parameter initialization.
41+
nn.init.orthogonal_(m.weight)
42+
43+
This can be applied to any module that satisfies the conditions described in the
44+
:ref:`Updating` section below. Note that all modules provided by
45+
`torch.nn` satisfy these conditions and thus support skipping init.
46+
47+
.. _Updating:
48+
49+
Updating Modules to Support Skipping Initialization
50+
---------------------------------------------------
51+
52+
Due to the way :func:`torch.nn.utils.skip_init` is implemented (see :ref:`Details`), there are
53+
two requirements that a module must meet to be compatible with the function.
54+
You can opt in to the parameter initialization skipping functionality for your custom module
55+
simply by adhering to these requirements:
56+
57+
1. The module must accept a `device` kwarg in its constructor that is passed to any parameters
58+
or buffers created during construction.
59+
60+
2. The module must not perform any computation on parameters or buffers in its constructor except
61+
initialization (i.e. functions from `torch.nn.init`).
62+
63+
The following example demonstrates a module updated to support the `device`
64+
kwarg by passing it along to any created parameters, buffers, or submodules:
65+
66+
::
67+
68+
import torch
69+
from torch import nn
70+
71+
class MyModule(torch.nn.Module):
72+
def __init__(self, foo, bar, device=None):
73+
super().__init__()
74+
75+
# ==== Case 1: Module creates parameters directly. ====
76+
# Pass device along to any created parameters.
77+
self.param1 = nn.Parameter(torch.empty((foo, bar), device=device))
78+
self.register_parameter('param2', nn.Parameter(torch.empty(bar, device=device)))
79+
80+
# To ensure support for the meta device, avoid using ops except those in
81+
# torch.nn.init on parameters in your module's constructor.
82+
with torch.no_grad():
83+
nn.init.kaiming_uniform_(self.param1)
84+
nn.init.uniform_(self.param2)
85+
86+
87+
# ==== Case 2: Module creates submodules. ====
88+
# Pass device along recursively. All submodules will need to support
89+
# them as well; this is the case for all torch.nn provided modules.
90+
self.fc = nn.Linear(bar, 5, device=device)
91+
92+
# This also works with containers.
93+
self.linears = nn.Sequential(
94+
nn.Linear(5, 5, device=device),
95+
nn.Linear(5, 1, device=device)
96+
)
97+
98+
99+
# ==== Case 3: Module creates buffers. ====
100+
# Pass device along during buffer tensor creation.
101+
self.register_buffer('some_buffer', torch.ones(7, device=device))
102+
103+
...
104+
105+
.. _Details:
106+
107+
Implementation Details
108+
----------------------
109+
110+
Behind the scenes, the :func:`torch.nn.utils.skip_init` function is implemented in terms of a two-step pattern:
111+
112+
::
113+
114+
# 1. Initialize module on the meta device; all torch.nn.init ops have
115+
# no-op behavior on the meta device.
116+
m = nn.Linear(10, 5, device='meta')
117+
118+
# 2. Materialize an uninitialized (empty) form of the module on the CPU device.
119+
# The result of this is a module instance with uninitialized parameters.
120+
m.to_empty(device='cpu')
121+
122+
It works by instantiating the module onto a "meta" device, which has tensor shape information
123+
but does not allocate any storage. The `torch.nn.init` ops are specially implemented for this meta device
124+
so that they have no-op behavior. This results in the parameter intialization logic being essentially skipped.
125+
126+
Note that this pattern only works for modules that properly support a `device` kwarg during construction, as
127+
described in :ref:`Updating`.

0 commit comments

Comments
 (0)