Skip to content

Commit

Permalink
[MRG] Free support Sinkhorn barycenters (PythonOT#387)
Browse files Browse the repository at this point in the history
* Adding function for computing Sinkhorn Free Support barycenters

* Adding exampel on Free Support Sinkhorn Barycenter

* Fixing typo on free support sinkhorn barycenter example

* Adding info on new Free Support Barycenter solver

* Removing extra line so that code follows pep8

* Fixing issues with pep8 in example

* Correcting issues with pep8 standards

* Adding tests for free support sinkhorn barycenter

* Adding section on Sinkhorn barycenter to the example

* Changing distributions for the Sinkhorn barycenter example

* Removing file that should not be on the last commit

* Adding PR number to REALEASES.md

* Adding new contributors

* Update CONTRIBUTORS.md

Co-authored-by: Rémi Flamary <[email protected]>
  • Loading branch information
eddardd and rflamary authored Jul 27, 2022
1 parent 7c2a952 commit 818c7ac
Show file tree
Hide file tree
Showing 6 changed files with 324 additions and 3 deletions.
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ The contributors to this library are:
* [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning)
* [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein Barycenters)
* [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug)
* [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter)

## Acknowledgments

Expand Down
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#### New features

- Added Generalized Wasserstein Barycenter solver + example (PR #372), fixed graphical details on the example (PR #376)
- Added Free Support Sinkhorn Barycenter + example (PR #387)

#### Closed issues

Expand Down
28 changes: 25 additions & 3 deletions examples/barycenters/plot_free_support_barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
2D free support Wasserstein barycenters of distributions
========================================================
Illustration of 2D Wasserstein barycenters if distributions are weighted
Illustration of 2D Wasserstein and Sinkhorn barycenters if distributions are weighted
sum of diracs.
"""

# Authors: Vivien Seguy <[email protected]>
# Rémi Flamary <[email protected]>
# Eduardo Fernandes Montesuma <[email protected]>
#
# License: MIT License

Expand Down Expand Up @@ -48,7 +49,7 @@


# %%
# Compute free support barycenter
# Compute free support Wasserstein barycenter
# -------------------------------

k = 200 # number of Diracs of the barycenter
Expand All @@ -58,7 +59,28 @@
X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b)

# %%
# Plot the barycenter
# Plot the Wasserstein barycenter
# ---------

pl.figure(2, (8, 3))
pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5)
pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5)
pl.scatter(X[:, 0], X[:, 1], s=b * 1000, marker='s', label='2-Wasserstein barycenter')
pl.title('Data measures and their barycenter')
pl.legend(loc="lower right")
pl.show()

# %%
# Compute free support Sinkhorn barycenter

k = 200 # number of Diracs of the barycenter
X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations
b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized)

X = ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, 20, b, numItermax=15)

# %%
# Plot the Wasserstein barycenter
# ---------

pl.figure(2, (8, 3))
Expand Down
151 changes: 151 additions & 0 deletions examples/barycenters/plot_free_support_sinkhorn_barycenter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# -*- coding: utf-8 -*-
"""
========================================================
2D free support Sinkhorn barycenters of distributions
========================================================
Illustration of Sinkhorn barycenter calculation between empirical distributions understood as point clouds
"""

# Authors: Eduardo Fernandes Montesuma <[email protected]>
#
# License: MIT License

import numpy as np
import matplotlib.pyplot as plt
import ot

# %%
# General Parameters
# ------------------
reg = 1e-2 # Entropic Regularization
numItermax = 20 # Maximum number of iterations for the Barycenter algorithm
numInnerItermax = 50 # Maximum number of sinkhorn iterations
n_samples = 200

# %%
# Generate Data
# -------------

X1 = np.random.randn(200, 2)
X2 = 2 * np.concatenate([
np.concatenate([- np.ones([50, 1]), np.linspace(-1, 1, 50)[:, None]], axis=1),
np.concatenate([np.linspace(-1, 1, 50)[:, None], np.ones([50, 1])], axis=1),
np.concatenate([np.ones([50, 1]), np.linspace(1, -1, 50)[:, None]], axis=1),
np.concatenate([np.linspace(1, -1, 50)[:, None], - np.ones([50, 1])], axis=1),
], axis=0)
X3 = np.random.randn(200, 2)
X3 = 2 * (X3 / np.linalg.norm(X3, axis=1)[:, None])
X4 = np.random.multivariate_normal(np.array([0, 0]), np.array([[1., 0.5], [0.5, 1.]]), size=200)

a1, a2, a3, a4 = ot.unif(len(X1)), ot.unif(len(X1)), ot.unif(len(X1)), ot.unif(len(X1))

# %%
# Inspect generated distributions
# -------------------------------

fig, axes = plt.subplots(1, 4, figsize=(16, 4))

axes[0].scatter(x=X1[:, 0], y=X1[:, 1], c='steelblue', edgecolor='k')
axes[1].scatter(x=X2[:, 0], y=X2[:, 1], c='steelblue', edgecolor='k')
axes[2].scatter(x=X3[:, 0], y=X3[:, 1], c='steelblue', edgecolor='k')
axes[3].scatter(x=X4[:, 0], y=X4[:, 1], c='steelblue', edgecolor='k')

axes[0].set_xlim([-3, 3])
axes[0].set_ylim([-3, 3])
axes[0].set_title('Distribution 1')

axes[1].set_xlim([-3, 3])
axes[1].set_ylim([-3, 3])
axes[1].set_title('Distribution 2')

axes[2].set_xlim([-3, 3])
axes[2].set_ylim([-3, 3])
axes[2].set_title('Distribution 3')

axes[3].set_xlim([-3, 3])
axes[3].set_ylim([-3, 3])
axes[3].set_title('Distribution 4')

plt.tight_layout()
plt.show()

# %%
# Interpolating Empirical Distributions
# -------------------------------------

fig = plt.figure(figsize=(10, 10))

weights = np.array([
[3 / 3, 0 / 3],
[2 / 3, 1 / 3],
[1 / 3, 2 / 3],
[0 / 3, 3 / 3],
]).astype(np.float32)

for k in range(4):
XB_init = np.random.randn(n_samples, 2)
XB = ot.bregman.free_support_sinkhorn_barycenter(
measures_locations=[X1, X2],
measures_weights=[a1, a2],
weights=weights[k],
X_init=XB_init,
reg=reg,
numItermax=numItermax,
numInnerItermax=numInnerItermax
)
ax = plt.subplot2grid((4, 4), (0, k))
ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k')
ax.set_xlim([-3, 3])
ax.set_ylim([-3, 3])

for k in range(1, 4, 1):
XB_init = np.random.randn(n_samples, 2)
XB = ot.bregman.free_support_sinkhorn_barycenter(
measures_locations=[X1, X3],
measures_weights=[a1, a2],
weights=weights[k],
X_init=XB_init,
reg=reg,
numItermax=numItermax,
numInnerItermax=numInnerItermax
)
ax = plt.subplot2grid((4, 4), (k, 0))
ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k')
ax.set_xlim([-3, 3])
ax.set_ylim([-3, 3])

for k in range(1, 4, 1):
XB_init = np.random.randn(n_samples, 2)
XB = ot.bregman.free_support_sinkhorn_barycenter(
measures_locations=[X3, X4],
measures_weights=[a1, a2],
weights=weights[k],
X_init=XB_init,
reg=reg,
numItermax=numItermax,
numInnerItermax=numInnerItermax
)
ax = plt.subplot2grid((4, 4), (3, k))
ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k')
ax.set_xlim([-3, 3])
ax.set_ylim([-3, 3])

for k in range(1, 3, 1):
XB_init = np.random.randn(n_samples, 2)
XB = ot.bregman.free_support_sinkhorn_barycenter(
measures_locations=[X2, X4],
measures_weights=[a1, a2],
weights=weights[k],
X_init=XB_init,
reg=reg,
numItermax=numItermax,
numInnerItermax=numInnerItermax
)
ax = plt.subplot2grid((4, 4), (k, 3))
ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k')
ax.set_xlim([-3, 3])
ax.set_ylim([-3, 3])

plt.show()
120 changes: 120 additions & 0 deletions ot/bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -1540,6 +1540,126 @@ def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000,
return geometricBar(weights, UKv)


def free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg, b=None, weights=None,
numItermax=100, numInnerItermax=1000, stopThr=1e-7, verbose=False, log=None,
**kwargs):
r"""
Solves the free support (locations of the barycenters are optimized, not the weights) regularized Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Sinkhorn divergence), formally:
.. math::
\min_\mathbf{X} \quad \sum_{i=1}^N w_i W_{reg}^2(\mathbf{b}, \mathbf{X}, \mathbf{a}_i, \mathbf{X}_i)
where :
- :math:`w \in \mathbb{(0, 1)}^{N}`'s are the barycenter weights and sum to one
- `measure_weights` denotes the :math:`\mathbf{a}_i \in \mathbb{R}^{k_i}`: empirical measures weights (on simplex)
- `measures_locations` denotes the :math:`\mathbf{X}_i \in \mathbb{R}^{k_i, d}`: empirical measures atoms locations
- :math:`\mathbf{b} \in \mathbb{R}^{k}` is the desired weights vector of the barycenter
This problem is considered in :ref:`[20] <references-free-support-barycenter>` (Algorithm 2).
There are two differences with the following codes:
- we do not optimize over the weights
- we do not do line search for the locations updates, we use i.e. :math:`\theta = 1` in
:ref:`[20] <references-free-support-barycenter>` (Algorithm 2). This can be seen as a discrete
implementation of the fixed-point algorithm of
:ref:`[43] <references-free-support-barycenter>` proposed in the continuous setting.
- at each iteration, instead of solving an exact OT problem, we use the Sinkhorn algorithm for calculating the
transport plan in :ref:`[20] <references-free-support-barycenter>` (Algorithm 2).
Parameters
----------
measures_locations : list of N (k_i,d) array-like
The discrete support of a measure supported on :math:`k_i` locations of a `d`-dimensional space
(:math:`k_i` can be different for each element of the list)
measures_weights : list of N (k_i,) array-like
Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one
representing the weights of each discrete input measure
X_init : (k,d) array-like
Initialization of the support locations (on `k` atoms) of the barycenter
reg : float
Regularization term >0
b : (k,) array-like
Initialization of the weights of the barycenter (non-negatives, sum to 1)
weights : (N,) array-like
Initialization of the coefficients of the barycenter (non-negatives, sum to 1)
numItermax : int, optional
Max number of iterations
numInnerItermax : int, optional
Max number of iterations when calculating the transport plans with Sinkhorn
stopThr : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
Returns
-------
X : (k,d) array-like
Support locations (on k atoms) of the barycenter
See Also
--------
ot.bregman.sinkhorn : Entropic regularized OT solver
ot.lp.free_support_barycenter : Barycenter solver based on Linear Programming
.. _references-free-support-barycenter:
References
----------
.. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.
.. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.
"""
nx = get_backend(*measures_locations, *measures_weights, X_init)

iter_count = 0

N = len(measures_locations)
k = X_init.shape[0]
d = X_init.shape[1]
if b is None:
b = nx.ones((k,), type_as=X_init) / k
if weights is None:
weights = nx.ones((N,), type_as=X_init) / N

X = X_init

log_dict = {}
displacement_square_norms = []

displacement_square_norm = stopThr + 1.

while (displacement_square_norm > stopThr and iter_count < numItermax):

T_sum = nx.zeros((k, d), type_as=X_init)

for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights):
M_i = dist(X, measure_locations_i)
T_i = sinkhorn(b, measure_weights_i, M_i, reg=reg, numItermax=numInnerItermax, **kwargs)
T_sum = T_sum + weight_i * 1. / b[:, None] * nx.dot(T_i, measure_locations_i)

displacement_square_norm = nx.sum((T_sum - X) ** 2)
if log:
displacement_square_norms.append(displacement_square_norm)

X = T_sum

if verbose:
print('iteration %d, displacement_square_norm=%f\n', iter_count, displacement_square_norm)

iter_count += 1

if log:
log_dict['displacement_square_norms'] = displacement_square_norms
return X, log_dict
else:
return X


def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000,
stopThr=1e-4, verbose=False, log=False, warn=True):
r"""Compute the entropic wasserstein barycenter in log-domain
Expand Down
26 changes: 26 additions & 0 deletions test/test_bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Author: Remi Flamary <[email protected]>
# Kilian Fatras <[email protected]>
# Quang Huy Tran <[email protected]>
# Eduardo Fernandes Montesuma <[email protected]>
#
# License: MIT License

Expand Down Expand Up @@ -490,6 +491,31 @@ def test_barycenter(nx, method, verbose, warn):
ot.bregman.barycenter(A_nx, M_nx, reg, log=True)


def test_free_support_sinkhorn_barycenter():
measures_locations = [
np.array([-1.]).reshape((1, 1)), # First dirac support
np.array([1.]).reshape((1, 1)) # Second dirac support
]

measures_weights = [
np.array([1.]), # First dirac sample weights
np.array([1.]) # Second dirac sample weights
]

# Barycenter initialization
X_init = np.array([-12.]).reshape((1, 1))

# Obvious barycenter locations. Take a look on test_ot.py, test_free_support_barycenter
bar_locations = np.array([0.]).reshape((1, 1))

# Calculate free support barycenter w/ Sinkhorn algorithm. We set the entropic regularization
# term to 1, but this should be, in general, fine-tuned to the problem.
X = ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg=1)

# Verifies if calculated barycenter matches ground-truth
np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7)


@pytest.mark.parametrize("method, verbose, warn",
product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"],
[True, False], [True, False]))
Expand Down

0 comments on commit 818c7ac

Please sign in to comment.