Skip to content

Commit

Permalink
add support for anneal & swtch from cuda to device
Browse files Browse the repository at this point in the history
  • Loading branch information
ilkhem committed Jun 18, 2019
1 parent 3bb01ea commit 3209d11
Showing 1 changed file with 60 additions and 38 deletions.
98 changes: 60 additions & 38 deletions lib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from torch import nn
from torch.nn import functional as F


def weights_init(m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight.data)


class Dist:
def __init__(self):
pass
Expand All @@ -22,12 +24,9 @@ def log_pdf(self, *args, **kwargs):


class Normal(Dist):
def __init__(self, cuda=False):
def __init__(self, device='cpu'):
super().__init__()
if cuda:
self.device = 'cuda:0'
else:
self.device = 'cpu'
self.device = device
self.c = 2 * np.pi * torch.ones(1).to(self.device)
self._dist = dist.normal.Normal(torch.zeros(1).to(self.device), torch.ones(1).to(self.device))

Expand All @@ -36,9 +35,13 @@ def sample(self, mu, v):
scaled = eps.mul(v.sqrt())
return scaled.add(mu)

def log_pdf(self, x, mu, v):
def log_pdf(self, x, mu, v, reduce=True):
"""compute the log-pdf of a normal distribution with diagonal covariance"""
return (-0.5 * (torch.log(self.c) + v.log() + (x - mu).pow(2).div(v))).sum(dim=-1)
lpdf = -0.5 * (torch.log(self.c) + v.log() + (x - mu).pow(2).div(v))
if reduce:
return lpdf.sum(dim=-1)
else:
return lpdf

