-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathalexlstm.py
53 lines (47 loc) · 1.73 KB
/
alexlstm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch.nn as nn
from torch.autograd import Variable as V
import torch as th
from torchvision import models
import os
import torch.optim as optim
import random
import numpy as np
import cv2 as cv2
class AlexLSTM(nn.Module):
def __init__(self, n_layers=3, h_size=1600):
super(AlexLSTM, self).__init__()
self.h_size = h_size
self.n_layers = n_layers
alexnet = models.alexnet(pretrained=True)
self.conv = nn.Sequential(*list(alexnet.children())[:-1])
self.lstm = nn.LSTM(12288, h_size, dropout=0.3, num_layers=n_layers) # dropout = 0.3
self.fc = nn.Sequential(
nn.Linear(h_size, 512),
nn.ReLU(),
nn.Dropout(0.2), # dropout = 0.2
nn.Linear(512, 64),
nn.ReLU(),
nn.Dropout(0.2), # dropout = 0.2
nn.Linear(64, 1)
)
def forward(self, x):
# (batch_size, 3, time_stamp, 480, 640)
batch_size, timesteps = x.size()[0], x.size()[2]
state = self._init_state(b_size=batch_size)
convs = []
for t in range(timesteps):
conv = self.conv(x[:, :, t, :, :])
conv = conv.view(batch_size, -1)
convs.append(conv)
convs = th.stack(convs, 0)
lstm, _ = self.lstm(convs, state)
print("lstm output shape : ",lstm[1:].size())
logit = self.fc(lstm[1:])
logit = logit.transpose(1,0).squeeze(2) # batch_size, seq_len - 1
return logit
def _init_state(self, b_size=1):
weight = next(self.parameters()).data
return (
V(weight.new(self.n_layers, b_size, self.h_size).normal_(0.0, 0.01)),
V(weight.new(self.n_layers, b_size, self.h_size).normal_(0.0, 0.01))
)