-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Braden Hoagland
committed
May 11, 2020
1 parent
d4fe2e6
commit 86e764f
Showing
7 changed files
with
315 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import numpy as np | ||
import torch | ||
from torch.distributions import MultivariateNormal | ||
|
||
|
||
########################## | ||
# DISTRIBUTION FUNCTIONS # | ||
########################## | ||
def set_cov_matrix(stds): | ||
n = len(stds) | ||
cov_matrix = torch.zeros((n, n)) | ||
for i in range(n): | ||
cov_matrix[i][i] = stds[i]**2 | ||
return cov_matrix | ||
|
||
def make_normal_dist(mean, std): | ||
return MultivariateNormal(torch.FloatTensor(mean), set_cov_matrix(std)) | ||
|
||
|
||
############# | ||
# GOOD DATA # | ||
############# | ||
# μ = [[1, 6], [-5, -10]] | ||
# σ = [[2, 3], [4, 2]] | ||
μ = [[1, 6]] | ||
σ = [[2, 3]] | ||
good_dists = [make_normal_dist(mean, std) for mean, std in zip(μ, σ)] | ||
|
||
|
||
################## | ||
# ANOMALOUS DATA # | ||
################## | ||
μ = [[15, 5]] | ||
σ = [[1, 2]] | ||
bad_dists = [make_normal_dist(mean, std) for mean, std in zip(μ, σ)] | ||
|
||
|
||
###################### | ||
# SAMPLING FUNCTIONS # | ||
###################### | ||
|
||
# sample from a given array of distributions | ||
def sample_from_dists(batch_size, dists): | ||
data = [] | ||
n = len(dists) | ||
for dist in dists: | ||
data.append(dist.sample(torch.Size([batch_size // n]))) | ||
data = torch.stack(data) | ||
return data.view(-1, data.shape[2]) | ||
|
||
# sample from both distributions, with the given probability of choosing bad data | ||
def sample_data(batch_size, bad_data_prob): | ||
data = sample_from_dists(batch_size, good_dists) | ||
for i in range(len(data)): | ||
if np.random.random() < bad_data_prob: | ||
data[i] = sample_from_dists(1, bad_dists) | ||
return data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
from torch import randn | ||
|
||
# TODO: remove these | ||
#################### | ||
n_in = 2 | ||
n_h = 64 | ||
n_latent = 2 | ||
n_noise = 2 | ||
#################### | ||
|
||
class AutoEncoder(nn.Module): | ||
def __init__(self, lr): | ||
super().__init__() | ||
|
||
self.encode = nn.Sequential( | ||
nn.Linear(n_in, n_h), | ||
nn.Tanh(), | ||
nn.Linear(n_h, n_h), | ||
nn.Tanh(), | ||
nn.Linear(n_h, n_latent) | ||
) | ||
|
||
self.decode = nn.Sequential( | ||
nn.Linear(n_latent, n_h), | ||
nn.Tanh(), | ||
nn.Linear(n_h, n_h), | ||
nn.Tanh(), | ||
nn.Linear(n_h, n_in) | ||
) | ||
|
||
self.optimizer = optim.Adam(self.parameters(), lr=lr) | ||
|
||
def forward(self, x): | ||
return self.decode(self.encode(x)) | ||
|
||
def minimize(self, loss): | ||
self.optimizer.zero_grad() | ||
loss.backward() | ||
self.optimizer.step() | ||
|
||
|
||
class Classifier(nn.Module): | ||
def __init__(self, lr): | ||
super().__init__() | ||
|
||
self.main = nn.Sequential( | ||
nn.Linear(n_latent, n_h), | ||
nn.Tanh(), | ||
nn.Linear(n_h, n_h), | ||
nn.Tanh(), | ||
nn.Linear(n_h, 1), | ||
nn.Sigmoid() | ||
) | ||
|
||
self.optimizer = optim.Adam(self.parameters(), lr=lr) | ||
|
||
def forward(self, x): | ||
return self.main(x) | ||
|
||
def maximize(self, loss): | ||
self.optimizer.zero_grad() | ||
(-loss).backward() | ||
self.optimizer.step() | ||
|
||
|
||
class Generator(nn.Module): | ||
def __init__(self, lr): | ||
super().__init__() | ||
|
||
self.main = nn.Sequential( | ||
nn.Linear(n_noise, n_h), | ||
nn.Tanh(), | ||
nn.Linear(n_h, n_h), | ||
nn.Tanh(), | ||
nn.Linear(n_h, n_in) | ||
) | ||
|
||
self.optimizer = optim.Adam(self.parameters(), lr=lr) | ||
|
||
def forward(self, batch_size): | ||
return self.main(randn(batch_size, n_noise)) | ||
|
||
def minimize(self, loss): | ||
self.optimizer.zero_grad() | ||
loss.backward() | ||
self.optimizer.step() |
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
import torch | ||
|
||
from data import sample_data | ||
from model import AutoEncoder, Classifier, Generator | ||
from visualize import scatter, line, heatmap | ||
|
||
|
||
# hyperparameters | ||
bad_data_prob = 0.02 | ||
batch_size = 128 | ||
lr = 3e-4 | ||
|
||
|
||
def train(network_class, network_name, epochs, train_step, vis=None, use_saved_model=True): | ||
net = network_class(lr) | ||
|
||
save_path = f'model_params/{network_name}' | ||
|
||
# if a saved model exists, use that | ||
if use_saved_model: | ||
try: | ||
saved_params = torch.load(save_path) | ||
net.load_state_dict(saved_params) | ||
return net | ||
# if not, train a new one and save it | ||
except FileNotFoundError: | ||
pass | ||
|
||
# training loop | ||
for epoch in range(epochs): | ||
# Take an optimization step and visualize if necessary | ||
train_step(net) | ||
if vis is not None and epoch % 100 == 99: | ||
with torch.no_grad(): vis(net) | ||
|
||
# save final model | ||
torch.save(net.state_dict(), save_path) | ||
|
||
return net | ||
|
||
|
||
def map(fn, name): | ||
with torch.no_grad(): | ||
x_range = torch.arange(-20, 20) | ||
y_range = torch.arange(-20, 20) | ||
arr = torch.zeros((len(x_range), len(y_range))) | ||
for i in range(len(x_range)): | ||
for j in range(len(y_range)): | ||
x, y = x_range[i], y_range[j] | ||
arr[i][j] = fn(torch.FloatTensor([x, y])) | ||
heatmap(arr, name, x_range.tolist(), y_range.tolist()) | ||
|
||
|
||
######################### | ||
# AUTO-ENCODER TRAINING # | ||
######################### | ||
|
||
def auto_encoder_step(auto_encoder): | ||
inp = sample_data(batch_size, bad_data_prob) | ||
out = auto_encoder(inp) | ||
|
||
# minimizing step on MSE loss | ||
loss = ((out - inp) ** 2).mean() | ||
auto_encoder.minimize(loss) | ||
|
||
# plot loss | ||
line(loss.item(), 'AE loss') | ||
|
||
|
||
###################### | ||
# GENERATOR TRAINING # | ||
###################### | ||
|
||
def generator_step(auto_encoder, generator): | ||
data = sample_data(batch_size, bad_data_prob) | ||
|
||
# improve generator | ||
loss = ((data - generator(batch_size))**2).mean() | ||
generator.minimize(loss) | ||
|
||
# plot objective | ||
line(loss.item(), 'Generator Loss') | ||
|
||
|
||
def generator_vis(auto_encoder, generator): | ||
with torch.no_grad(): scatter(auto_encoder.decode(generator(500)), 'Generated', color=[255,0,0]) | ||
|
||
|
||
####################### | ||
# CLASSIFIER TRAINING # | ||
####################### | ||
|
||
def classifier_step(auto_encoder, classifier): | ||
# sample data and use AE to encode it | ||
data = auto_encoder.encode(sample_data(batch_size, bad_data_prob)) | ||
|
||
# improve classifier | ||
obj = torch.log(classifier(data)).mean() | ||
classifier.maximize(obj) | ||
|
||
# plot objective | ||
line(obj.item(), 'Classifier Objective') | ||
|
||
|
||
def classifier_vis(auto_encoder, classifier): | ||
map(lambda x: classifier(auto_encoder.encode(x)), 'Classification Probabilities') | ||
|
||
|
||
################# | ||
# FULL PIPELINE # | ||
################# | ||
|
||
# plot a large sample of the data (bad points included) | ||
scatter(sample_data(500, bad_data_prob)) | ||
|
||
# train AE and classifier | ||
auto_encoder = train(AutoEncoder, 'auto_encoder', 1000, auto_encoder_step) | ||
generator = train(Generator, 'generator', 1000, lambda x: generator_step(auto_encoder, x), lambda x: generator_vis(auto_encoder, x), use_saved_model=False) | ||
classifier = train(Classifier, 'classifier', 1000, lambda x: classifier_step(auto_encoder, x), lambda x: classifier_vis(auto_encoder, x)) | ||
|
||
# testing | ||
map(lambda x: torch.norm(auto_encoder(x) - x, 2), 'AE Error') | ||
map(lambda x: classifier(auto_encoder.encode(x)), 'Classification Probabilities') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import numpy as np | ||
from visdom import Visdom | ||
|
||
viz = Visdom() | ||
|
||
|
||
def scatter(points, win='data', color=[0,0,0]): | ||
viz.scatter( | ||
X=points, | ||
win=win, | ||
opts=dict( | ||
title=win, | ||
markersize=4, | ||
markerborderwidth=0, | ||
markercolor=np.array([color]), | ||
xtickmin=-20, | ||
xtickmax=20, | ||
ytickmin=-20, | ||
ytickmax=20 | ||
) | ||
) | ||
|
||
|
||
line_data = {} | ||
def line(point, win): | ||
if win not in line_data: line_data[win] = [] | ||
line_data[win].append(point) | ||
viz.line( | ||
X=np.arange(len(line_data[win])), | ||
Y=np.array(line_data[win]), | ||
win=win, | ||
opts=dict( | ||
title=win | ||
) | ||
) | ||
|
||
|
||
def heatmap(points, win, x_labels, y_labels): | ||
viz.heatmap( | ||
X=points.numpy().transpose(), | ||
win=win, | ||
opts=dict( | ||
title=win, | ||
columnnames=x_labels, | ||
rownames=y_labels | ||
) | ||
) |