forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfactory.py
469 lines (383 loc) · 16.9 KB
/
factory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
import torch
from collections import namedtuple
from .cells import lstm_cell, premul_lstm_cell, premul_lstm_cell_no_bias, flat_lstm_cell
# list[list[T]] -> list[T]
def flatten_list(lst):
result = []
for inner in lst:
result.extend(inner)
return result
'''
Define a creator as a function:
(options) -> (inputs, params, forward, backward_setup, backward)
inputs: the inputs to the returned 'forward'. One can call
forward(*inputs) directly.
params: List[Tensor] all requires_grad=True parameters.
forward: function / graph executor / module
One can call rnn(rnn_inputs) using the outputs of the creator.
backward_setup: backward_inputs = backward_setup(*outputs)
Then, we pass backward_inputs to backward. If None, then it is assumed to
be the identity function.
backward: Given `output = backward_setup(*forward(*inputs))`, performs
backpropagation. If None, then nothing happens.
fastrnns.bench times the forward and backward invocations.
'''
ModelDef = namedtuple('ModelDef', [
'inputs', 'params', 'forward', 'backward_setup', 'backward'])
def lstm_backward_setup(lstm_outputs, seed=None):
hx, _ = lstm_outputs
return simple_backward_setup(hx, seed)
def simple_backward_setup(output, seed=None):
assert isinstance(output, torch.Tensor)
if seed:
torch.manual_seed(seed)
grad_output = torch.randn_like(output)
return output, grad_output
def simple_backward(output, grad_output):
return output.backward(grad_output)
def pytorch_lstm_creator(**kwargs):
input, hidden, _, module = lstm_inputs(return_module=True, **kwargs)
return ModelDef(
inputs=[input, hidden],
params=flatten_list(module.all_weights),
forward=module,
backward_setup=lstm_backward_setup,
backward=simple_backward)
def lstm_creator(script=True, **kwargs):
input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs)
inputs = [input, hidden] + params[0]
return ModelDef(
inputs=inputs,
params=flatten_list(params),
forward=lstm_factory(lstm_cell, script),
backward_setup=lstm_backward_setup,
backward=simple_backward)
def lnlstm_creator(script=True, decompose_layernorm=False, **kwargs):
assert script is True
from .custom_lstms import script_lnlstm
input_size = kwargs['inputSize']
hidden_size = kwargs['hiddenSize']
seq_len = kwargs['seqLength']
batch_size = kwargs['miniBatch']
ge = script_lnlstm(input_size, hidden_size, 1,
decompose_layernorm=decompose_layernorm).cuda()
input = torch.randn(seq_len, batch_size, input_size, device='cuda')
states = [(torch.randn(batch_size, hidden_size, device='cuda'),
torch.randn(batch_size, hidden_size, device='cuda'))]
return ModelDef(
inputs=[input, states],
params=ge.parameters(),
forward=ge,
backward_setup=lstm_backward_setup,
backward=simple_backward)
def dropoutlstm_creator(script=True, **kwargs):
assert script is True
from .custom_lstms import script_lstm, LSTMState
input_size = kwargs['inputSize']
hidden_size = kwargs['hiddenSize']
seq_len = kwargs['seqLength']
batch_size = kwargs['miniBatch']
num_layers = kwargs['numLayers']
ge = script_lstm(input_size, hidden_size, num_layers, dropout=True).cuda()
input = torch.randn(seq_len, batch_size, input_size, device='cuda')
states = [LSTMState(torch.randn(batch_size, hidden_size, device='cuda'),
torch.randn(batch_size, hidden_size, device='cuda'))
for _ in range(num_layers)]
return ModelDef(
inputs=[input, states],
params=ge.parameters(),
forward=ge,
backward_setup=lstm_backward_setup,
backward=simple_backward)
def lstm_premul_creator(script=True, **kwargs):
input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs)
inputs = [input, hidden] + params[0]
return ModelDef(
inputs=inputs,
params=flatten_list(params),
forward=lstm_factory_premul(premul_lstm_cell, script),
backward_setup=lstm_backward_setup,
backward=simple_backward)
def lstm_premul_bias_creator(script=True, **kwargs):
input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs)
inputs = [input, hidden] + params[0]
return ModelDef(
inputs=inputs,
params=flatten_list(params),
forward=lstm_factory_premul_bias(premul_lstm_cell_no_bias, script),
backward_setup=lstm_backward_setup,
backward=simple_backward)
def lstm_simple_creator(script=True, **kwargs):
input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs)
inputs = [input] + [h[0] for h in hidden] + params[0]
return ModelDef(
inputs=inputs,
params=flatten_list(params),
forward=lstm_factory_simple(flat_lstm_cell, script),
backward_setup=lstm_backward_setup,
backward=simple_backward)
def lstm_multilayer_creator(script=True, **kwargs):
input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs)
inputs = [input, hidden, flatten_list(params)]
return ModelDef(
inputs=inputs,
params=flatten_list(params),
forward=lstm_factory_multilayer(lstm_cell, script),
backward_setup=lstm_backward_setup,
backward=simple_backward)
def imagenet_cnn_creator(arch, jit=True):
def creator(device='cuda', **kwargs):
model = arch().to(device)
x = torch.randn(32, 3, 224, 224, device=device)
if jit:
model = torch.jit.trace(model, x)
return ModelDef(
inputs=(x,),
params=list(model.parameters()),
forward=model,
backward_setup=simple_backward_setup,
backward=simple_backward)
return creator
def varlen_lstm_inputs(minlen=30, maxlen=100,
numLayers=1, inputSize=512, hiddenSize=512,
miniBatch=64, return_module=False, device='cuda',
seed=None, **kwargs):
if seed is not None:
torch.manual_seed(seed)
lengths = torch.randint(
low=minlen, high=maxlen, size=[miniBatch],
dtype=torch.long, device=device)
x = [torch.randn(length, inputSize, device=device)
for length in lengths]
hx = torch.randn(numLayers, miniBatch, hiddenSize, device=device)
cx = torch.randn(numLayers, miniBatch, hiddenSize, device=device)
lstm = torch.nn.LSTM(inputSize, hiddenSize, numLayers).to(device)
if return_module:
return x, lengths, (hx, cx), lstm.all_weights, lstm
else:
# NB: lstm.all_weights format:
# wih, whh, bih, bhh = lstm.all_weights[layer]
return x, lengths, (hx, cx), lstm.all_weights, None
def varlen_lstm_backward_setup(forward_output, seed=None):
if seed:
torch.manual_seed(seed)
rnn_utils = torch.nn.utils.rnn
sequences = forward_output[0]
padded = rnn_utils.pad_sequence(sequences)
grad = torch.randn_like(padded)
return padded, grad
def varlen_pytorch_lstm_creator(**kwargs):
rnn_utils = torch.nn.utils.rnn
sequences, _, hidden, _, module = varlen_lstm_inputs(
return_module=True, **kwargs)
def forward(sequences, hidden):
packed = rnn_utils.pack_sequence(sequences, enforce_sorted=False)
out, new_hidden = module(packed, hidden)
padded, lengths = rnn_utils.pad_packed_sequence(out)
# XXX: It's more efficient to store the output in its padded form,
# but that might not be conducive to loss computation.
# Un-padding the output also makes the backward pass 2x slower...
# return [padded[:lengths[i], i, :] for i in range(lengths.size(0))]
return padded, new_hidden
return ModelDef(
inputs=[sequences, hidden],
params=flatten_list(module.all_weights),
forward=forward,
backward_setup=lstm_backward_setup,
backward=simple_backward)
def varlen_lstm_factory(cell, script):
def dynamic_rnn(sequences, hiddens, wih, whh, bih, bhh):
# type: (List[Tensor], Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[List[Tensor], Tuple[List[Tensor], List[Tensor]]] # noqa
hx, cx = hiddens
hxs = hx.unbind(1)
cxs = cx.unbind(1)
# List of: (output, hx, cx)
outputs = []
hx_outs = []
cx_outs = []
for batch in range(len(sequences)):
output = []
hy, cy = hxs[batch], cxs[batch]
inputs = sequences[batch].unbind(0)
for seq_idx in range(len(inputs)):
hy, cy = cell(
inputs[seq_idx].unsqueeze(0), (hy, cy), wih, whh, bih, bhh)
output += [hy]
outputs += [torch.stack(output)]
hx_outs += [hy.unsqueeze(0)]
cx_outs += [cy.unsqueeze(0)]
return outputs, (hx_outs, cx_outs)
if script:
cell = torch.jit.script(cell)
dynamic_rnn = torch.jit.script(dynamic_rnn)
return dynamic_rnn
def varlen_lstm_creator(script=False, **kwargs):
sequences, _, hidden, params, _ = varlen_lstm_inputs(
return_module=False, **kwargs)
inputs = [sequences, hidden] + params[0]
return ModelDef(
inputs=inputs,
params=flatten_list(params),
forward=varlen_lstm_factory(lstm_cell, script),
backward_setup=varlen_lstm_backward_setup,
backward=simple_backward)
# cudnn_layernorm_lstm: since cudnn does not have Layernorm LSTM, we cannot benchmark
# the lowerbound directly. Instead, we only benchmark the forward pass by mimicing the
# computation of a cudnn lstm + seq_len * 3 layernorm computation. This should serve
# as a perf lowerbound for the Layernorm LSTM forward pass(given that Layernorm itself
# is invariant), the lowerbound of backward pass is hard to get since we lose the
# intermediate results, we can still optimize the layernorm implementation to make
# a faster forward lowerbound though.
def layernorm_pytorch_lstm_creator(**kwargs):
input, hidden, _, module = lstm_inputs(return_module=True, **kwargs)
batch_size = kwargs['miniBatch']
hidden_size = kwargs['hiddenSize']
ln_i = torch.nn.LayerNorm(4 * hidden_size).cuda()
ln_h = torch.nn.LayerNorm(4 * hidden_size).cuda()
ln_c = torch.nn.LayerNorm(hidden_size).cuda()
ln_input1 = torch.randn(batch_size, 4 * hidden_size, device='cuda')
def forward(input, hidden):
out, new_hidden = module(input, hidden)
# plus (seq_len * three laynorm cell computation) to mimic the lower bound of
# Layernorm cudnn LSTM in the forward pass
seq_len = len(input.unbind(0))
hy, cy = new_hidden
for i in range(seq_len):
ln_i_output = ln_i(ln_input1)
ln_h_output = ln_h(ln_input1)
cy = ln_c(cy)
return out, (hy, cy)
return ModelDef(
inputs=[input, hidden],
params=flatten_list(module.all_weights),
forward=forward,
backward_setup=lstm_backward_setup,
backward=None)
# input: lstm.all_weights format (wih, whh, bih, bhh = lstm.all_weights[layer])
# output: packed_weights with format
# packed_weights[0] is wih with size (layer, 4*hiddenSize, inputSize)
# packed_weights[1] is whh with size (layer, 4*hiddenSize, hiddenSize)
# packed_weights[2] is bih with size (layer, 4*hiddenSize)
# packed_weights[3] is bhh with size (layer, 4*hiddenSize)
def stack_weights(weights):
def unzip_columns(mat):
assert isinstance(mat, list)
assert isinstance(mat[0], list)
layers = len(mat)
columns = len(mat[0])
return [[mat[layer][col] for layer in range(layers)]
for col in range(columns)]
# XXX: script fns have problems indexing multidim lists, so we try to
# avoid them by stacking tensors
all_weights = weights
packed_weights = [torch.stack(param)
for param in unzip_columns(all_weights)]
return packed_weights
# returns: x, (hx, cx), all_weights, lstm module with all_weights as params
def lstm_inputs(seqLength=100, numLayers=1, inputSize=512, hiddenSize=512,
miniBatch=64, dropout=0.0, return_module=False, device='cuda', seed=None):
if seed is not None:
torch.manual_seed(seed)
x = torch.randn(seqLength, miniBatch, inputSize, device=device)
hx = torch.randn(numLayers, miniBatch, hiddenSize, device=device)
cx = torch.randn(numLayers, miniBatch, hiddenSize, device=device)
lstm = torch.nn.LSTM(inputSize, hiddenSize, numLayers, dropout=dropout)
if 'cuda' in device:
lstm = lstm.cuda()
if return_module:
return x, (hx, cx), lstm.all_weights, lstm
else:
# NB: lstm.all_weights format:
# wih, whh, bih, bhh = lstm.all_weights[layer]
return x, (hx, cx), lstm.all_weights, None
def lstm_factory(cell, script):
def dynamic_rnn(input, hidden, wih, whh, bih, bhh):
# type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
hx, cx = hidden
outputs = []
inputs = input.unbind(0)
hy, cy = hx[0], cx[0]
for seq_idx in range(len(inputs)):
hy, cy = cell(inputs[seq_idx], (hy, cy), wih, whh, bih, bhh)
outputs += [hy]
return torch.stack(outputs), (hy.unsqueeze(0), cy.unsqueeze(0))
if script:
cell = torch.jit.script(cell)
dynamic_rnn = torch.jit.script(dynamic_rnn)
return dynamic_rnn
# premul: we're going to premultiply the inputs & weights
def lstm_factory_premul(premul_cell, script):
def dynamic_rnn(input, hidden, wih, whh, bih, bhh):
# type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
hx, cx = hidden
outputs = []
inputs = torch.matmul(input, wih.t()).unbind(0)
hy, cy = hx[0], cx[0]
for seq_idx in range(len(inputs)):
hy, cy = premul_cell(inputs[seq_idx], (hy, cy), whh, bih, bhh)
outputs += [hy]
return torch.stack(outputs), (hy.unsqueeze(0), cy.unsqueeze(0))
if script:
premul_cell = torch.jit.script(premul_cell)
dynamic_rnn = torch.jit.script(dynamic_rnn)
return dynamic_rnn
# premul: we're going to premultiply the inputs & weights, and add bias
def lstm_factory_premul_bias(premul_cell, script):
def dynamic_rnn(input, hidden, wih, whh, bih, bhh):
# type: (Tensor, Tuple[Tensor, Tensor], Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
hx, cx = hidden
outputs = []
inpSize = input.size()
# add bias for all timesteps instead of going step-by-step, results in a single reduction kernel in the backward
# FIXME matmul(x,y) + bias currently goes through jit AD, and backward formula in AD is not optimized for this
# case. Workaround with mm and views.
inpSize = input.size()
inputs = torch.mm(input.view(-1, inpSize[2]), wih.t()) + bih
inputs = inputs.view(inpSize[0], inpSize[1], -1).unbind(0)
hy, cy = hx[0], cx[0]
for seq_idx in range(len(inputs)):
hy, cy = premul_cell(inputs[seq_idx], (hy, cy), whh, bhh)
outputs += [hy]
return torch.stack(outputs), (hy.unsqueeze(0), cy.unsqueeze(0))
if script:
premul_cell = torch.jit.script(premul_cell)
dynamic_rnn = torch.jit.script(dynamic_rnn)
return dynamic_rnn
# simple: flat inputs (no tuples), no list to accumulate outputs
# useful mostly for benchmarking older JIT versions
def lstm_factory_simple(cell, script):
def dynamic_rnn(input, hx, cx, wih, whh, bih, bhh):
hy = hx # for scoping
cy = cx # for scoping
inputs = input.unbind(0)
for seq_idx in range(len(inputs)):
hy, cy = cell(inputs[seq_idx], hy, cy, wih, whh, bih, bhh)
return hy, cy
if script:
cell = torch.jit.script(cell)
dynamic_rnn = torch.jit.script(dynamic_rnn)
return dynamic_rnn
def lstm_factory_multilayer(cell, script):
def dynamic_rnn(input, hidden, params):
# type: (Tensor, Tuple[Tensor, Tensor], List[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
params_stride = 4 # NB: this assumes that biases are there
hx, cx = hidden
hy, cy = hidden # for scoping...
inputs, outputs = input.unbind(0), []
for layer in range(hx.size(0)):
hy = hx[layer]
cy = cx[layer]
base_idx = layer * params_stride
wih = params[base_idx]
whh = params[base_idx + 1]
bih = params[base_idx + 2]
bhh = params[base_idx + 3]
for seq_idx in range(len(inputs)):
hy, cy = cell(inputs[seq_idx], (hy, cy), wih, whh, bih, bhh)
outputs += [hy]
inputs, outputs = outputs, []
return torch.stack(inputs), (hy.unsqueeze(0), cy.unsqueeze(0))
if script:
cell = torch.jit.script(cell)
dynamic_rnn = torch.jit.script(dynamic_rnn)
return dynamic_rnn