Skip to content

Commit

Permalink
initial stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
Braden Hoagland committed May 11, 2020
1 parent d4fe2e6 commit 86e764f
Show file tree
Hide file tree
Showing 7 changed files with 315 additions and 0 deletions.
57 changes: 57 additions & 0 deletions data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import numpy as np
import torch
from torch.distributions import MultivariateNormal


##########################
# DISTRIBUTION FUNCTIONS #
##########################
def set_cov_matrix(stds):
n = len(stds)
cov_matrix = torch.zeros((n, n))
for i in range(n):
cov_matrix[i][i] = stds[i]**2
return cov_matrix

def make_normal_dist(mean, std):
return MultivariateNormal(torch.FloatTensor(mean), set_cov_matrix(std))


#############
# GOOD DATA #
#############
# μ = [[1, 6], [-5, -10]]
# σ = [[2, 3], [4, 2]]
μ = [[1, 6]]
σ = [[2, 3]]
good_dists = [make_normal_dist(mean, std) for mean, std in zip(μ, σ)]


##################
# ANOMALOUS DATA #
##################
μ = [[15, 5]]
σ = [[1, 2]]
bad_dists = [make_normal_dist(mean, std) for mean, std in zip(μ, σ)]


######################
# SAMPLING FUNCTIONS #
######################

