Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lazily-constructed BayesianNet to simplify BNN variational #112

Open
wants to merge 5 commits into
base: v4
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
initial; bnn works and needs more test
result isn't exactly same as before, but it could be that the computation order has changed.
  • Loading branch information
meta-inf committed Oct 11, 2018
commit a69b570f2fd47477d5d25f5ac0163ef8dcd2fd64
31 changes: 6 additions & 25 deletions examples/bayesian_neural_nets/bayesian_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,9 @@ def build_bnn(layer_sizes, n_particles):
bn = zs.BayesianNet()
x = bn.input("x")
h = tf.tile(x[None, ...], [n_particles, 1, 1])
for i, (n_in, n_out) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
w = bn.normal("w" + str(i), tf.zeros([n_out, n_in + 1]), std=1.,
group_ndims=2, n_samples=n_particles)
h = tf.concat([h, tf.ones(tf.shape(h)[:-1])[..., None]], -1)
h = tf.einsum("imk,ijk->ijm", w, h) / tf.sqrt(
tf.to_float(tf.shape(h)[2]))
if i < len(layer_sizes) - 2:
h = tf.nn.relu(h)
for i, n_out in enumerate(layer_sizes[1:]):
activation = tf.nn.relu if i < len(layer_sizes) - 2 else None
h = zs.nn.dense(bn, h, n_out, name=str(i), activation=activation)

y_mean = bn.output("y_mean", tf.squeeze(h, 2))
y_logstd = tf.get_variable("y_logstd", shape=[],
Expand All @@ -36,21 +31,6 @@ def build_bnn(layer_sizes, n_particles):
return bn


@zs.reuse_variables(scope="variational")
def build_mean_field_variational(layer_sizes, n_particles):
bn = zs.BayesianNet()
for i, (n_in, n_out) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
w_mean = tf.get_variable(
"w_mean_" + str(i), shape=[n_out, n_in + 1],
initializer=tf.constant_initializer(0.))
w_logstd = tf.get_variable(
"w_logstd_" + str(i), shape=[n_out, n_in + 1],
initializer=tf.constant_initializer(0.))
bn.normal("w" + str(i), w_mean, logstd=w_logstd,
n_samples=n_particles, group_ndims=2)
return bn


def main():
tf.set_random_seed(1234)
np.random.seed(1234)
Expand Down Expand Up @@ -79,12 +59,13 @@ def main():
w_names = ["w" + str(i) for i in range(len(layer_sizes) - 1)]

meta_model = build_bnn(layer_sizes, n_particles)
variational = build_mean_field_variational(layer_sizes, n_particles)
variational = zs.nn.mean_field_for_dense_weights()

def log_joint(bn):
log_pws = bn.cond_log_prob(w_names)
log_py_xw = bn.cond_log_prob('y')
return tf.add_n(log_pws) + tf.reduce_mean(log_py_xw, 1) * n_train
ret = tf.add_n(log_pws) + tf.reduce_mean(log_py_xw, 1) * n_train
return ret

meta_model.log_joint = log_joint

Expand Down
1 change: 1 addition & 0 deletions zhusuan/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from . import distributions
from . import variational
from . import nn
from .framework import *
from .hmc import *
from .evaluation import *
Expand Down
1 change: 1 addition & 0 deletions zhusuan/framework/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .bn import *
from .meta_bn import *
from .utils import *
from .node_storage import *
106 changes: 92 additions & 14 deletions zhusuan/framework/bn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from __future__ import print_function
from __future__ import division

import re
from collections import namedtuple
import tensorflow as tf
from tensorflow.python.client.session import (
register_session_run_conversion_functions)
Expand All @@ -18,16 +20,21 @@

__all__ = [
'StochasticTensor',
'BaseBayesianNet',
'ExplicitBayesianNet',
'BayesianNet',
'LazyBayesianNet'
]


# TODO: __str__, __repr__ for StochasticTensor

class StochasticTensor(TensorArithmeticMixin):
def __init__(self, bn, name, dist, observation=None, **kwargs):
def __init__(self, bn, name, dist, observation=None, tag=None,
tensor_shape=None, **kwargs):
if bn is None:
try:
# TODO make sure this is a v3 BN
bn = BayesianNet.get_context()
except RuntimeError:
pass
Expand All @@ -39,10 +46,15 @@ def __init__(self, bn, name, dist, observation=None, **kwargs):
self._dist = dist
self._dtype = dist.dtype
self._n_samples = kwargs.get("n_samples", None)
self._observation = observation
self._tag = tag
self._tensor_shape = tensor_shape
if observation is not None:
print(name, "set obs: {}".format(observation))
self._observation = self._check_observation(observation)
elif (self._bn is not None) and (self._name in self._bn._observed):
elif (self._bn is not None) and hasattr(self._bn, '_observed') and\
(self._name in self._bn._observed):
# deprecated v3 feature
self._observation = self._check_observation(
self._bn._observed[name])
else:
Expand Down Expand Up @@ -154,9 +166,15 @@ def prob(self, given):
)


