forked from utiasSTARS/so3_learning
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added 1D uncertainty with updated hydranet-sigma
- Loading branch information
Showing
19 changed files
with
1,485 additions
and
24 deletions.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# deep-uncertainty |
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import pandas as pd | ||
import os | ||
|
||
f = pd.read_csv('figs/stats.csv', mangle_dupe_cols=True) | ||
#dropout_nll = np.asarray(f['Dropout+AC0-NLL']).reshape((-1,1)) | ||
#dropout_mse = np.asarray(f['Dropout+AC0-MSE']).reshape((-1,1)) | ||
#ensemble_nll = np.asarray(f['Ensemble+AC0-NLL']).reshape((-1,1)) | ||
#ensemble_mse = np.asarray(f['Ensemble+AC0-MSE']).reshape((-1,1)) | ||
#sigma_nll = np.asarray(f['Sigma+AC0-NLL']).reshape((-1,1)) | ||
#sigma_mse = np.asarray(f['Sigma+AC0-MSE']).reshape((-1,1)) | ||
#hydranet_nll = np.asarray(f['HydraNet+AC0-NLL']).reshape((-1,1)) | ||
#hydranet_mse = np.asarray(f['HydraNet+AC0-MSE']).reshape((-1,1)) | ||
#hydra_sigma_nll = np.asarray(f['HydraNet+AC0-Sigma+AC0-NLL']).reshape((-1,1)) | ||
#hydra_sigma_mse = np.asarray(f['HydraNet+AC0-Sigma+AC0-MSE']).reshape((-1,1)) | ||
dropout_nll = np.asarray(f['Dropout-NLL']).reshape((-1,1)) | ||
dropout_mse = np.asarray(f['Dropout-MSE']).reshape((-1,1)) | ||
ensemble_nll = np.asarray(f['Ensemble-NLL']).reshape((-1,1)) | ||
ensemble_mse = np.asarray(f['Ensemble-MSE']).reshape((-1,1)) | ||
sigma_nll = np.asarray(f['Sigma-NLL']).reshape((-1,1)) | ||
sigma_mse = np.asarray(f['Sigma-MSE']).reshape((-1,1)) | ||
hydranet_nll = np.asarray(f['HydraNet-NLL']).reshape((-1,1)) | ||
hydranet_mse = np.asarray(f['HydraNet-MSE']).reshape((-1,1)) | ||
hydra_sigma_nll = np.asarray(f['HydraNet-Sigma-NLL']).reshape((-1,1)) | ||
hydra_sigma_mse = np.asarray(f['HydraNet-Sigma-MSE']).reshape((-1,1)) | ||
|
||
nll_losses = np.hstack((dropout_nll, ensemble_nll, sigma_nll, hydranet_nll, hydra_sigma_nll)) | ||
plt.figure() | ||
plt.clf() | ||
plt.rc('text', usetex=True) | ||
plt.rc('font', family='serif') | ||
plt.boxplot(nll_losses) | ||
#plt.ylim(0,100) | ||
plt.title('NLL') | ||
plt.grid() | ||
plt.xticks([1, 2, 3, 4, 5], ['Dropout', 'Ensemble', 'Sigma', 'HydraNet', 'HydraNet-Sigma']) | ||
plt.xticks(rotation=20) | ||
plt.yscale('log') | ||
plt.savefig('figs/uncertainty-NLL.pdf', format='pdf', dpi=800, bbox_inches='tight') | ||
|
||
mse_losses = np.hstack((dropout_mse, ensemble_mse, sigma_mse, hydranet_mse, hydra_sigma_mse)) | ||
plt.figure() | ||
plt.clf() | ||
plt.rc('text', usetex=True) | ||
plt.rc('font', family='serif') | ||
plt.boxplot(mse_losses) | ||
#plt.ylim(0,100) | ||
plt.title('MSE') | ||
plt.grid() | ||
plt.xticks([1, 2, 3, 4, 5], ['Dropout', 'Ensemble', 'Sigma', 'HydraNet', 'HydraNet-Sigma']) | ||
plt.xticks(rotation=20) | ||
plt.yscale('log') | ||
plt.savefig('figs/uncertainty-MSE.pdf', format='pdf', dpi=800, bbox_inches='tight') |
Large diffs are not rendered by default.
Oops, something went wrong.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Rep,Dropout-NLL,Dropout-MSE,Ensemble-NLL,Ensemble-MSE,Sigma-NLL,Sigma-MSE,HydraNet-NLL,HydraNet-MSE,HydraNet-Sigma-NLL,HydraNet-Sigma-MSE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
Rep,Dropout-NLL,Dropout-MSE,Ensemble-NLL,Ensemble-MSE,Sigma-NLL,Sigma-MSE,HydraNet-NLL,HydraNet-MSE,HydraNet-Sigma-NLL,HydraNet-Sigma-MSE | ||
0,379.7791515456072,4.620935588897405,2.1391301208346998,1.4016311981774348,7.448746292007388,6.689223514927125,70.88247250307367,4.503404974706821,9.618296405322212,5.888777043944422 | ||
1,414.86480524092303,5.067042160184711,2.3885199275524176,1.3329222645449306,8.221807982834287,7.544757217181355,40.80861204353162,4.42987879318454,10.02975610951294,6.066285856747918 | ||
2,775.5840755283302,4.4648964113513125,2.4787516778433196,3.9383119338136017,7.531148093246316,6.2334702965074005,41.43443568790157,6.668626792292481,9.768880647697896,6.12391266805442 | ||
3,309.3894760368025,4.00854234131287,1.955213270678324,2.074559781829711,7.787646360110071,7.497382485824534,67.26920002442839,5.480905873724492,9.482844276699177,7.254852569705431 | ||
4,133.02373369107522,3.9481237012595103,1.934146068195844,1.669322863773667,7.963543928391613,7.277166111592238,65.62253835829712,4.496994558106445,7.4329653641830475,6.195428699395999 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
Rep,Dropout-NLL,Dropout-MSE,Ensemble-NLL,Ensemble-MSE,Sigma-NLL,Sigma-MSE,HydraNet-NLL,HydraNet-MSE,HydraNet-Sigma-NLL,HydraNet-Sigma-MSE | ||
0,365.7113729831465,4.372538431306081,1.768381719532294,2.309441149243982,7.5900531989858,7.053998384499737,36.88368071874429,8.88816725914986,11.103690824541212,5.930373454560031 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import numpy as np | ||
import torch | ||
|
||
#Setup and train a simple neural network | ||
def build_NN(dropout_p, num_outputs): | ||
if dropout_p > 0: | ||
NN = torch.nn.Sequential( | ||
torch.nn.Linear(1, 20), | ||
torch.nn.SELU(), | ||
torch.nn.AlphaDropout(p=dropout_p), | ||
torch.nn.Linear(20, 20), | ||
torch.nn.SELU(), | ||
torch.nn.AlphaDropout(p=dropout_p), | ||
torch.nn.Linear(20, 20), | ||
torch.nn.SELU(), | ||
torch.nn.AlphaDropout(p=dropout_p), | ||
torch.nn.Linear(20, num_outputs), | ||
) | ||
else: | ||
NN = torch.nn.Sequential( | ||
torch.nn.Linear(1, 20), | ||
torch.nn.SELU(), | ||
torch.nn.Linear(20, 20), | ||
torch.nn.SELU(), | ||
torch.nn.Linear(20, 20), | ||
torch.nn.SELU(), | ||
torch.nn.Linear(20, num_outputs) | ||
) | ||
return NN | ||
|
||
#NN with multiple heads | ||
def build_hydra(num_heads, num_outputs=1, direct_variance_head=False): | ||
class HydraHead(torch.nn.Module): | ||
def __init__(self, n_o): | ||
super(HydraHead, self).__init__() | ||
|
||
self.head_net = torch.nn.Sequential( | ||
torch.nn.Linear(20, 20), | ||
torch.nn.SELU(), | ||
torch.nn.Linear(20, n_o)) | ||
|
||
def forward(self, x): | ||
return self.head_net(x) | ||
|
||
class HydraNet(torch.nn.Module): | ||
def __init__(self, num_heads, num_outputs, direct_variance_head=False): | ||
super(HydraNet, self).__init__() | ||
self.shared_net = torch.nn.Sequential( | ||
torch.nn.Linear(1, 20), | ||
torch.nn.SELU(), | ||
torch.nn.Linear(20, 20), | ||
torch.nn.SELU() | ||
) | ||
#Initialize the heads | ||
self.num_heads = num_heads | ||
self.num_outputs = num_outputs | ||
self.heads = torch.nn.ModuleList([HydraHead(n_o=num_outputs) for h in range(num_heads)]) | ||
|
||
if direct_variance_head: | ||
self.direct_variance_head = HydraHead(n_o=1) | ||
else: | ||
self.direct_variance_head = None | ||
|
||
def forward(self, x): | ||
y = self.shared_net(x) | ||
y_out = [head_net(y) for head_net in self.heads] | ||
|
||
#Append the direct variance to the end of the heads | ||
if self.direct_variance_head is not None: | ||
y_out.append(self.direct_variance_head(y)) | ||
|
||
return torch.cat(y_out, 1) | ||
|
||
net = HydraNet(num_heads, num_outputs, direct_variance_head) | ||
return net | ||
|
||
|
||
#NLL loss for single-headed NN | ||
class GaussianLoss(torch.nn.Module): | ||
def __init__(self): | ||
super(GaussianLoss, self).__init__() | ||
|
||
#Based on negative log of normal distribution | ||
def forward(self, input, target): | ||
mean = input[:, 0] | ||
sigma2 = torch.log(1. + torch.exp(input[:, 1])) + 1e-6 | ||
#sigma2 = torch.nn.functional.softplus(input[:, 1]) + 1e-4 | ||
loss = torch.mean(0.5*(mean - target.squeeze())*((mean - target.squeeze())/sigma2) + 0.5*torch.log(sigma2)) | ||
return loss | ||
|
||
#NLL loss for HydraNet | ||
class GaussianHydraLoss(torch.nn.Module): | ||
def __init__(self): | ||
super(GaussianHydraLoss, self).__init__() | ||
|
||
#Based on negative log of normal distribution | ||
def forward(self, input, target): | ||
|
||
mean = input[:, :-1] | ||
sigma2 = torch.log(1. + torch.exp(input[:, [-1]])) + 1e-6 #torch.abs(input[:, [-1]]) + 1e-6 | ||
|
||
sigma2 = sigma2.repeat([1, mean.shape[1]]) | ||
#sigma2 = torch.nn.functional.softplus(input[:, :, 1]) + 1e-4 | ||
loss = 0.5*(mean - target)*((mean - target)/sigma2) + 0.5*torch.log(sigma2) | ||
return loss.mean() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
import numpy as np | ||
import torch | ||
from torch import optim | ||
import matplotlib | ||
matplotlib.use('Agg') | ||
import matplotlib.pyplot as plt | ||
from train_and_test import * | ||
import time | ||
import csv | ||
from visualize import * | ||
import os | ||
|
||
os.environ['OMP_NUM_THREADS'] = '4' | ||
torch.set_num_threads(4) | ||
|
||
def gen_train_data(N=100): | ||
mu, sigma = 0, 0.03 | ||
alpha, beta = 4, 13 | ||
w = np.random.normal(mu, sigma, N) | ||
x = np.concatenate((np.random.uniform(0,.6,int(N/2)), np.random.uniform(0.8,1.0,int(N/2)))) | ||
y = x + np.sin(alpha*(x+w)) + np.sin(beta*(x+w)) + w | ||
#y = x**3 + w | ||
return (x, y) | ||
|
||
def gen_test_data(N=100): | ||
x = np.linspace(-2,2, N) | ||
alpha, beta = 4, 13 | ||
y = x + np.sin(alpha*x) + np.sin(beta*x) | ||
#y = x**3 | ||
return (x, y) | ||
|
||
def main(): | ||
num_reps = 100 | ||
minibatch_samples = 50 | ||
train_samples = 1000 | ||
test_samples =100 | ||
num_epochs = 3000 | ||
target_noise_sigma = 0.05 | ||
|
||
use_cuda = False | ||
|
||
|
||
stats_list = [] | ||
csv_header = ["Rep","Dropout-NLL", "Dropout-MSE", | ||
"Ensemble-NLL", "Ensemble-MSE", | ||
"Sigma-NLL", "Sigma-MSE", | ||
"HydraNet-NLL", "HydraNet-MSE", | ||
"HydraNet-Sigma-NLL", "HydraNet-Sigma-MSE"] | ||
|
||
for rep in range(num_reps): | ||
|
||
print('Performing repetition {}/{}...'.format(rep+1, num_reps)) | ||
|
||
(x_train, y_train) = gen_train_data(train_samples) | ||
(x_test, y_test) = gen_test_data(test_samples) | ||
exp_data = ExperimentalData(x_train, y_train, x_test, y_test) | ||
|
||
|
||
visualize_data_only(x_train, y_train, x_test, y_test, filename='data.pdf') | ||
#return | ||
|
||
print('Starting dropout training') | ||
start = time.time() | ||
# (dropout_model, _) = train_nn_dropout(x_train, y_train, num_epochs=num_epochs, use_cuda=use_cuda) | ||
(dropout_model, _) = train_nn_dropout(x_train, y_train, minibatch_samples, num_epochs=num_epochs, use_cuda=use_cuda) | ||
(y_pred_dropout, sigma_pred_dropout) = test_nn_dropout(x_test, dropout_model, use_cuda=use_cuda) | ||
nll_dropout = compute_nll(y_test, y_pred_dropout, sigma_pred_dropout) | ||
mse_dropout = compute_mse(y_test, y_pred_dropout) | ||
print('Dropout, NLL: {:.3f} | MSE: {:.3f}'.format(nll_dropout, mse_dropout)) | ||
visualize(x_train, y_train, x_test, y_test, y_pred_dropout, sigma_pred_dropout, nll_dropout, mse_dropout, rep,'dropout_{}.png'.format(rep)) | ||
end = time.time() | ||
print('Completed in {:.3f} seconds.'.format(end - start)) | ||
# | ||
# | ||
print('Starting ensemble training') | ||
|
||
start = time.time() | ||
ensemble_models = train_nn_ensemble_bootstrap(x_train, y_train, minibatch_samples, num_epochs=num_epochs, use_cuda=use_cuda, target_noise_sigma=target_noise_sigma) | ||
(y_pred_bs, sigma_pred_bs) = test_nn_ensemble_bootstrap(x_test, ensemble_models, use_cuda=use_cuda) | ||
nll_bs = compute_nll(y_test, y_pred_bs, sigma_pred_bs) | ||
mse_bs = compute_mse(y_test, y_pred_bs) | ||
print('Ensemble BS, NLL: {:.3f} | MSE: {:.3f}'.format(nll_bs, mse_bs)) | ||
visualize(x_train, y_train, x_test, y_test, y_pred_bs, sigma_pred_bs, nll_bs, mse_bs, rep,'ensemble_{}.png'.format(rep)) | ||
end = time.time() | ||
print('Completed in {:.3f} seconds.'.format(end - start)) | ||
# | ||
# | ||
print('Starting sigma training') | ||
|
||
start = time.time() | ||
(sigma_model, _) = train_nn_sigma(exp_data, minibatch_samples, num_epochs=num_epochs, use_cuda=use_cuda) | ||
(y_pred_sigma, sigma_pred_sigma) = test_nn_sigma(x_test, sigma_model, use_cuda=use_cuda) | ||
nll_sigma = compute_nll(y_test, y_pred_sigma, sigma_pred_sigma) | ||
mse_sigma = compute_mse(y_test, y_pred_sigma) | ||
print('Sigma, NLL: {:.3f} | MSE: {:.3f}'.format(nll_sigma, mse_sigma)) | ||
visualize(x_train, y_train, x_test, y_test, y_pred_sigma, sigma_pred_sigma, nll_sigma, mse_sigma, rep,'sigma_{}.png'.format(rep)) | ||
end = time.time() | ||
print('Completed in {:.3f} seconds.'.format(end - start)) | ||
|
||
print('Starting hydranet training with target_noise_sigma={}'.format(target_noise_sigma)) | ||
# | ||
start = time.time() | ||
(hydranet_model, _) = train_hydranet(exp_data, minibatch_samples, num_heads=10, num_epochs=num_epochs, use_cuda=use_cuda, target_noise_sigma=target_noise_sigma) | ||
(y_pred_hydranet, sigma_pred_hydranet) = test_hydranet(x_test, hydranet_model, use_cuda=use_cuda) | ||
nll_hydranet = compute_nll(y_test, y_pred_hydranet, sigma_pred_hydranet) | ||
mse_hydranet = compute_mse(y_test, y_pred_hydranet) | ||
print('HydraNet, NLL: {:.3f} | MSE: {:.3f}'.format(nll_hydranet, mse_hydranet)) | ||
visualize(x_train, y_train, x_test, y_test, y_pred_hydranet, sigma_pred_hydranet, nll_hydranet, mse_hydranet, rep,'hydranet_{}.png'.format(rep)) | ||
end = time.time() | ||
print('Completed in {:.3f} seconds.'.format(end - start)) | ||
|
||
print('Starting hydranet-sigma training') | ||
# | ||
start = time.time() | ||
(hydranetsigma_model, _) = train_hydranet_sigma(exp_data, minibatch_samples, num_heads=10, num_epochs=num_epochs, use_cuda=use_cuda, target_noise_sigma=target_noise_sigma) | ||
(y_pred_hydranetsigma, sigma_pred_hydranetsigma) = test_hydranet_sigma(x_test, hydranetsigma_model, use_cuda=use_cuda) | ||
nll_hydranetsigma = compute_nll(y_test, y_pred_hydranetsigma, sigma_pred_hydranetsigma) | ||
mse_hydranetsigma = compute_mse(y_test, y_pred_hydranetsigma) | ||
|
||
print('HydraNet-Sigma, NLL: {:.3f} | MSE: {:.3f}'.format(nll_hydranetsigma, mse_hydranetsigma)) | ||
visualize(x_train, y_train, x_test, y_test, y_pred_hydranetsigma, sigma_pred_hydranetsigma, nll_hydranetsigma, mse_hydranetsigma, rep,'hydranetsigma_{}.png'.format(rep)) | ||
end = time.time() | ||
print('Completed in {:.3f} seconds.'.format(end - start)) | ||
|
||
stats_list.append([rep, nll_dropout, mse_dropout, nll_bs, mse_bs, nll_sigma, mse_sigma, nll_hydranet, mse_hydranet, nll_hydranetsigma, mse_hydranetsigma]) | ||
|
||
|
||
csv_filename = 'figs/stats_target_noise_sigma_{}.csv'.format(target_noise_sigma) | ||
with open(csv_filename, "w") as f: | ||
writer = csv.writer(f) | ||
writer.writerow(csv_header) | ||
writer.writerows(stats_list) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.