# sample from a given array of distributions
def sample_from_dists(batch_size, dists):
data = []
n = len(dists)
for dist in dists:
data.append(dist.sample(torch.Size([batch_size // n])))
data = torch.stack(data)
return data.view(-1, data.shape[2])

# sample from both distributions, with the given probability of choosing bad data
def sample_data(batch_size, bad_data_prob):
data = sample_from_dists(batch_size, good_dists)
for i in range(len(data)):
if np.random.random() < bad_data_prob:
data[i] = sample_from_dists(1, bad_dists)
return data
88 changes: 88 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import torch.nn as nn
import torch.optim as optim
from torch import randn

# TODO: remove these
####################
n_in = 2
n_h = 64
n_latent = 2
n_noise = 2
####################

class AutoEncoder(nn.Module):
def __init__(self, lr):
super().__init__()

self.encode = nn.Sequential(
nn.Linear(n_in, n_h),
nn.Tanh(),
nn.Linear(n_h, n_h),
nn.Tanh(),
nn.Linear(n_h, n_latent)
)

self.decode = nn.Sequential(
nn.Linear(n_latent, n_h),
nn.Tanh(),
nn.Linear(n_h, n_h),
nn.Tanh(),
nn.Linear(n_h, n_in)
)

self.optimizer = optim.Adam(self.parameters(), lr=lr)

def forward(self, x):
return self.decode(self.encode(x))

def minimize(self, loss):
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()


class Classifier(nn.Module):
def __init__(self, lr):
super().__init__()

self.main = nn.Sequential(
nn.Linear(n_latent, n_h),
nn.Tanh(),
nn.Linear(n_h, n_h),
nn.Tanh(),
nn.Linear(n_h, 1),
nn.Sigmoid()
)

self.optimizer = optim.Adam(self.parameters(), lr=lr)

def forward(self, x):
return self.main(x)

def maximize(self, loss):
self.optimizer.zero_grad()
(-loss).backward()
self.optimizer.step()


class Generator(nn.Module):
def __init__(self, lr):
super().__init__()

self.main = nn.Sequential(
nn.Linear(n_noise, n_h),
nn.Tanh(),
nn.Linear(n_h, n_h),
nn.Tanh(),
nn.Linear(n_h, n_in)
)

self.optimizer = optim.Adam(self.parameters(), lr=lr)

def forward(self, batch_size):
return self.main(randn(batch_size, n_noise))

def minimize(self, loss):
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
Binary file added model_params/auto_encoder
Binary file not shown.
Binary file added model_params/classifier
Binary file not shown.
Binary file added model_params/generator
Binary file not shown.
123 changes: 123 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import torch

from data import sample_data
from model import AutoEncoder, Classifier, Generator
from visualize import scatter, line, heatmap


# hyperparameters
bad_data_prob = 0.02
batch_size = 128
lr = 3e-4


def train(network_class, network_name, epochs, train_step, vis=None, use_saved_model=True):
net = network_class(lr)

save_path = f'model_params/{network_name}'

# if a saved model exists, use that
if use_saved_model:
try:
saved_params = torch.load(save_path)
net.load_state_dict(saved_params)
return net
# if not, train a new one and save it
except FileNotFoundError:
pass

# training loop
for epoch in range(epochs):
# Take an optimization step and visualize if necessary
train_step(net)
if vis is not None and epoch % 100 == 99:
with torch.no_grad(): vis(net)

# save final model
torch.save(net.state_dict(), save_path)

return net


def map(fn, name):
with torch.no_grad():
x_range = torch.arange(-20, 20)
y_range = torch.arange(-20, 20)
arr = torch.zeros((len(x_range), len(y_range)))
for i in range(len(x_range)):
for j in range(len(y_range)):
x, y = x_range[i], y_range[j]
arr[i][j] = fn(torch.FloatTensor([x, y]))
heatmap(arr, name, x_range.tolist(), y_range.tolist())


#########################
# AUTO-ENCODER TRAINING #
#########################

def auto_encoder_step(auto_encoder):
inp = sample_data(batch_size, bad_data_prob)
out = auto_encoder(inp)

# minimizing step on MSE loss
loss = ((out - inp) ** 2).mean()
auto_encoder.minimize(loss)

# plot loss
line(loss.item(), 'AE loss')


######################
# GENERATOR TRAINING #
######################

def generator_step(auto_encoder, generator):
data = sample_data(batch_size, bad_data_prob)

# improve generator
loss = ((data - generator(batch_size))**2).mean()
generator.minimize(loss)

# plot objective
line(loss.item(), 'Generator Loss')


def generator_vis(auto_encoder, generator):
with torch.no_grad(): scatter(auto_encoder.decode(generator(500)), 'Generated', color=[255,0,0])


#######################
# CLASSIFIER TRAINING #
#######################

def classifier_step(auto_encoder, classifier):
# sample data and use AE to encode it
data = auto_encoder.encode(sample_data(batch_size, bad_data_prob))

# improve classifier
obj = torch.log(classifier(data)).mean()
classifier.maximize(obj)

# plot objective
line(obj.item(), 'Classifier Objective')


def classifier_vis(auto_encoder, classifier):
map(lambda x: classifier(auto_encoder.encode(x)), 'Classification Probabilities')


#################
# FULL PIPELINE #
#################

# plot a large sample of the data (bad points included)
scatter(sample_data(500, bad_data_prob))

# train AE and classifier
auto_encoder = train(AutoEncoder, 'auto_encoder', 1000, auto_encoder_step)
generator = train(Generator, 'generator', 1000, lambda x: generator_step(auto_encoder, x), lambda x: generator_vis(auto_encoder, x), use_saved_model=False)
classifier = train(Classifier, 'classifier', 1000, lambda x: classifier_step(auto_encoder, x), lambda x: classifier_vis(auto_encoder, x))

# testing
map(lambda x: torch.norm(auto_encoder(x) - x, 2), 'AE Error')
map(lambda x: classifier(auto_encoder.encode(x)), 'Classification Probabilities')
47 changes: 47 additions & 0 deletions visualize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import numpy as np
from visdom import Visdom

viz = Visdom()


def scatter(points, win='data', color=[0,0,0]):
viz.scatter(
X=points,
win=win,
opts=dict(
title=win,
markersize=4,
markerborderwidth=0,
markercolor=np.array([color]),
xtickmin=-20,
xtickmax=20,
ytickmin=-20,
ytickmax=20
)
)


line_data = {}
def line(point, win):
if win not in line_data: line_data[win] = []
line_data[win].append(point)
viz.line(
X=np.arange(len(line_data[win])),
Y=np.array(line_data[win]),
win=win,
opts=dict(
title=win
)
)


def heatmap(points, win, x_labels, y_labels):
viz.heatmap(
X=points.numpy().transpose(),
win=win,
opts=dict(
title=win,
columnnames=x_labels,
rownames=y_labels
)
)

0 comments on commit 86e764f

Please sign in to comment.