Skip to content

Commit

Permalink
AutoIAFGuide bug? (pyro-ppl#1792)
Browse files Browse the repository at this point in the history
* Fixed minor bug in MADE mask encoding

* More flexible MADE implementation

* Forgot to import numpy

* Change dimension to dim in arguments

* PEP8 and isort

* Split long line for PEP8

* Fixed trailing whitespace

* Newline at end of file?

* Rewrote MADE in Torch, and adjusted test

* Variable name changes, change to how mask is produced

* Turning off skip connections by default

* Fixed PEP8

* Added in test for masks

* Fixed bugs in mask test

* Removed debug lines

* Test for skip connections mask

* Neeraj's type comments

* PEP8

* Comment about mask structure and checking hidden dim

* Fixed assert

* Changed assert to raising a ValueError exception

* Reverting to using type_as when creating mask

* Type of arange

* Debugging decoupling AutoregressiveNN from IAF

* Updated everything using InverseAutoregressiveFlow

* PEP8

* Changed how MADE is imported

* Syntax error in IAF docs

* Adjusted IAF docs a little

* Fixed doctest problem

* Fixed some of the doc issues after code review

* AutoRegressiveNN takes parameter dimension list

* Fixed bug when only one parameter of dim > 1

* Fixed torch.unbind error

* PEP8

* Added nonlinearity example to docstring

* Describe default param_dims and range of output in AutoRegressiveNN

* New versions of IAF

* Changed to iaf.py

* Separated out flow tests from nn ones and new test for inverse for both types of IAF

* Docstrings and PEP8

* Fixed test

* Masked Autoregressive Flow

* Fixed MAF

* PEP8

* Typo

* 'NeuralAutoregressiveFlow'

* Bug fix in AutoIAFGuide

* Better fix to the bug

* Removed update_module_params
  • Loading branch information
stefanwebb authored and fehiepsi committed Mar 27, 2019
1 parent 4b664ee commit c98a2c2
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 9 deletions.
16 changes: 15 additions & 1 deletion pyro/contrib/autoguide/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class AutoGuide(object):
:param callable model: a pyro model
:param str prefix: a prefix that will be prefixed to all param internal sites
"""

def __init__(self, model, prefix="auto"):
self.master = None
self.model = model
Expand Down Expand Up @@ -140,6 +141,7 @@ class AutoGuideList(AutoGuide):
:param callable model: a Pyro model
:param str prefix: a prefix that will be prefixed to all param internal sites
"""

def __init__(self, model, prefix="auto"):
super(AutoGuideList, self).__init__(model, prefix)
self.parts = []
Expand Down Expand Up @@ -232,6 +234,7 @@ def my_local_median(*args, **kwargs)
:param callable median: an optional callable returning a dict mapping
sample site name to computed median tensor.
"""

def __init__(self, model, guide, median=lambda *args, **kwargs: {}):
super(AutoCallable, self).__init__(model, prefix="")
self._guide = guide
Expand Down Expand Up @@ -265,6 +268,7 @@ class AutoDelta(AutoGuide):
pyro.param("auto_concentration", torch.ones(k),
constraint=constraints.positive)
"""

def __call__(self, *args, **kwargs):
"""
An automatic guide with the same ``*args, **kwargs`` as the base ``model``.
Expand Down Expand Up @@ -316,6 +320,7 @@ class AutoContinuous(AutoGuide):
Alp Kucukelbir, Dustin Tran, Rajesh Ranganath, Andrew Gelman, David M.
Blei
"""

def _setup_prototype(self, *args, **kwargs):
super(AutoContinuous, self)._setup_prototype(*args, **kwargs)
self._unconstrained_shapes = {}
Expand Down Expand Up @@ -456,6 +461,7 @@ class AutoMultivariateNormal(AutoContinuous):
pyro.param("auto_scale_tril", torch.tril(torch.rand(latent_dim)),
constraint=constraints.lower_cholesky)
"""

def get_posterior(self, *args, **kwargs):
"""
Returns a MultivariateNormal posterior distribution.
Expand Down Expand Up @@ -493,6 +499,7 @@ class AutoDiagonalNormal(AutoContinuous):
pyro.param("auto_scale", torch.ones(latent_dim),
constraint=constraints.positive)
"""

def get_posterior(self, *args, **kwargs):
"""
Returns a diagonal Normal posterior distribution.
Expand Down Expand Up @@ -537,6 +544,7 @@ class AutoLowRankMultivariateNormal(AutoContinuous):
:param int rank: the rank of the low-rank part of the covariance matrix
:param str prefix: a prefix that will be prefixed to all param internal sites
"""

def __init__(self, model, prefix="auto", rank=1):
if not isinstance(rank, numbers.Number) or not rank > 0:
raise ValueError("Expected rank > 0 but got {}".format(rank))
Expand Down Expand Up @@ -580,8 +588,10 @@ class AutoIAFNormal(AutoContinuous):
:param int hidden_dim: number of hidden dimensions in the IAF
:param str prefix: a prefix that will be prefixed to all param internal sites
"""

def __init__(self, model, hidden_dim=None, prefix="auto"):
self.hidden_dim = hidden_dim
self.arn = None
super(AutoIAFNormal, self).__init__(model, prefix)

def get_posterior(self, *args, **kwargs):
Expand All @@ -593,7 +603,10 @@ def get_posterior(self, *args, **kwargs):
raise ValueError('latent dim = 1. Consider using AutoDiagonalNormal instead')
if self.hidden_dim is None:
self.hidden_dim = self.latent_dim
iaf = dist.InverseAutoregressiveFlow(AutoRegressiveNN(self.latent_dim, [self.hidden_dim]))
if self.arn is None:
self.arn = AutoRegressiveNN(self.latent_dim, [self.hidden_dim])

iaf = dist.InverseAutoregressiveFlow(self.arn)
pyro.module("{}_iaf".format(self.prefix), iaf)
iaf_dist = dist.TransformedDistribution(dist.Normal(0., 1.).expand([self.latent_dim]), [iaf])
return iaf_dist
Expand Down Expand Up @@ -662,6 +675,7 @@ class AutoDiscreteParallel(AutoGuide):
A discrete mean-field guide that learns a latent discrete distribution for
each discrete site in the model.
"""

def _setup_prototype(self, *args, **kwargs):
# run the model so we can inspect its structure
model = config_enumerate(self.model)
Expand Down
19 changes: 11 additions & 8 deletions pyro/distributions/lkj.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class _CorrCholesky(Constraint):
Euclidean norm of each row is 1, such that `torch.mm(omega, omega.t())` will
have unit diagonal.
"""

def check(self, value):
unit_norm_row = (value.norm(dim=-1).sub(1) < 1e-4).min(-1)[0]
return constraints.lower_cholesky.check(value) & unit_norm_row
Expand All @@ -34,21 +35,21 @@ def check(self, value):
########################################

def _vector_to_l_cholesky(z):
D = (1.0 + math.sqrt(1.0 + 8.0 * z.shape[-1]))/2.0
D = (1.0 + math.sqrt(1.0 + 8.0 * z.shape[-1])) / 2.0
if D % 1 != 0:
raise ValueError("Correlation matrix transformation requires d choose 2 inputs")
D = int(D)
x = z.new_zeros(list(z.shape[:-1]) + [D, D])

x[..., 0, 0] = 1
x[..., 1:, 0] = z[..., :(D-1)]
x[..., 1:, 0] = z[..., :(D - 1)]
i = D - 1
last_squared_x = z.new_zeros(list(z.shape[:-1]) + [D])
for j in range(1, D):
distance_to_copy = D - 1 - j
last_squared_x = last_squared_x[..., 1:] + x[..., j:, (j-1)].clone()**2
last_squared_x = last_squared_x[..., 1:] + x[..., j:, (j - 1)].clone()**2
x[..., j, j] = (1 - last_squared_x[..., 0]).sqrt()
x[..., (j+1):, j] = z[..., i:(i + distance_to_copy)] * (1 - last_squared_x[..., 1:]).sqrt()
x[..., (j + 1):, j] = z[..., i:(i + distance_to_copy)] * (1 - last_squared_x[..., 1:]).sqrt()
i += distance_to_copy
return x

Expand Down Expand Up @@ -88,12 +89,12 @@ def _inverse(self, y):
]

