-
Notifications
You must be signed in to change notification settings - Fork 84
/
Copy pathvrae.py
490 lines (377 loc) · 17.6 KB
/
vrae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
import numpy as np
import torch
from torch import nn, optim
from torch import distributions
from .base import BaseEstimator
from torch.utils.data import DataLoader
from torch.autograd import Variable
import os
class Encoder(nn.Module):
"""
Encoder network containing enrolled LSTM/GRU
:param number_of_features: number of input features
:param hidden_size: hidden size of the RNN
:param hidden_layer_depth: number of layers in RNN
:param latent_length: latent vector length
:param dropout: percentage of nodes to dropout
:param block: LSTM/GRU block
"""
def __init__(self, number_of_features, hidden_size, hidden_layer_depth, latent_length, dropout, block = 'LSTM'):
super(Encoder, self).__init__()
self.number_of_features = number_of_features
self.hidden_size = hidden_size
self.hidden_layer_depth = hidden_layer_depth
self.latent_length = latent_length
if block == 'LSTM':
self.model = nn.LSTM(self.number_of_features, self.hidden_size, self.hidden_layer_depth, dropout = dropout)
elif block == 'GRU':
self.model = nn.GRU(self.number_of_features, self.hidden_size, self.hidden_layer_depth, dropout = dropout)
else:
raise NotImplementedError
def forward(self, x):
"""Forward propagation of encoder. Given input, outputs the last hidden state of encoder
:param x: input to the encoder, of shape (sequence_length, batch_size, number_of_features)
:return: last hidden state of encoder, of shape (batch_size, hidden_size)
"""
_, (h_end, c_end) = self.model(x)
h_end = h_end[-1, :, :]
return h_end
class Lambda(nn.Module):
"""Lambda module converts output of encoder to latent vector
:param hidden_size: hidden size of the encoder
:param latent_length: latent vector length
"""
def __init__(self, hidden_size, latent_length):
super(Lambda, self).__init__()
self.hidden_size = hidden_size
self.latent_length = latent_length
self.hidden_to_mean = nn.Linear(self.hidden_size, self.latent_length)
self.hidden_to_logvar = nn.Linear(self.hidden_size, self.latent_length)
nn.init.xavier_uniform_(self.hidden_to_mean.weight)
nn.init.xavier_uniform_(self.hidden_to_logvar.weight)
def forward(self, cell_output):
"""Given last hidden state of encoder, passes through a linear layer, and finds the mean and variance
:param cell_output: last hidden state of encoder
:return: latent vector
"""
self.latent_mean = self.hidden_to_mean(cell_output)
self.latent_logvar = self.hidden_to_logvar(cell_output)
if self.training:
std = torch.exp(0.5 * self.latent_logvar)
eps = torch.randn_like(std)
return eps.mul(std).add_(self.latent_mean)
else:
return self.latent_mean
class Decoder(nn.Module):
"""Converts latent vector into output
:param sequence_length: length of the input sequence
:param batch_size: batch size of the input sequence
:param hidden_size: hidden size of the RNN
:param hidden_layer_depth: number of layers in RNN
:param latent_length: latent vector length
:param output_size: 2, one representing the mean, other log std dev of the output
:param block: GRU/LSTM - use the same which you've used in the encoder
:param dtype: Depending on cuda enabled/disabled, create the tensor
"""
def __init__(self, sequence_length, batch_size, hidden_size, hidden_layer_depth, latent_length, output_size, dtype, block='LSTM'):
super(Decoder, self).__init__()
self.hidden_size = hidden_size
self.batch_size = batch_size
self.sequence_length = sequence_length
self.hidden_layer_depth = hidden_layer_depth
self.latent_length = latent_length
self.output_size = output_size
self.dtype = dtype
if block == 'LSTM':
self.model = nn.LSTM(1, self.hidden_size, self.hidden_layer_depth)
elif block == 'GRU':
self.model = nn.GRU(1, self.hidden_size, self.hidden_layer_depth)
else:
raise NotImplementedError
self.latent_to_hidden = nn.Linear(self.latent_length, self.hidden_size)
self.hidden_to_output = nn.Linear(self.hidden_size, self.output_size)
self.decoder_inputs = torch.zeros(self.sequence_length, self.batch_size, 1, requires_grad=True).type(self.dtype)
self.c_0 = torch.zeros(self.hidden_layer_depth, self.batch_size, self.hidden_size, requires_grad=True).type(self.dtype)
nn.init.xavier_uniform_(self.latent_to_hidden.weight)
nn.init.xavier_uniform_(self.hidden_to_output.weight)
def forward(self, latent):
"""Converts latent to hidden to output
:param latent: latent vector
:return: outputs consisting of mean and std dev of vector
"""
h_state = self.latent_to_hidden(latent)
if isinstance(self.model, nn.LSTM):
h_0 = torch.stack([h_state for _ in range(self.hidden_layer_depth)])
decoder_output, _ = self.model(self.decoder_inputs, (h_0, self.c_0))
elif isinstance(self.model, nn.GRU):
h_0 = torch.stack([h_state for _ in range(self.hidden_layer_depth)])
decoder_output, _ = self.model(self.decoder_inputs, h_0)
else:
raise NotImplementedError
out = self.hidden_to_output(decoder_output)
return out
def _assert_no_grad(tensor):
assert not tensor.requires_grad, \
"nn criterions don't compute the gradient w.r.t. targets - please " \
"mark these tensors as not requiring gradients"
class VRAE(BaseEstimator, nn.Module):
"""Variational recurrent auto-encoder. This module is used for dimensionality reduction of timeseries
:param sequence_length: length of the input sequence
:param number_of_features: number of input features
:param hidden_size: hidden size of the RNN
:param hidden_layer_depth: number of layers in RNN
:param latent_length: latent vector length
:param batch_size: number of timeseries in a single batch
:param learning_rate: the learning rate of the module
:param block: GRU/LSTM to be used as a basic building block
:param n_epochs: Number of iterations/epochs
:param dropout_rate: The probability of a node being dropped-out
:param optimizer: ADAM/ SGD optimizer to reduce the loss function
:param loss: SmoothL1Loss / MSELoss / ReconLoss / any custom loss which inherits from `_Loss` class
:param boolean cuda: to be run on GPU or not
:param print_every: The number of iterations after which loss should be printed
:param boolean clip: Gradient clipping to overcome explosion
:param max_grad_norm: The grad-norm to be clipped
:param dload: Download directory where models are to be dumped
"""
def __init__(self, sequence_length, number_of_features, hidden_size=90, hidden_layer_depth=2, latent_length=20,
batch_size=32, learning_rate=0.005, block='LSTM',
n_epochs=5, dropout_rate=0., optimizer='Adam', loss='MSELoss',
cuda=False, print_every=100, clip=True, max_grad_norm=5, dload='.'):
super(VRAE, self).__init__()
self.dtype = torch.FloatTensor
self.use_cuda = cuda
if not torch.cuda.is_available() and self.use_cuda:
self.use_cuda = False
if self.use_cuda:
self.dtype = torch.cuda.FloatTensor
self.encoder = Encoder(number_of_features = number_of_features,
hidden_size=hidden_size,
hidden_layer_depth=hidden_layer_depth,
latent_length=latent_length,
dropout=dropout_rate,
block=block)
self.lmbd = Lambda(hidden_size=hidden_size,
latent_length=latent_length)
self.decoder = Decoder(sequence_length=sequence_length,
batch_size = batch_size,
hidden_size=hidden_size,
hidden_layer_depth=hidden_layer_depth,
latent_length=latent_length,
output_size=number_of_features,
block=block,
dtype=self.dtype)
self.sequence_length = sequence_length
self.hidden_size = hidden_size
self.hidden_layer_depth = hidden_layer_depth
self.latent_length = latent_length
self.batch_size = batch_size
self.learning_rate = learning_rate
self.n_epochs = n_epochs
self.print_every = print_every
self.clip = clip
self.max_grad_norm = max_grad_norm
self.is_fitted = False
self.dload = dload
if self.use_cuda:
self.cuda()
if optimizer == 'Adam':
self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
elif optimizer == 'SGD':
self.optimizer = optim.SGD(self.parameters(), lr=learning_rate)
else:
raise ValueError('Not a recognized optimizer')
if loss == 'SmoothL1Loss':
self.loss_fn = nn.SmoothL1Loss(size_average=False)
elif loss == 'MSELoss':
self.loss_fn = nn.MSELoss(size_average=False)
def __repr__(self):
return """VRAE(n_epochs={n_epochs},batch_size={batch_size},cuda={cuda})""".format(
n_epochs=self.n_epochs,
batch_size=self.batch_size,
cuda=self.use_cuda)
def forward(self, x):
"""
Forward propagation which involves one pass from inputs to encoder to lambda to decoder
:param x:input tensor
:return: the decoded output, latent vector
"""
cell_output = self.encoder(x)
latent = self.lmbd(cell_output)
x_decoded = self.decoder(latent)
return x_decoded, latent
def _rec(self, x_decoded, x, loss_fn):
"""
Compute the loss given output x decoded, input x and the specified loss function
:param x_decoded: output of the decoder
:param x: input to the encoder
:param loss_fn: loss function specified
:return: joint loss, reconstruction loss and kl-divergence loss
"""
latent_mean, latent_logvar = self.lmbd.latent_mean, self.lmbd.latent_logvar
kl_loss = -0.5 * torch.mean(1 + latent_logvar - latent_mean.pow(2) - latent_logvar.exp())
recon_loss = loss_fn(x_decoded, x)
return kl_loss + recon_loss, recon_loss, kl_loss
def compute_loss(self, X):
"""
Given input tensor, forward propagate, compute the loss, and backward propagate.
Represents the lifecycle of a single iteration
:param X: Input tensor
:return: total loss, reconstruction loss, kl-divergence loss and original input
"""
x = Variable(X[:,:,:].type(self.dtype), requires_grad = True)
x_decoded, _ = self(x)
loss, recon_loss, kl_loss = self._rec(x_decoded, x.detach(), self.loss_fn)
return loss, recon_loss, kl_loss, x
def _train(self, train_loader):
"""
For each epoch, given the batch_size, run this function batch_size * num_of_batches number of times
:param train_loader:input train loader with shuffle
:return:
"""
self.train()
epoch_loss = 0
t = 0
for t, X in enumerate(train_loader):
# Index first element of array to return tensor
X = X[0]
# required to swap axes, since dataloader gives output in (batch_size x seq_len x num_of_features)
X = X.permute(1,0,2)
self.optimizer.zero_grad()
loss, recon_loss, kl_loss, _ = self.compute_loss(X)
loss.backward()
if self.clip:
torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm = self.max_grad_norm)
# accumulator
epoch_loss += loss.item()
self.optimizer.step()
if (t + 1) % self.print_every == 0:
print('Batch %d, loss = %.4f, recon_loss = %.4f, kl_loss = %.4f' % (t + 1, loss.item(),
recon_loss.item(), kl_loss.item()))
print('Average loss: {:.4f}'.format(epoch_loss / t))
def fit(self, dataset, save = False):
"""
Calls `_train` function over a fixed number of epochs, specified by `n_epochs`
:param dataset: `Dataset` object
:param bool save: If true, dumps the trained model parameters as pickle file at `dload` directory
:return:
"""
train_loader = DataLoader(dataset = dataset,
batch_size = self.batch_size,
shuffle = True,
drop_last=True)
for i in range(self.n_epochs):
print('Epoch: %s' % i)
self._train(train_loader)
self.is_fitted = True
if save:
self.save('model.pth')
def _batch_transform(self, x):
"""
Passes the given input tensor into encoder and lambda function
:param x: input batch tensor
:return: intermediate latent vector
"""
return self.lmbd(
self.encoder(
Variable(x.type(self.dtype), requires_grad = False)
)
).cpu().data.numpy()
def _batch_reconstruct(self, x):
"""
Passes the given input tensor into encoder, lambda and decoder function
:param x: input batch tensor
:return: reconstructed output tensor
"""
x = Variable(x.type(self.dtype), requires_grad = False)
x_decoded, _ = self(x)
return x_decoded.cpu().data.numpy()
def reconstruct(self, dataset, save = False):
"""
Given input dataset, creates dataloader, runs dataloader on `_batch_reconstruct`
Prerequisite is that model has to be fit
:param dataset: input dataset who's output vectors are to be obtained
:param bool save: If true, dumps the output vector dataframe as a pickle file
:return:
"""
self.eval()
test_loader = DataLoader(dataset = dataset,
batch_size = self.batch_size,
shuffle = False,
drop_last=True) # Don't shuffle for test_loader
if self.is_fitted:
with torch.no_grad():
x_decoded = []
for t, x in enumerate(test_loader):
x = x[0]
x = x.permute(1, 0, 2)
x_decoded_each = self._batch_reconstruct(x)
x_decoded.append(x_decoded_each)
x_decoded = np.concatenate(x_decoded, axis=1)
if save:
if os.path.exists(self.dload):
pass
else:
os.mkdir(self.dload)
x_decoded.dump(self.dload + '/z_run.pkl')
return x_decoded
raise RuntimeError('Model needs to be fit')
def transform(self, dataset, save = False):
"""
Given input dataset, creates dataloader, runs dataloader on `_batch_transform`
Prerequisite is that model has to be fit
:param dataset: input dataset who's latent vectors are to be obtained
:param bool save: If true, dumps the latent vector dataframe as a pickle file
:return:
"""
self.eval()
test_loader = DataLoader(dataset = dataset,
batch_size = self.batch_size,
shuffle = False,
drop_last=True) # Don't shuffle for test_loader
if self.is_fitted:
with torch.no_grad():
z_run = []
for t, x in enumerate(test_loader):
x = x[0]
x = x.permute(1, 0, 2)
z_run_each = self._batch_transform(x)
z_run.append(z_run_each)
z_run = np.concatenate(z_run, axis=0)
if save:
if os.path.exists(self.dload):
pass
else:
os.mkdir(self.dload)
z_run.dump(self.dload + '/z_run.pkl')
return z_run
raise RuntimeError('Model needs to be fit')
def fit_transform(self, dataset, save = False):
"""
Combines the `fit` and `transform` functions above
:param dataset: Dataset on which fit and transform have to be performed
:param bool save: If true, dumps the model and latent vectors as pickle file
:return: latent vectors for input dataset
"""
self.fit(dataset, save = save)
return self.transform(dataset, save = save)
def save(self, file_name):
"""
Pickles the model parameters to be retrieved later
:param file_name: the filename to be saved as,`dload` serves as the download directory
:return: None
"""
PATH = self.dload + '/' + file_name
if os.path.exists(self.dload):
pass
else:
os.mkdir(self.dload)
torch.save(self.state_dict(), PATH)
def load(self, PATH):
"""
Loads the model's parameters from the path mentioned
:param PATH: Should contain pickle file
:return: None
"""
self.is_fitted = True
self.load_state_dict(torch.load(PATH))