Skip to content

Commit

Permalink
upload code
Browse files Browse the repository at this point in the history
  • Loading branch information
seunghyukcho committed Oct 12, 2022
1 parent 949cb62 commit 0caa8bb
Show file tree
Hide file tree
Showing 46 changed files with 1,697 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,6 @@ dmypy.json

# Pyre type checker
.pyre/

data/
wandb/
33 changes: 33 additions & 0 deletions arguments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import argparse


tasks = ['Breakout', 'NSBT']
distributions = ['EuclideanNormal', 'IsotropicHWN', 'DiagonalHWN', 'FullHWN', 'RoWN']


def get_initial_parser():
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument('--task', type=str, choices=tasks)
parser.add_argument('--dist', type=str, choices=distributions)
return parser


def add_train_args(parser):
group = parser.add_argument_group('train')
group.add_argument('--task', type=str, choices=tasks)
group.add_argument('--dist', type=str, choices=distributions)
group.add_argument('--seed', type=int, default=7777)
group.add_argument('--latent_dim', type=int, default=2)
group.add_argument('--beta', type=float, default=1.)
group.add_argument('--n_epochs', type=int, default=10)
group.add_argument('--train_batch_size', type=int, default=32)
group.add_argument('--test_batch_size', type=int, default=32)
group.add_argument('--lr', type=float, default=1e-5)
group.add_argument('--device', type=str, default='cuda:0')
group.add_argument('--eval_interval', type=int, default=10)
group.add_argument('--log_interval', type=int, default=10)
group.add_argument('--log_dir', type=str, default='logs/')
group.add_argument('--train_samples', type=int, default=1)
group.add_argument('--test_samples', type=int, default=500)
group.add_argument('--exp_name', type=str, default='dummy')

3 changes: 3 additions & 0 deletions distributions/DiagonalHWN/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .layers import EncoderLayer, DecoderLayer, EmbeddingLayer
from .distribution import Distribution
from .prior import get_prior
18 changes: 18 additions & 0 deletions distributions/DiagonalHWN/distribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch
from torch.distributions import Normal

from ..hwn import HWN


class Distribution(HWN):
def __init__(self, mean, covar) -> None:
base = Normal(
torch.zeros(
covar.size(),
device=covar.device
),
covar
)

super().__init__(mean, base)

38 changes: 38 additions & 0 deletions distributions/DiagonalHWN/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import torch
import geoopt
from torch import nn
from torch.nn import functional as F

from ..utils import ExpLayer, LogLayer

EncoderLayer = ExpLayer
DecoderLayer = LogLayer


class EmbeddingLayer(nn.Module):
def __init__(self, args, n_words):
super().__init__()

self.args = args
self.latent_dim = args.latent_dim
self.n_words = n_words
self.initial_sigma = args.initial_sigma
self.manifold = geoopt.manifolds.Lorentz()

mean_initialize = torch.empty([self.n_words, self.latent_dim])
nn.init.normal_(mean_initialize, std=args.initial_sigma)
self.mean = nn.Embedding.from_pretrained(mean_initialize, freeze=False)

covar_initialize = torch.empty([self.n_words, self.latent_dim])
nn.init.normal_(covar_initialize, std=args.initial_sigma)
self.covar = nn.Embedding.from_pretrained(covar_initialize, freeze=False)

def forward(self, x):
mean = self.mean(x)
mean = F.pad(mean, (1, 0))
mean = self.manifold.expmap0(mean)

covar = F.softplus(self.covar(x))

return mean, covar

18 changes: 18 additions & 0 deletions distributions/DiagonalHWN/prior.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch
import geoopt
from .distribution import Distribution


def get_prior(args):
m = geoopt.manifolds.Lorentz()

mean = m.origin([1, args.latent_dim + 1], device=args.device)
covar = torch.ones(
1,
args.latent_dim,
device=args.device
)

prior = Distribution(mean, covar)
return prior

3 changes: 3 additions & 0 deletions distributions/EuclideanNormal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .layers import EncoderLayer, DecoderLayer, EmbeddingLayer
from .distribution import Distribution
from .prior import get_prior
28 changes: 28 additions & 0 deletions distributions/EuclideanNormal/distribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch
from torch.distributions import Normal

def kl_dist(mu0, std0, mu1, std1):
k = mu0.size(-1)
logvar0, logvar1 = 2 * std0.log(), 2 * std1.log()

dist = logvar1 - logvar0 + (((mu0 - mu1).pow(2) + 1e-9).log() - logvar1).exp() + (logvar0 - logvar1).exp()
dist = dist.sum(dim=-1) - k
return dist * 0.5

class Distribution():
def __init__(self, mean, covar) -> None:
self.mean = mean
self.covar = covar

self.base = Normal(self.mean, self.covar)

def log_prob(self, z):
return self.base.log_prob(z).sum(dim=-1)

def rsample(self, N):
return self.base.rsample([N])

def sample(self, N):
with torch.no_grad():
return self.rsample(N)

60 changes: 60 additions & 0 deletions distributions/EuclideanNormal/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torch
from torch import nn
from torch.nn import functional as F


class EncoderLayer(nn.Module):
def __init__(self, args, feature_dim) -> None:
super().__init__()

self.latent_dim = args.latent_dim
self.feature_dim = feature_dim

self.variational = nn.Linear(
self.feature_dim,
2 * self.latent_dim
)