def log_pdf_full(self, x, mu, v):
"""
Expand Down Expand Up @@ -71,30 +74,32 @@ def _batch_slogdet(self, cov_batch: torch.Tensor):


class Laplace(Dist):
def __init__(self, cuda=False):
def __init__(self, device='cpu'):
super().__init__()
if cuda:
self.device = 'cuda:0'
else:
self.device = 'cpu'
self.device = device
self._dist = dist.laplace.Laplace(torch.zeros(1).to(self.device), torch.ones(1).to(self.device) / np.sqrt(2))

def sample(self, mu, b):
eps = self._dist.sample(mu.size())
scaled = eps.mul(b)
return scaled.add(mu)

def log_pdf(self, x, mu, b):
def log_pdf(self, x, mu, b, reduce=True):
"""compute the log-pdf of a laplace distribution with diagonal covariance"""
return (-torch.log(2 * b) - (x - mu).abs().div(b)).sum(dim=-1)
lpdf = -torch.log(2 * b) - (x - mu).abs().div(b)
if reduce:
return lpdf.sum(dim=-1)
else:
return lpdf


class MLP(nn.Module):
def __init__(self, input_dim, output_dim, hidden_dim, n_layers, activation='none', slope=.1):
def __init__(self, input_dim, output_dim, hidden_dim, n_layers, activation='none', slope=.1, device='cpu'):
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.n_layers = n_layers
self.device = device
if isinstance(hidden_dim, Number):
self.hidden_dim = [hidden_dim] * (self.n_layers - 1)
elif isinstance(hidden_dim, list):
Expand Down Expand Up @@ -130,6 +135,7 @@ def __init__(self, input_dim, output_dim, hidden_dim, n_layers, activation='none
_fc_list.append(nn.Linear(self.hidden_dim[i - 1], self.hidden_dim[i]))
_fc_list.append(nn.Linear(self.hidden_dim[self.n_layers - 2], self.output_dim))
self.fc = nn.ModuleList(_fc_list)
self.to(self.device)

@staticmethod
def xtanh(x, alpha=.1):
Expand All @@ -148,7 +154,7 @@ def forward(self, x):

class iVAE(nn.Module):
def __init__(self, latent_dim, data_dim, aux_dim, prior=None, decoder=None, encoder=None,
n_layers=3, hidden_dim=50, activation='lrelu', slope=.1, cuda=False):
n_layers=3, hidden_dim=50, activation='lrelu', slope=.1, device='cpu', anneal=False):
super().__init__()

self.data_dim = data_dim
Expand All @@ -158,46 +164,38 @@ def __init__(self, latent_dim, data_dim, aux_dim, prior=None, decoder=None, enco
self.n_layers = n_layers
self.activation = activation
self.slope = slope
self.anneal_params = anneal

if prior is None:
self.prior_dist = Normal(cuda=cuda)
self.prior_dist = Normal(device=device)
else:
self.prior_dist = prior

if decoder is None:
self.decoder_dist = Normal(cuda=cuda)
self.decoder_dist = Normal(device=device)
else:
self.decoder_dist = decoder

if encoder is None:
self.encoder_dist = Normal(cuda=cuda)
self.encoder_dist = Normal(device=device)
else:
self.encoder_dist = encoder

if cuda:
device = 'cuda'
else:
device = 'cpu'

# prior_params
self.prior_mean = torch.zeros(1).to(device)
self.logl = MLP(aux_dim, latent_dim, hidden_dim, n_layers, activation=activation, slope=slope)
self.logl = MLP(aux_dim, latent_dim, hidden_dim, n_layers, activation=activation, slope=slope, device=device)
# decoder params
self.f = MLP(latent_dim, data_dim, hidden_dim, n_layers, activation=activation, slope=slope)
self.f = MLP(latent_dim, data_dim, hidden_dim, n_layers, activation=activation, slope=slope, device=device)
self.decoder_var = .01 * torch.ones(1).to(device)
# encoder params
self.g = MLP(data_dim + aux_dim, latent_dim, hidden_dim, n_layers, activation=activation, slope=slope)
self.logv = MLP(data_dim + aux_dim, latent_dim, hidden_dim, n_layers, activation=activation, slope=slope)

if cuda:
self.cuda()
self.g = MLP(data_dim + aux_dim, latent_dim, hidden_dim, n_layers, activation=activation, slope=slope,
device=device)
self.logv = MLP(data_dim + aux_dim, latent_dim, hidden_dim, n_layers, activation=activation, slope=slope,
device=device)

self.apply(weights_init)

self.a = 1
self.b = 1
self.c = 1
self.d = 1
self._training_hyperparams = [1, 1, 1, 1, 1]

def encoder_params(self, x, u):
xu = torch.cat((x, u), 1)
Expand All @@ -221,8 +219,32 @@ def forward(self, x, u):
return decoder_params, encoder_params, z, prior_params

def elbo(self, x, u):
decoder_params, encoder_params, z, prior_params = self.forward(x, u)
decoder_params, (g, v), z, prior_params = self.forward(x, u)
log_px_z = self.decoder_dist.log_pdf(x, *decoder_params)
log_qz_xu = self.encoder_dist.log_pdf(z, *encoder_params)
log_qz_xu = self.encoder_dist.log_pdf(z, g, v)
log_pz_u = self.prior_dist.log_pdf(z, *prior_params)
return (log_px_z + log_pz_u - log_qz_xu).mean(), z

if self.anneal_params:
a, b, c, d, N = self._training_hyperparams
M = z.size(0)
log_qz_tmp = self.encoder_dist.log_pdf(z.view(M, 1, self.latent_dim), g.view(1, M, self.latent_dim),
v.view(1, M, self.latent_dim), reduce=False)
log_qz = torch.logsumexp(log_qz_tmp.sum(dim=-1), dim=1, keepdim=False) - np.log(M * self.N)
log_qz_i = (torch.logsumexp(log_qz_tmp, dim=1, keepdim=False) - np.log(M * self.N)).sum(dim=-1)

return (a * log_px_z - b * (log_qz_xu - log_qz) - c * (log_qz - log_qz_i) - d * (
log_qz_i - log_pz_u)).mean(), z

else:
return (log_px_z + log_pz_u - log_qz_xu).mean(), z

def anneal(self, N, max_iters, iter):
thr = int(max_iters / 1.6)
a = 0.5 / self.decoder_var.item()
self._training_hyperparams[-1] = N
self._training_hyperparams[0] = min(2 * a, a + a * iter / thr)
self._training_hyperparams[1] = max(1, a * .3 * (1 - iter / thr))
self._training_hyperparams[2] = min(1, iter / thr)
self._training_hyperparams[3] = max(1, a * .5 * (1 - iter / thr))
if iter > thr:
self.anneal_params = False

0 comments on commit 3209d11

Please sign in to comment.