class _BayesianNet(object):
class BaseBayesianNet(object):

"""
A BaseBayesianNet is a Name->StochasticTensor store.
Optionally, it maintains a pointer to the MetaBayesianNet, and fetch
observations from it.
"""

def __init__(self):
self._nodes = {}
try:
self._local_cxt = Local.get_context()
except RuntimeError:
Expand All @@ -165,28 +183,55 @@ def __init__(self):
self._meta_bn = self._local_cxt.meta_bn
else:
self._meta_bn = None
super(_BayesianNet, self).__init__()

@property
def nodes(self):
return self._nodes
super(BaseBayesianNet, self).__init__()

def _get_observation(self, name):
def _get_observation(self, name, tag=None, tensor_shape=None, n_samples=None,
**kwargs):
if self._local_cxt:
ret = self._local_cxt.observations.get(name, None)
ret = self._local_cxt.observations.get_node(
name, tag, tensor_shape, n_samples)
print(name, "get obs: {}".format(ret))
return ret
return None

def stochastic(self, name, dist, **kwargs):
def get_node(self, name, tag=None, shape=None, n_samples=None):
raise NotImplementedError()

def has_node(self, name, tag=None):
raise NotImplementedError()


class ExplicitBayesianNet(BaseBayesianNet):

def __init__(self):
self._nodes = {}
super(ExplicitBayesianNet, self).__init__()

def get_node(self, name, tag=None, shape=None, n_samples=None):
# TODO: optionally, check if tag and shape agree with what we have.
return self._nodes.get(name)

def has_node(self, name, tag=None):
return name in self._nodes

@property
def nodes(self):
return self._nodes

def stochastic(self, name, dist, tag=None, **kwargs):
if name in self._nodes:
raise ValueError(
"There exists a node with name '{}' in the {}. Names should "
"be unique.".format(name, BayesianNet.__name__))
# TODO: check whether `self` is BayesianNet or _BayesianNet
print(name, "add stochastic node")
tensor_shape = dist.get_batch_shape().concatenate(
dist.get_value_shape()).as_list()
kwargs['tensor_shape'] = tensor_shape
kwargs['tag'] = tag
observation = self._get_observation(name, **kwargs)
node = StochasticTensor(
self, name, dist, observation=self._get_observation(name), **kwargs)
self, name, dist, observation=observation, **kwargs)
self._nodes[name] = node
return node

Expand Down Expand Up @@ -285,7 +330,7 @@ def __setitem__(self, name, node):
MetaBayesianNet.observe.__name__))


class BayesianNet(_BayesianNet, Context):
class BayesianNet(ExplicitBayesianNet, Context):
def __init__(self, observed=None):
# To support deprecated features
self._observed = observed if observed else {}
Expand Down Expand Up @@ -691,3 +736,36 @@ def query(self, name_or_names, outputs=False, local_log_prob=False):
return list(zip(*ret))
else:
return tuple(ret)


Rule = namedtuple('Rule', 'tag_pattern constructor')