def forward(self, feature):
feature = self.variational(feature)
mean, covar = torch.split(
feature,
[self.latent_dim, self.latent_dim],
dim=-1
)
covar = F.softplus(covar)

return mean, covar


class DecoderLayer(nn.Module):
def __init__(self) -> None:
super().__init__()

def forward(self, z):
return z


class EmbeddingLayer(nn.Module):
def __init__(self, args, n_words):
super().__init__()

self.args = args
self.latent_dim = args.latent_dim
self.n_words = n_words
self.initial_sigma = args.initial_sigma

mean_initialize = torch.empty([self.n_words, self.latent_dim])
nn.init.normal_(mean_initialize, std=args.initial_sigma)
self.mean = nn.Embedding.from_pretrained(mean_initialize, freeze=False)

covar_initialize = torch.empty([self.n_words, self.latent_dim])
nn.init.normal_(covar_initialize, std=args.initial_sigma)
self.covar = nn.Embedding.from_pretrained(covar_initialize, freeze=False)

def forward(self, x):
mean = self.mean(x)
covar = F.softplus(self.covar(x))

return mean, covar

18 changes: 18 additions & 0 deletions distributions/EuclideanNormal/prior.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch

from .distribution import Distribution


def get_prior(args):
mean = torch.zeros(
[1, args.latent_dim],
device=args.device
)
covar = torch.ones(
[1, args.latent_dim],
device=args.device
)

prior = Distribution(mean, covar)
return prior

3 changes: 3 additions & 0 deletions distributions/FullHWN/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .layers import EncoderLayer, DecoderLayer, EmbeddingLayer
from .distribution import Distribution
from .prior import get_prior
29 changes: 29 additions & 0 deletions distributions/FullHWN/distribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch
from torch.distributions import MultivariateNormal

from ..hwn import HWN


class Distribution(HWN):
def __init__(self, mean, covar) -> None:
base = MultivariateNormal(
torch.zeros(
mean.size(),
device=covar.device
)[..., 1:],
covar
)

super().__init__(mean, base)

def log_prob(self, z):
u = self.manifold.logmap(self.mean, z)
v = self.manifold.transp(self.mean, self.origin, u)
log_prob_v = self.base.log_prob(v[:, :, 1:])

r = self.manifold.norm(u)
log_det = (self.latent_dim - 1) * (torch.sinh(r).log() - r.log())

log_prob_z = log_prob_v - log_det
return log_prob_z

90 changes: 90 additions & 0 deletions distributions/FullHWN/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import torch
import geoopt
from torch import nn
from torch.nn import functional as F

from ..utils import LogLayer


class EncoderLayer(nn.Module):
def __init__(self, args, feature_dim) -> None:
super().__init__()

self.latent_dim = args.latent_dim
self.feature_dim = feature_dim

self.manifold = geoopt.manifolds.Lorentz()
self.variational = nn.Linear(
self.feature_dim,
self.latent_dim + self.latent_dim ** 2
)

def forward(self, feature):
feature = self.variational(feature)
mu, covar = torch.split(
feature,
[self.latent_dim, self.latent_dim ** 2],
dim=-1
)

mu = F.pad(mu, (1, 0))
mu = self.manifold.expmap0(mu)

covar_size = covar.size()[:-1]
covar = covar.view(
*covar_size,
self.latent_dim,
self.latent_dim
)
covar = covar.matmul(covar.transpose(-1, -2))
covar = covar + 1e-9 * torch.eye(
self.latent_dim,
device=covar.device
)[None, ...]

return mu, covar


DecoderLayer = LogLayer


class EmbeddingLayer(nn.Module):
def __init__(self, args, n_words):
super().__init__()

self.args = args
self.latent_dim = args.latent_dim
self.n_words = n_words
self.initial_sigma = args.initial_sigma
self.manifold = geoopt.manifolds.Lorentz()

mean_initialize = torch.empty([self.n_words, self.latent_dim])
nn.init.normal_(mean_initialize, std=args.initial_sigma)
self.mean = nn.Embedding.from_pretrained(mean_initialize, freeze=False)

covar_initialize = torch.stack(
[torch.eye(self.latent_dim) for _ in range(self.n_words)]
).view(self.n_words, -1)
covar_initialize = covar_initialize * torch.randn(covar_initialize.size()) * self.initial_sigma
self.covar = nn.Embedding.from_pretrained(covar_initialize, freeze=False)

def forward(self, x):
mean = self.mean(x)
mean = F.pad(mean, (1, 0))
mean = self.manifold.expmap0(mean)

covar = self.covar(x)
covar_size = covar.size()[:-1]
covar = covar.view(
*covar_size,
self.latent_dim,
self.latent_dim
)
covar = covar.matmul(covar.transpose(-1, -2))
covar = covar + 1e-9 * torch.eye(
self.latent_dim,
device=covar.device
)[None, ...]

return mean, covar

17 changes: 17 additions & 0 deletions distributions/FullHWN/prior.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import torch
import geoopt
from .distribution import Distribution


def get_prior(args):
m = geoopt.manifolds.Lorentz()

mean = m.origin([1, args.latent_dim + 1], device=args.device)
covar = torch.eye(
args.latent_dim,
device=args.device
)[None, ...]

prior = Distribution(mean, covar)
return prior

3 changes: 3 additions & 0 deletions distributions/IsotropicHWN/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .layers import EncoderLayer, DecoderLayer, EmbeddingLayer
from .distribution import Distribution
from .prior import get_prior
Loading

0 comments on commit 0caa8bb

Please sign in to comment.