for i in range(2, D):
z_tri[..., i - 2, 0:(i-1)] = y[..., i, 1:i] / (1-y[..., i, 0:(i-1)].pow(2).cumsum(-1)).sqrt()
z_tri[..., i - 2, 0:(i - 1)] = y[..., i, 1:i] / (1 - y[..., i, 0:(i - 1)].pow(2).cumsum(-1)).sqrt()
for j in range(D - 2):
z_stack.append(z_tri[..., j:, j])

z = torch.cat(z_stack, -1)
return torch.log1p((2*z)/(1-z)) / 2
return torch.log1p((2 * z) / (1 - z)) / 2

def log_abs_det_jacobian(self, x, y):
# Note dependence on pytorch 1.0.1 for batched tril
Expand All @@ -102,6 +103,8 @@ def log_abs_det_jacobian(self, x, y):
return tanpart + matpart

# register transform to global store


@biject_to.register(corr_cholesky_constraint)
@transform_to.register(corr_cholesky_constraint)
def _transform_to_corr_cholesky(constraint):
Expand Down Expand Up @@ -148,9 +151,9 @@ def __init__(self, d, eta, validate_args=None):

concentrations = eta.new_empty(vector_size,)
i = 0
for k in range(d-1):
for k in range(d - 1):
alpha -= .5
concentrations[..., i:(i + d - k-1)] = alpha
concentrations[..., i:(i + d - k - 1)] = alpha
i += d - k - 1
self._gen = Beta(concentrations, concentrations)
self.eta = eta
Expand Down

0 comments on commit c98a2c2

Please sign in to comment.