Skip to content

Commit

Permalink
Add log_prob_elemwise feature (pymc-devs#158)
Browse files Browse the repository at this point in the history
* Added collect_log_prob_elemwise to executor

* Added tests for executor's collect_log_prob_elemwise

* Add high level API pm.model_log_prob_elemwise

* Black style

* Add missing model_log_prob_elemwise function

* [TST] Fixed broken model_with_plates test

* Removed debugging prints
  • Loading branch information
lucianopaz authored and ColCarroll committed Nov 27, 2019
1 parent dc40e55 commit 014e59a
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 4 deletions.
11 changes: 8 additions & 3 deletions pymc4/flow/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,17 @@ def __init__(
self.deterministics = deterministics
self.posterior_predictives = posterior_predictives

def collect_log_prob(self):
all_terms = itertools.chain(
def collect_log_prob_elemwise(self):
return itertools.chain(
(dist.log_prob(self.all_values[name]) for name, dist in self.distributions.items()),
(p.value for p in self.potentials),
)
return sum(map(tf.reduce_sum, all_terms))

def collect_log_prob(self):
return sum(map(tf.reduce_sum, self.collect_log_prob_elemwise()))

def collect_unreduced_log_prob(self):
return sum(self.collect_log_prob_elemwise())

def __repr__(self):
# display keys only
Expand Down
89 changes: 88 additions & 1 deletion tests/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,32 @@
import numpy as np
from pymc4 import distributions as dist
import tensorflow as tf
from pymc4.flow.executor import get_observed_tensor_shape, EvaluationError
from tensorflow_probability import distributions as tfd
from pymc4.flow.executor import get_observed_tensor_shape, EvaluationError, SamplingState


TEST_SHAPES = [(), (1,), (3,), (1, 1), (1, 3), (5, 3)]


@pytest.fixture(scope="module", params=TEST_SHAPES, ids=str)
def fixture_batch_shapes(request):
return request.param


@pytest.fixture(scope="module", params=TEST_SHAPES, ids=str)
def fixture_sample_shapes(request):
return request.param


@pytest.fixture(scope="module")
def fixture_distribution_parameters(fixture_batch_shapes, fixture_sample_shapes):
observed = np.random.randn(*(fixture_sample_shapes + fixture_batch_shapes))
return fixture_batch_shapes, observed


@pytest.fixture(scope="module", params=["decorate_model", "use_plain_function"], ids=str)
def fixture_pm_model_decorate(request):
return request.param == "decorate_model"


@pytest.fixture("module")
Expand Down Expand Up @@ -79,6 +104,33 @@ def class_model_method(self):
return ClassModel()


@pytest.fixture(scope="module")
def fixture_model_with_plates(fixture_distribution_parameters, fixture_pm_model_decorate):
batch_shape, observed = fixture_distribution_parameters
expected_obs_shape = (
()
if isinstance(observed, float)
else observed.shape[: len(observed.shape) - len(batch_shape)]
)
if fixture_pm_model_decorate:
expected_rv_shapes = {
"model/loc": (),
"model/obs": expected_obs_shape,
}
else:
expected_rv_shapes = {"loc": (), "obs": expected_obs_shape}

def model():
loc = yield pm.Normal("loc", 0, 1)
obs = yield pm.Normal("obs", loc, 1, plate=batch_shape, observed=observed)
return obs

if fixture_pm_model_decorate:
model = pm.model(model)

return model, expected_rv_shapes


@pytest.fixture("module")
def model_with_deterministics():
expected_deterministics = ["model/abs_norm", "model/sine_norm", "model/norm_copy"]
Expand Down Expand Up @@ -561,6 +613,19 @@ def model():
state.collect_log_prob() # this should work


def test_log_prob_elemwise(fixture_model_with_plates):
model, expected_rv_shapes = fixture_model_with_plates
_, state = pm.evaluate_model(model())
log_prob_elemwise = dict(
zip(state.distributions, state.collect_log_prob_elemwise())
) # This will discard potentials in log_prob_elemwise
log_prob = state.collect_log_prob()
assert len(log_prob_elemwise) == len(expected_rv_shapes)
assert all(rv in log_prob_elemwise for rv in expected_rv_shapes)
assert all(log_prob_elemwise[rv].shape == shape for rv, shape in expected_rv_shapes.items())
assert log_prob.numpy() == sum(map(tf.reduce_sum, log_prob_elemwise.values())).numpy()


def test_deterministics(model_with_deterministics):
model, expected_deterministics, expected_ops, expected_ops_inputs = model_with_deterministics
_, state = pm.evaluate_model(model())
Expand Down Expand Up @@ -629,3 +694,25 @@ def model(observed):

with pytest.raises(EvaluationError):
pm.evaluate_model(model(observed_value))


def test_unreduced_log_prob(fixture_batch_shapes):
observed_value = np.ones(10, dtype="float32")

@pm.model
def model():
a = yield pm.Normal("a", 0, 1)
b = yield pm.HalfNormal("b", 1)
c = yield pm.Normal("c", loc=a, scale=b, plate=len(observed_value))

values = {
"model/a": np.zeros(fixture_batch_shapes, dtype="float32"),
"model/b": np.ones(fixture_batch_shapes, dtype="float32"),
}
observed = {
"model/c": np.broadcast_to(observed_value, fixture_batch_shapes + observed_value.shape)
}
state = pm.evaluate_model(model(), values=values, observed=observed)[1]
unreduced_log_prob = state.collect_unreduced_log_prob()
assert unreduced_log_prob.numpy().shape == fixture_batch_shapes
np.testing.assert_allclose(tf.reduce_sum(unreduced_log_prob), state.collect_log_prob())

0 comments on commit 014e59a

Please sign in to comment.