-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathdae.py
191 lines (151 loc) · 4.63 KB
/
dae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
"""Implementation of a Deep Autoencoder"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import *
class DAE(nn.Module):
"""A Deep Autoencoder that takes a list of RBMs as input"""
def __init__(self, models):
"""Create a deep autoencoder based on a list of RBM models
Parameters
----------
models: list[RBM]
a list of RBM models to use for autoencoding
"""
super(DAE, self).__init__()
# extract weights from each model
encoders = []
encoder_biases = []
decoders = []
decoder_biases = []
for model in models:
encoders.append(nn.Parameter(model.W.clone()))
encoder_biases.append(nn.Parameter(model.h_bias.clone()))
decoders.append(nn.Parameter(model.W.clone()))
decoder_biases.append(nn.Parameter(model.v_bias.clone()))
# build encoders and decoders
self.encoders = nn.ParameterList(encoders)
self.encoder_biases = nn.ParameterList(encoder_biases)
self.decoders = nn.ParameterList(reversed(decoders))
self.decoder_biases = nn.ParameterList(reversed(decoder_biases))
def forward(self, v):
"""Forward step
Parameters
----------
v: Tensor
input tensor
Returns
-------
Tensor
a reconstruction of v from the autoencoder
"""
# encode
p_h = self.encode(v)
# decode
p_v = self.decode(p_h)
return p_v
def encode(self, v): # for visualization, encode without sigmoid
"""Encode input
Parameters
----------
v: Tensor
visible input tensor
Returns
-------
Tensor
the activations of the last layer
"""
p_v = v
activation = v
for i in range(len(self.encoders)):
W = self.encoders[i]
h_bias = self.encoder_biases[i]
activation = torch.mm(p_v, W) + h_bias
p_v = torch.sigmoid(activation)
# for the last layer, we want to return the activation directly rather than the sigmoid
return activation
def decode(self, h):
"""Encode hidden layer
Parameters
----------
h: Tensor
activations from last hidden layer
Returns
-------
Tensor
reconstruction of original input based on h
"""
p_h = h
for i in range(len(self.encoders)):
W = self.decoders[i]
v_bias = self.decoder_biases[i]
activation = torch.mm(p_h, W.t()) + v_bias
p_h = torch.sigmoid(activation)
return p_h
class Naive_DAE(nn.Module):
"""A Naive implementation of the DAE to be trained without RBMs"""
def __init__(self, layers):
"""Initialize the DAE
Parameters
----------
layers: list[int]
the number of dimensions in each layer of the DAE
"""
super(Naive_DAE, self).__init__()
self.layers = layers
encoders = []
decoders = []
prev_layer = layers[0]
for layer in layers[1:]:
encoders.append(
nn.Linear(in_features=prev_layer, out_features=layer))
decoders.append(
nn.Linear(in_features=layer, out_features=prev_layer))
prev_layer = layer
self.encoders = nn.ModuleList(encoders)
self.decoders = nn.ModuleList(reversed(decoders))
def forward(self, x):
"""Forward step
Parameters
----------
x: Tensor
input tensor
Returns
-------
Tensor
a reconstructed version of x
"""
x_encoded = self.encode(x)
x_reconstructed = self.decode(x_encoded)
return x_reconstructed
def encode(self, x):
"""Encode the input x
Parameters
----------
x: Tensor
input to encode
Returns
-------
Tensor
encoded input
"""
for i, enc in enumerate(self.encoders):
if i == len(self.encoders) - 1:
x = enc(x)
else:
x = torch.sigmoid(enc(x))
return x
def decode(self, x):
"""Decode the representation x
Parameters
----------
x: Tensor
input to decode
Returns
-------
Tensor
decoded input
"""
for dec in self.decoders:
x = torch.sigmoid(dec(x))
return x