forked from Lightning-AI/pytorch-lightning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_memory.py
226 lines (188 loc) · 7.1 KB
/
test_memory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import pytest
import torch
import torch.nn as nn
from pytorch_lightning import LightningModule
from pytorch_lightning.core.memory import UNKNOWN_SIZE, ModelSummary
from tests.base.models import ParityModuleRNN
class EmptyModule(LightningModule):
""" A module that has no layers """
def __init__(self):
super().__init__()
self.parameter = torch.rand(3, 3, requires_grad=True)
self.example_input_array = torch.zeros(1, 2, 3, 4, 5)
def forward(self, *args, **kwargs):
return {'loss': self.parameter.sum()}
class UnorderedModel(LightningModule):
""" A model in which the layers not defined in order of execution """
def __init__(self):
super().__init__()
# note: the definition order is intentionally scrambled for this test
self.layer2 = nn.Linear(10, 2)
self.combine = nn.Linear(7, 9)
self.layer1 = nn.Linear(3, 5)
self.relu = nn.ReLU()
# this layer is unused, therefore input-/output shapes are unknown
self.unused = nn.Conv2d(1, 1, 1)
self.example_input_array = (torch.rand(2, 3), torch.rand(2, 10))
def forward(self, x, y):
out1 = self.layer1(x)
out2 = self.layer2(y)
out = self.relu(torch.cat((out1, out2), 1))
out = self.combine(out)
return out
class MixedDtypeModel(LightningModule):
""" The parameters and inputs of this model have different dtypes. """
def __init__(self):
super().__init__()
self.embed = nn.Embedding(10, 20) # expects dtype long as input
self.reduce = nn.Linear(20, 1) # dtype: float
self.example_input_array = torch.tensor([[0, 2, 1], [3, 5, 3]]) # dtype: long
def forward(self, x):
return self.reduce(self.embed(x))
@pytest.mark.parametrize(['mode'], [
pytest.param(ModelSummary.MODE_FULL),
pytest.param(ModelSummary.MODE_TOP),
])
def test_empty_model_summary_shapes(mode):
""" Test that the summary works for models that have no submodules. """
model = EmptyModule()
summary = model.summarize(mode=mode)
assert summary.in_sizes == []
assert summary.out_sizes == []
assert summary.param_nums == []
@pytest.mark.parametrize(['mode'], [
pytest.param(ModelSummary.MODE_FULL),
pytest.param(ModelSummary.MODE_TOP),
])
@pytest.mark.parametrize(['device'], [
pytest.param(torch.device('cpu')),
pytest.param(torch.device('cuda', 0)),
pytest.param(torch.device('cuda', 0)),
])
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Test requires GPU.")
def test_linear_model_summary_shapes(device, mode):
""" Test that the model summary correctly computes the input- and output shapes. """
model = UnorderedModel().to(device)
model.train()
summary = model.summarize(mode=mode)
assert summary.in_sizes == [
[2, 10], # layer 2
[2, 7], # combine
[2, 3], # layer 1
[2, 7], # relu
UNKNOWN_SIZE,
]
assert summary.out_sizes == [
[2, 2], # layer 2
[2, 9], # combine
[2, 5], # layer 1
[2, 7], # relu
UNKNOWN_SIZE,
]
assert model.training
assert model.device == device
def test_mixed_dtype_model_summary():
""" Test that the model summary works with models that have mixed input- and parameter dtypes. """
model = MixedDtypeModel()
summary = model.summarize()
assert summary.in_sizes == [
[2, 3], # embed
[2, 3, 20], # reduce
]
assert summary.out_sizes == [
[2, 3, 20], # embed
[2, 3, 1], # reduce
]
@pytest.mark.parametrize(['mode'], [
pytest.param(ModelSummary.MODE_FULL),
pytest.param(ModelSummary.MODE_TOP),
])
def test_hooks_removed_after_summarize(mode):
""" Test that all hooks were properly removed after summary, even ones that were not run. """
model = UnorderedModel()
summary = ModelSummary(model, mode=mode)
# hooks should be removed
for _, layer in summary.summarize().items():
handle = layer._hook_handle
assert handle.id not in handle.hooks_dict_ref()
@pytest.mark.parametrize(['mode'], [
pytest.param(ModelSummary.MODE_FULL),
pytest.param(ModelSummary.MODE_TOP),
])
def test_rnn_summary_shapes(mode):
""" Test that the model summary works for RNNs. """
model = ParityModuleRNN()
b = 3
t = 5
i = model.rnn.input_size
h = model.rnn.hidden_size
o = model.linear_out.out_features
model.example_input_array = torch.zeros(b, t, 10)
summary = model.summarize(mode=mode)
assert summary.in_sizes == [
[b, t, i], # rnn
[b, t, h], # linear
]
assert summary.out_sizes == [
[[b, t, h], [[1, b, h], [1, b, h]]], # rnn
[b, t, o] # linear
]
@pytest.mark.parametrize(['mode'], [
pytest.param(ModelSummary.MODE_FULL),
pytest.param(ModelSummary.MODE_TOP),
])
def test_summary_parameter_count(mode):
""" Test that the summary counts the number of parameters in every submodule. """
model = UnorderedModel()
summary = model.summarize(mode=mode)
assert summary.param_nums == [
model.layer2.weight.numel() + model.layer2.bias.numel(),
model.combine.weight.numel() + model.combine.bias.numel(),
model.layer1.weight.numel() + model.layer1.bias.numel(),
0, # ReLU
model.unused.weight.numel() + model.unused.bias.numel(),
]
@pytest.mark.parametrize(['mode'], [
pytest.param(ModelSummary.MODE_FULL),
pytest.param(ModelSummary.MODE_TOP),
])
def test_summary_layer_types(mode):
""" Test that the summary displays the layer names correctly. """
model = UnorderedModel()
summary = model.summarize(mode=mode)
assert summary.layer_types == [
'Linear',
'Linear',
'Linear',
'ReLU',
'Conv2d',
]
@pytest.mark.parametrize(['mode'], [
pytest.param(ModelSummary.MODE_FULL),
pytest.param(ModelSummary.MODE_TOP),
])
@pytest.mark.parametrize(['example_input', 'expected_size'], [
pytest.param([], UNKNOWN_SIZE),
pytest.param((1, 2, 3), [UNKNOWN_SIZE] * 3),
pytest.param(torch.tensor(0), UNKNOWN_SIZE),
pytest.param(dict(tensor=torch.zeros(1, 2, 3)), UNKNOWN_SIZE),
pytest.param(torch.zeros(2, 3, 4), [2, 3, 4]),
pytest.param([torch.zeros(2, 3), torch.zeros(4, 5)], [[2, 3], [4, 5]]),
pytest.param((torch.zeros(2, 3), torch.zeros(4, 5)), [[2, 3], [4, 5]]),
])
def test_example_input_array_types(example_input, expected_size, mode):
""" Test the types of example inputs supported for display in the summary. """
class DummyModule(nn.Module):
def forward(self, *args, **kwargs):
return None
class DummyLightningModule(LightningModule):
def __init__(self):
super().__init__()
self.layer = DummyModule()
# this LightningModule and submodule accept any type of input
def forward(self, *args, **kwargs):
return self.layer(*args, **kwargs)
model = DummyLightningModule()
model.example_input_array = example_input
summary = model.summarize(mode=mode)
assert summary.in_sizes == [expected_size]