forked from pymc-devs/pymc4
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Variational Inference Interface (pymc-devs#280)
* Started hacking around tfp and pymc4 * Completed build_logpfn function * Completed basic structure for vi * Updated functions with better names * Completed MeanField ADVI * Removed my print statement from sampling * Added docs for fit function * Completed docstrings and written comments * Added quickstart notebook * Resolved mypy issues * Updated return value from fit function to contain approximation. Added `sample` method for posterior distribution. Updated quickstart notebook. * Resolved pylint issues * Added inverse bijector to account samples of transformed variables Added a new axis to handle ArviZ shape issues Changed initialization of std to 1 Created updates module to account for optimizers Added test_variational.py Updated quick_start notebook * Polished tests and optimizers
- Loading branch information
Showing
8 changed files
with
962 additions
and
2 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
"""Tools for Variational Inference.""" | ||
from .approximations import * | ||
from .updates import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,244 @@ | ||
"""Implements ADVI approximations.""" | ||
from typing import Optional, Union | ||
from collections import namedtuple | ||
|
||
import arviz as az | ||
import numpy as np | ||
import tensorflow as tf | ||
import tensorflow_probability as tfp | ||
|
||
from pymc4 import flow | ||
from pymc4.coroutine_model import Model | ||
from pymc4.distributions.transforms import JacobianPreference | ||
from pymc4.inference.utils import initialize_sampling_state | ||
from pymc4.utils import NameParts | ||
from pymc4.variational import updates | ||
|
||
tfd = tfp.distributions | ||
tfb = tfp.bijectors | ||
ADVIFit = namedtuple("ADVIFit", "approximation, losses") | ||
|
||
|
||
class Approximation(tf.Module): | ||
"""Base Approximation class.""" | ||
|
||
def __init__(self, model: Optional[Model] = None, random_seed: Optional[int] = None): | ||
if not isinstance(model, Model): | ||
raise TypeError( | ||
"`fit` function only supports `pymc4.Model` objects, but you've passed `{}`".format( | ||
type(model) | ||
) | ||
) | ||
|
||
self.model = model | ||
self._seed = random_seed | ||
self.state, self.deterministic_names = initialize_sampling_state(model) | ||
if not self.state.all_unobserved_values: | ||
raise ValueError( | ||
f"Can not calculate a log probability: the model {model.name or ''} has no unobserved values." | ||
) | ||
|
||
self.unobserved_keys = self.state.all_unobserved_values.keys() | ||
self.target_log_prob = self._build_logfn() | ||
self.approx = self._build_posterior() | ||
|
||
def _build_logfn(self): | ||
"""Build vectorized logp function.""" | ||
|
||
@tf.function(autograph=False) | ||
def logpfn(*values, **kwargs): | ||
if kwargs and values: | ||
raise TypeError("Either list state should be passed or a dict one") | ||
elif values: | ||
kwargs = dict(zip(self.unobserved_keys, values)) | ||
st = flow.SamplingState.from_values(kwargs) | ||
_, st = flow.evaluate_model_transformed(self.model, state=st) | ||
return st.collect_log_prob() | ||
|
||
def vectorize_logp_function(logpfn): | ||
def vectorized_logpfn(*q_samples): | ||
return tf.vectorized_map(lambda samples: logpfn(*samples), q_samples) | ||
|
||
return vectorized_logpfn | ||
|
||
return vectorize_logp_function(logpfn) | ||
|
||
def _build_posterior(self): | ||
raise NotImplementedError | ||
|
||
def flatten_view(self): | ||
"""Flattened view of the variational parameters.""" | ||
pass | ||
|
||
def sample(self, n): | ||
"""Generate samples from posterior distribution.""" | ||
q_samples = dict(zip(self.unobserved_keys, self.approx.sample(n))) | ||
|
||
# TODO - Account for deterministics as well. | ||
# For all transformed_variables, apply inverse of bijector to sampled values to match support in constraint space. | ||
_, st = flow.evaluate_model(self.model) | ||
for transformed_name in self.state.transformed_values: | ||
untransformed_name = NameParts.from_name(transformed_name).full_untransformed_name | ||
transform = st.distributions[untransformed_name].transform | ||
if transform.JacobianPreference == JacobianPreference.Forward: | ||
q_samples[untransformed_name] = transform.forward(q_samples[transformed_name]) | ||
else: | ||
q_samples[untransformed_name] = transform.inverse(q_samples[transformed_name]) | ||
|
||
# Add a new axis so as n_chains=1 for InferenceData: handles shape issues | ||
trace = {k: v.numpy()[np.newaxis] for k, v in q_samples.items()} | ||
trace = az.from_dict(trace, observed_data=self.state.observed_values) | ||
return trace | ||
|
||
|
||
class MeanField(Approximation): | ||
""" | ||
Mean Field ADVI. | ||
This class implements Mean Field Automatic Differentiation Variational Inference. It posits spherical | ||
Gaussian family to fit posterior. And assumes the parameters to be uncorrelated. | ||
References | ||
---------- | ||
- Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., | ||
and Blei, D. M. (2016). Automatic Differentiation Variational | ||
Inference. arXiv preprint arXiv:1603.00788. | ||
""" | ||
|
||
def _build_loc(self, shape, dtype, name): | ||
loc = tf.Variable(tf.random.normal(shape, seed=self._seed), name=f"{name}/mu", dtype=dtype) | ||
return loc | ||
|
||
def _build_cov_matrix(self, shape, dtype, name): | ||
# As per `tfp.vi.fit_surrogate_posterior` docs, use `TransformedVariable` or `DeferredTensor` | ||
# to ensure all ops invoke gradients while applying transformation. | ||
scale = tfp.util.TransformedVariable( | ||
tf.fill(shape, value=tf.constant(1, dtype=dtype)), | ||
tfb.Softplus(), # For positive values of scale | ||
name=f"{name}/sigma", | ||
) | ||
return scale | ||
|
||
def _build_posterior(self): | ||
def apply_normal(dist_name): | ||
unobserved_value = self.state.all_unobserved_values[dist_name] | ||
shape = unobserved_value.shape | ||
dtype = unobserved_value.dtype | ||
return tfd.Normal( | ||
self._build_loc(shape, dtype, dist_name), | ||
self._build_cov_matrix(shape, dtype, dist_name), | ||
) | ||
|
||
# Should we use `tf.nest.map_structure` or `pm.utils.map_structure`? | ||
variational_params = tf.nest.map_structure(apply_normal, self.unobserved_keys) | ||
return tfd.JointDistributionSequential(variational_params) | ||
|
||
|
||
class FullRank(Approximation): | ||
"""Full Rank Automatic Differential Variational Inference(Full Rank ADVI).""" | ||
|
||
def _build_loc(self): | ||
raise NotImplementedError | ||
|
||
def _build_cov_matrix(self): | ||
raise NotImplementedError | ||
|
||
def _build_posterior(self): | ||
raise NotImplementedError | ||
|
||
|
||
class LowRank(Approximation): | ||
"""Low Rank Automatic Differential Variational Inference(Low Rank ADVI).""" | ||
|
||
def _build_loc(self): | ||
raise NotImplementedError | ||
|
||
def _build_cov_matrix(self): | ||
raise NotImplementedError | ||
|
||
def _build_posterior(self): | ||
raise NotImplementedError | ||
|
||
|
||
def fit( | ||
model: Optional[Model] = None, | ||
method: Union[str, MeanField] = "advi", | ||
num_steps: int = 10000, | ||
sample_size: int = 1, | ||
random_seed: Optional[int] = None, | ||
optimizer=None, | ||
**kwargs, | ||
): | ||
""" | ||
Fit an approximating distribution to log_prob of the model. | ||
Parameters | ||
---------- | ||
model : Optional[:class:`Model`] | ||
Model to fit posterior against | ||
method : Union[str, :class:`Approximation`] | ||
Method to fit model using VI | ||
- 'advi' for :class:`MeanField` | ||
- 'fullrank_advi' for :class:`FullRank` | ||
- 'lowrank_advi' for :class:`LowRank` | ||
- or directly pass in :class:`Approximation` instance | ||
num_steps : int | ||
Number of iterations to run the optimizer | ||
sample_size : int | ||
Number of Monte Carlo samples used for approximation | ||
random_seed : Optional[int] | ||
Seed for tensorflow random number generator | ||
optimizer : TF1-style | TF2-style | from pymc4/variational/updates | ||
Tensorflow optimizer to use | ||
kwargs : Optional[Dict[str, Any]] | ||
Pass extra non-default arguments to | ||
``tensorflow_probability.vi.fit_surrogate_posterior`` | ||
Returns | ||
------- | ||
ADVIFit : collections.namedtuple | ||
Named tuple, including approximation, ELBO losses depending on the `trace_fn` | ||
""" | ||
_select = dict(advi=MeanField,) | ||
|
||
if isinstance(method, str): | ||
# Here we assume that `model` parameter is provided by the user. | ||
try: | ||
inference = _select[method.lower()](model, random_seed) | ||
except KeyError: | ||
raise KeyError( | ||
"method should be one of %s or Approximation instance" % set(_select.keys()) | ||
) | ||
|
||
elif isinstance(method, Approximation): | ||
# Here we assume that `model` parameter is not provided by the user | ||
# as the :class:`Approximation` itself contains :class:`Model` instance. | ||
inference = method | ||
|
||
else: | ||
raise TypeError( | ||
"method should be one of %s or Approximation instance" % set(_select.keys()) | ||
) | ||
|
||
# Defining `opt = optimizer or updates.adam()` | ||
# leads to optimizer initialization issues from tf. | ||
if optimizer: | ||
opt = optimizer | ||
else: | ||
opt = updates.adam() | ||
|
||
@tf.function(autograph=False) | ||
def run_approximation(): | ||
losses = tfp.vi.fit_surrogate_posterior( | ||
target_log_prob_fn=inference.target_log_prob, | ||
surrogate_posterior=inference.approx, | ||
num_steps=num_steps, | ||
sample_size=sample_size, | ||
seed=random_seed, | ||
optimizer=opt, | ||
**kwargs, | ||
) | ||
return losses | ||
|
||
return ADVIFit(inference, run_approximation()) |
Oops, something went wrong.