Skip to content

Commit

Permalink
[MRG] Add Bures-Wasserstein arycenetrs example (and debug the solver) (
Browse files Browse the repository at this point in the history
…PythonOT#584)

* add exmaple and debug barycenters

* debug barycenter again
  • Loading branch information
rflamary authored Dec 1, 2023
1 parent 55a851e commit 659cde8
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 6 deletions.
2 changes: 1 addition & 1 deletion RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
+ Add `stop_criterion` feature to (un)regularized (f)gw barycenter solvers (PR #578)
+ Add `fixed_structure` and `fixed_features` to entropic fgw barycenter solver (PR #578)
+ Add new BAPG solvers with KL projections for GW and FGW (PR #581)
+ Add Bures-Wasserstein barycenter in `ot.gaussian` (PR #582)
+ Add Bures-Wasserstein barycenter in `ot.gaussian` and example (PR #582, PR #584)


#### Closed issues
Expand Down
127 changes: 127 additions & 0 deletions examples/barycenters/plot_gaussian_barycenter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# -*- coding: utf-8 -*-
"""
========================================================
Gaussian Bures-Wasserstein barycenters
========================================================
Illustration of Gaussian Bures-Wasserstein barycenters.
"""

# Authors: Rémi Flamary <[email protected]>
#
# License: MIT License

# sphinx_gallery_thumbnail_number = 2
# %%
from matplotlib import colors
from matplotlib.patches import Ellipse
import numpy as np
import matplotlib.pylab as pl
import ot


# %%
# Define Gaussian Covariances and distributions
# ---------------------------------------------

C1 = np.array([[0.5, -0.4], [-0.4, 0.5]])
C2 = np.array([[1, 0.3], [0.3, 1]])
C3 = np.array([[1.5, 0], [0, 0.5]])
C4 = np.array([[0.5, 0], [0, 1.5]])

C = np.stack((C1, C2, C3, C4))

m1 = np.array([0, 0])
m2 = np.array([0, 4])
m3 = np.array([4, 0])
m4 = np.array([4, 4])

m = np.stack((m1, m2, m3, m4))

# %%
# Plot the distributions
# ----------------------


def draw_cov(mu, C, color=None, label=None, nstd=1):

def eigsorted(cov):
vals, vecs = np.linalg.eigh(cov)
order = vals.argsort()[::-1]
return vals[order], vecs[:, order]

vals, vecs = eigsorted(C)
theta = np.degrees(np.arctan2(*vecs[:, 0][::-1]))
w, h = 2 * nstd * np.sqrt(vals)
ell = Ellipse(xy=(mu[0], mu[1]),
width=w, height=h, alpha=0.5,
angle=theta, facecolor=color, edgecolor=color, label=label, fill=True)
pl.gca().add_artist(ell)
#pl.scatter(mu[0],mu[1],color=color, marker='x')


axis = [-1.5, 5.5, -1.5, 5.5]

pl.figure(1, (8, 2))
pl.clf()

pl.subplot(1, 4, 1)
draw_cov(m1, C1, color='C0')
pl.axis(axis)
pl.title('$\mathcal{N}(m_1,\Sigma_1)$')

pl.subplot(1, 4, 2)
draw_cov(m2, C2, color='C1')
pl.axis(axis)
pl.title('$\mathcal{N}(m_2,\Sigma_2)$')

pl.subplot(1, 4, 3)
draw_cov(m3, C3, color='C2')
pl.axis(axis)
pl.title('$\mathcal{N}(m_3,\Sigma_3)$')

pl.subplot(1, 4, 4)
draw_cov(m4, C4, color='C3')
pl.axis(axis)
pl.title('$\mathcal{N}(m_4,\Sigma_4)$')

# %%
# Compute Bures-Wasserstein barycenters and plot them
# -------------------------------------------

# basis for bilinear interpolation
v1 = np.array((1, 0, 0, 0))
v2 = np.array((0, 1, 0, 0))
v3 = np.array((0, 0, 1, 0))
v4 = np.array((0, 0, 0, 1))


colors = np.stack((colors.to_rgb('C0'),
colors.to_rgb('C1'),
colors.to_rgb('C2'),
colors.to_rgb('C3')))

pl.figure(2, (8, 8))

nb_interp = 6

for i in range(nb_interp):
for j in range(nb_interp):
tx = float(i) / (nb_interp - 1)
ty = float(j) / (nb_interp - 1)

# weights are constructed by bilinear interpolation
tmp1 = (1 - tx) * v1 + tx * v2
tmp2 = (1 - tx) * v3 + tx * v4
weights = (1 - ty) * tmp1 + ty * tmp2

color = np.dot(colors.T, weights)

mb, Cb = ot.gaussian.bures_wasserstein_barycenter(m, C, weights)

draw_cov(mb, Cb, color=color, label=None, nstd=0.3)

pl.axis(axis)
pl.axis('off')
pl.tight_layout()
10 changes: 5 additions & 5 deletions ot/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,14 +399,14 @@ def bures_wasserstein_barycenter(m, C, weights=None, num_iter=1000, eps=1e-7, lo
"""
nx = get_backend(*C, *m,)

if weights is None:
weights = nx.ones(C.shape[0], type_as=C[0]) / C.shape[0]

# Compute the mean barycenter
mb = nx.mean(m)
mb = nx.sum(m * weights[:, None], axis=0)

# Init the covariance barycenter
Cb = nx.mean(C, axis=0)

if weights is None:
weights = nx.ones(len(C), type_as=C[0]) / len(C)
Cb = nx.mean(C * weights[:, None, None], axis=0)

for it in range(num_iter):
# fixed point update
Expand Down

0 comments on commit 659cde8

Please sign in to comment.