This repository contains an implementation of the algorithm presented in Black-box density function estimation using recursive partitioning arXiv, referred to as DEFER (DEnsity Function Estimation using Recursive partitioning). The paper was presented at the International Conference on Machine Learning (ICML) 2021.
DEFER allows efficient Bayesian inference on general problems involving up to about ten random variables or dimensions, without the need to specify anything else than
- the unnormalised density function (which can be black-box),
- domain bounds (enclosing the typical set, although allowed to be doing so with very large margins),
- a density function evaluation budget.
The code comes with a high-level interface intended for ease-of-use.
Below we will show:
- How the DEFER algorithm can be used to approximate arbitrary distributions (of moderate dimension), such as posterior distributions which are allowed to be e.g. multimodal, discontinuous, exhibit complicated correlations, and have zero density regions.
- Some available operations the approximation provides, including:
- fast, constant time sampling,
- computation of normalisation constant,
- analytical expectations of functions with respect to the approximation,
- conditionals to be derived quickly, and
- approximate marginalisation of a subset of variables or dimensions.
import corner
import matplotlib.pylab as plt
from defer.helpers import *
from defer.variables import Variable
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
# Configure plot defaults
plt.rcParams['axes.facecolor'] = 'white'
plt.rcParams['grid.color'] = '#666666'
%config InlineBackend.figure_format = 'png'
The DEFER algorithm constructs a distribution approximation given a provided (arbitrary) unnormalised density function
In this example we provide the algorithm with such a function, namely where the density value at a given location depends on the distance to the closest point of a shape embedded in the same space. The particular shape used is a spiral embedded in three dimensions. The distance (as a function of
Note that the DEFER algorithm does not need to know what the function
def spiral(t):
w = 6 * np.pi
v = 0.9
c = 1
return np.array([
t,
(v * t + c) * np.cos(w * t),
(v * t + c) * np.sin(w * t),
]).transpose()
# Domain bounds
curve_lower_x = -1
curve_upper_x = 1
other_dims_lower = [-2, -2]
other_dims_upper = [2, 2]
lower = [curve_lower_x] + other_dims_lower
upper = [curve_upper_x] + other_dims_upper
def find_shortest_distance(
x, curve, lower_x=curve_lower_x, upper_x=curve_upper_x):
trial_input_points = np.linspace(lower_x, upper_x, num=5000)
curve_points = curve(trial_input_points)
distances = np.linalg.norm(np.abs(curve_points - x), axis=1)
index = int(np.argmin(distances))
shortest_distance = distances[index]
return shortest_distance
# Unnormalised density function. This can be any continuous-input function returning a density value.
def f(x):
distance_to_closest_point_on_curve = find_shortest_distance(x, curve=spiral)
energy = distance_to_closest_point_on_curve
return np.exp(-30 * energy)
Declare the random variables, corresponding to the parameters of the density function
The Variables object produces the density function domain (variables.domain) - or sample space. It will also keep track of mappings between variables (or slices of them using []) to the corresponding dimensions.
Note that the variables do not have any state - they simply act as semantic collections of dimensions with defined boundaries. Having the variables kept together as a collection in the Variables object will be particularly useful for keeping track of what remains when deriving conditionals or marginals (more on that later).
x = Variable(
lower=lower,
upper=upper,
name="x"
)
variables = Variables([x])
Here we use a high-level wrapper around the DEFER algorithm which given the
The budget of the algorithm, in terms of number of evaluations of
Note that 'is_log_fn' also needs to be specified. 'is_log_fn' should be set to True if the
approx_joint: DensityFunctionApproximation = construct(
fn=f,
variables=variables,
is_log_fn=False,
num_fn_calls=30000,
callback=lambda i, density:
print("#Evals: %s. Log Z: %.2f" %
(density.num_partitions, np.log(density.z))),
callback_freq_fn_calls=3000,
)
#Evals: 7. Log Z: -4.71
#Evals: 3015. Log Z: -1.58
#Evals: 6019. Log Z: -1.48
#Evals: 9025. Log Z: -1.46
#Evals: 12027. Log Z: -1.41
#Evals: 15027. Log Z: -1.41
#Evals: 18039. Log Z: -1.41
#Evals: 21043. Log Z: -1.41
#Evals: 24063. Log Z: -1.40
#Evals: 27063. Log Z: -1.39
#Evals: 30001. Log Z: -1.39
The algorithm returns a 'DensityFunctionApproximation' object. This is a high-level wrapper around the approximation, which is a piece-wise constant function with as many elements/partitions as number of
Normalisation constant of the approximation. This can be used for, e.g., evidence estimation.
z = approx_joint.z
Analytical expectation of functions with respect to the distribution approximation.
Due to the piece-wise constant approximation being used, this simply corresponds to a weighted summation. The 'expectation' method is for convenience for computing expectations of functions provided the function, taking the integrand (
differential_entropy = approx_joint.expectation(lambda f, x, z: -np.log(f / z))
A few, very common such expectations are provided as methods (such as mean and variance), which are internally using the 'expectation' method.
mean = approx_joint.mean()
var = approx_joint.var()
The (estimated) mode is provided by the 'mode' method. This is simply the centroid of the partition of the approximation which has the highest associated
mode = approx_joint.mode()
The approximation of the density function can be queried just as the original function
x_test = np.array([0.5, 0.5, 0.5])
f_test = approx_joint(x_test)
p_test = approx_joint.prob(x_test)
Often what one is interested in is posterior samples. To obtain samples from the distribution approximation, we first construct a sampler function using the 'sampler' method. Internally it performs the pre-processing required to be able to sample very quickly.
sampler = approx_joint.sampler()
The sampler function allows sampling from the distribution in constant time per sample.
Below we draw one million samples.
Returned is a list of numpy arrays - one array of samples per variable.
In our example, however, we only have the
x_samples, = sampler(num_samples=1000000)
x_samples.shape
(1000000, 3)
The DensityFunctionApproximation can easily be saved or loaded using the 'save' and 'load' methods. Internally the tree-structure of partitions (only) are saved/loaded, instead of the whole Python object.
approx_joint.save("/tmp/approx_joint.pickle")
approx_joint.load("/tmp/approx_joint.pickle")
We will now plot the samples in a corner plot in order to visualize the distribution.
def plot(density: DensityFunctionApproximation, samples=None):
if samples is None:
print("Preparing sampler..")
sampler = density.sampler()
print("Sampling..")
samples_per_variable = sampler(num_samples=10 ** 6)
samples = np.concatenate(samples_per_variable, axis=-1)
print("Plotting..")
figure = corner.corner(
samples,
range=density.variables.bounds,
labels=[
"%s: dim %s" % (var.name, index)
for var in density.variables.variable_slices
for index in var.indices
],
plot_contours=False,
no_fill_contours=True,
bins=100,
plot_datapoints=False,
)
plt.show()
plot(approx_joint, samples=x_samples)
Plotting..
We can easily derive a conditional using the 'conditional' method,
yielding a new approximation of the same type.
In this example we condition the last dimension of the
approx_conditional: DensityFunctionApproximation = approx_joint.conditional({
x[-1]: np.array([0.5])
})
Let us plot samples from the derived conditional in a corner plot to visualize it.
plot(approx_conditional)
Preparing sampler..
Sampling..
Plotting..
The conditional of the approximation, being of the same type, provides the same methods and properties. Including its normalisation constant, mode, and methods for returning density values and computing expectations.
z = approx_conditional.z
approx_conditional.mode()
approx_conditional.expectation(lambda f, x, z: -np.log(f / z))
approx_conditional.mean()
approx_conditional.var()
approx_conditional.prob(np.array([0.5, 0.5]))
7.837536348532179338e-08
We may also marginalise variables (or dimensions of variables) using the 'construct_marginal' helper function. Internally it treats marginalisation as just another problem for the DEFER algorithm: now where each
Returned is a DensityFunctionApproximation similar to in the previously discussed cases.
Note that in this example the provided
In the below example the first dimension of the
approx_marginal: DensityFunctionApproximation = construct_marginal(
fn=approx_joint,
variables=variables,
marginalize_variable_slices=[x[0]],
is_log_fn=False,
num_outer_fn_calls=5000,
num_inner_fn_calls=15,
callback=lambda i, density:
print("#Evals: %s. Log Z: %.2f" %
(density.num_partitions, np.log(density.z))),
callback_freq_fn_calls=500,
)
#Evals: 5. Log Z: -1.16
#Evals: 509. Log Z: -1.48
#Evals: 1009. Log Z: -1.48
#Evals: 1509. Log Z: -1.46
#Evals: 2011. Log Z: -1.45
#Evals: 2515. Log Z: -1.44
#Evals: 3019. Log Z: -1.44
#Evals: 3523. Log Z: -1.44
#Evals: 4027. Log Z: -1.43
#Evals: 4531. Log Z: -1.43
#Evals: 5007. Log Z: -1.43
plot(approx_marginal)
Preparing sampler..
Sampling..
Plotting..
approx_marginal.z
approx_marginal.mode()
approx_marginal.expectation(lambda f, x, z: -np.log(f / z))
approx_marginal.mean()
approx_marginal.var()
approx_marginal.prob(np.array([0.5, 0.5]))
0.060272058916242459574
As the marginal approximation provides the same operations, we can of course, for example, derive a conditional of the marginal similar to before - or use any of the other operations, in any order.
approx_conditional_of_marginal = approx_marginal.conditional({
x[-1]: np.array([0.5])
})
plot(approx_conditional_of_marginal)
Preparing sampler..
Sampling..
Plotting..
approx_conditional_of_marginal.z
approx_conditional_of_marginal.mode()
approx_conditional_of_marginal.expectation(lambda f, x, z: -np.log(f / z))
approx_conditional_of_marginal.mean()
approx_conditional_of_marginal.var()
approx_conditional_of_marginal.prob(np.array([0.5]))
0.18142085889723261718