class LazyBayesianNet(BaseBayesianNet):

def __init__(self, rules):
self._rules = rules
self._node_cache = {}
super(LazyBayesianNet, self).__init__()

def _find_rule(self, tag):
matched = [rule for rule in self._rules
if re.match(rule.tag_pattern, tag) is not None]
assert len(matched) <= 1
return matched

def get_node(self, name, tag=None, shape=None, n_samples=None):
assert tag is not None and shape is not None
assert all([s is not None for s in shape])
if name in self._node_cache:
return self._node_cache[name]

matched = self._find_rule(tag)
assert len(matched) == 1
self._node_cache[name] = matched[0].constructor(
self, name, tag, shape, n_samples)
return self._node_cache[name]

def has_node(self, name, tag=None):
return tag is not None and len(self._find_rule(tag)) > 0

15 changes: 11 additions & 4 deletions zhusuan/framework/meta_bn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from functools import wraps

from zhusuan.framework.utils import Context
from zhusuan.framework.node_storage import FixedObservations, \
ObservationStorage


__all__ = [
Expand Down Expand Up @@ -38,6 +40,7 @@ def __init__(self, f, args=None, kwargs=None, scope=None,
self._f = f
# TODO: Whether to copy?
# TODO: make args and kwargs changeable after construction.
# TODO: figure out why the above comment is made.
self._args = args
self._kwargs = kwargs
self._scope = scope
Expand All @@ -53,18 +56,22 @@ def log_joint(self, value):
self._log_joint = value

def _run_with_observations(self, func, observations):
assert isinstance(observations, ObservationStorage)
with Local() as local_cxt:
local_cxt.observations = observations
local_cxt.meta_bn = self
return func(*self._args, **self._kwargs)

def observe(self, **kwargs):
print("observe:", kwargs)
def observe_storage(self, storage):
if (self._scope is not None) and (not self._reuse_variables):
with tf.variable_scope(self._scope):
return self._run_with_observations(self._f, kwargs)
return self._run_with_observations(self._f, storage)
else:
return self._run_with_observations(self._f, kwargs)
return self._run_with_observations(self._f, storage)

def observe(self, **kwargs):
print("observe:", kwargs)
return self.observe_storage(FixedObservations(kwargs))


def meta_bayesian_net(scope=None, reuse_variables=False):
Expand Down
57 changes: 57 additions & 0 deletions zhusuan/framework/node_storage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
class ObservationStorage(object):

def __init__(self):
pass

def get_node(self, name, tag=None, shape=None, n_samples=None):
raise NotImplementedError()

def merge_with(self, another_storage):
return StorageCollection([self, another_storage])


class FixedObservations(ObservationStorage):

def __init__(self, node_dict):
self._dict = node_dict

def get_node(self, name, tag=None, shape=None, n_samples=None):
# TODO optionally, validate the fetched node
return self._dict.get(name, None)


class BayesianNetStorage(ObservationStorage):

def __init__(self, bn):
self._bn = bn

def get_node(self, name, tag=None, shape=None, n_samples=None):
if not self._bn.has_node(name, tag):
return None
return self._bn.get_node(name, tag=tag, shape=shape, n_samples=n_samples)


class FilteredStorage(ObservationStorage):

def __init__(self, storage, filter_fn):
self._storage = storage
self._filter_fn = filter_fn

def get_node(self, name, tag=None, shape=None, n_samples=None):
ret = self._storage.get_node(name, tag, shape, n_samples)
if ret is None or not self._filter_fn(ret):
return None
return ret


class StorageCollection(ObservationStorage):

def __init__(self, storages):
self._storages = storages

def get_node(self, name, tag=None, shape=None, n_samples=None):
ret = [s.get_node(name, tag, shape, n_samples) for s in self._storages]
ret = [nd for nd in ret if nd is not None]
assert len(ret) <= 1
return ret[0] if len(ret) > 0 else None

Empty file added zhusuan/framework/node_view.py
Empty file.
4 changes: 4 additions & 0 deletions zhusuan/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from .variationals import *
Loading