forked from PythonOT/POT
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MRG] Free support Sinkhorn barycenters (PythonOT#387)
* Adding function for computing Sinkhorn Free Support barycenters * Adding exampel on Free Support Sinkhorn Barycenter * Fixing typo on free support sinkhorn barycenter example * Adding info on new Free Support Barycenter solver * Removing extra line so that code follows pep8 * Fixing issues with pep8 in example * Correcting issues with pep8 standards * Adding tests for free support sinkhorn barycenter * Adding section on Sinkhorn barycenter to the example * Changing distributions for the Sinkhorn barycenter example * Removing file that should not be on the last commit * Adding PR number to REALEASES.md * Adding new contributors * Update CONTRIBUTORS.md Co-authored-by: Rémi Flamary <[email protected]>
- Loading branch information
Showing
6 changed files
with
324 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,13 +4,14 @@ | |
2D free support Wasserstein barycenters of distributions | ||
======================================================== | ||
Illustration of 2D Wasserstein barycenters if distributions are weighted | ||
Illustration of 2D Wasserstein and Sinkhorn barycenters if distributions are weighted | ||
sum of diracs. | ||
""" | ||
|
||
# Authors: Vivien Seguy <[email protected]> | ||
# Rémi Flamary <[email protected]> | ||
# Eduardo Fernandes Montesuma <[email protected]> | ||
# | ||
# License: MIT License | ||
|
||
|
@@ -48,7 +49,7 @@ | |
|
||
|
||
# %% | ||
# Compute free support barycenter | ||
# Compute free support Wasserstein barycenter | ||
# ------------------------------- | ||
|
||
k = 200 # number of Diracs of the barycenter | ||
|
@@ -58,7 +59,28 @@ | |
X = ot.lp.free_support_barycenter(measures_locations, measures_weights, X_init, b) | ||
|
||
# %% | ||
# Plot the barycenter | ||
# Plot the Wasserstein barycenter | ||
# --------- | ||
|
||
pl.figure(2, (8, 3)) | ||
pl.scatter(x1[:, 0], x1[:, 1], alpha=0.5) | ||
pl.scatter(x2[:, 0], x2[:, 1], alpha=0.5) | ||
pl.scatter(X[:, 0], X[:, 1], s=b * 1000, marker='s', label='2-Wasserstein barycenter') | ||
pl.title('Data measures and their barycenter') | ||
pl.legend(loc="lower right") | ||
pl.show() | ||
|
||
# %% | ||
# Compute free support Sinkhorn barycenter | ||
|
||
k = 200 # number of Diracs of the barycenter | ||
X_init = np.random.normal(0., 1., (k, d)) # initial Dirac locations | ||
b = np.ones((k,)) / k # weights of the barycenter (it will not be optimized, only the locations are optimized) | ||
|
||
X = ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, 20, b, numItermax=15) | ||
|
||
# %% | ||
# Plot the Wasserstein barycenter | ||
# --------- | ||
|
||
pl.figure(2, (8, 3)) | ||
|
151 changes: 151 additions & 0 deletions
151
examples/barycenters/plot_free_support_sinkhorn_barycenter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
======================================================== | ||
2D free support Sinkhorn barycenters of distributions | ||
======================================================== | ||
Illustration of Sinkhorn barycenter calculation between empirical distributions understood as point clouds | ||
""" | ||
|
||
# Authors: Eduardo Fernandes Montesuma <[email protected]> | ||
# | ||
# License: MIT License | ||
|
||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import ot | ||
|
||
# %% | ||
# General Parameters | ||
# ------------------ | ||
reg = 1e-2 # Entropic Regularization | ||
numItermax = 20 # Maximum number of iterations for the Barycenter algorithm | ||
numInnerItermax = 50 # Maximum number of sinkhorn iterations | ||
n_samples = 200 | ||
|
||
# %% | ||
# Generate Data | ||
# ------------- | ||
|
||
X1 = np.random.randn(200, 2) | ||
X2 = 2 * np.concatenate([ | ||
np.concatenate([- np.ones([50, 1]), np.linspace(-1, 1, 50)[:, None]], axis=1), | ||
np.concatenate([np.linspace(-1, 1, 50)[:, None], np.ones([50, 1])], axis=1), | ||
np.concatenate([np.ones([50, 1]), np.linspace(1, -1, 50)[:, None]], axis=1), | ||
np.concatenate([np.linspace(1, -1, 50)[:, None], - np.ones([50, 1])], axis=1), | ||
], axis=0) | ||
X3 = np.random.randn(200, 2) | ||
X3 = 2 * (X3 / np.linalg.norm(X3, axis=1)[:, None]) | ||
X4 = np.random.multivariate_normal(np.array([0, 0]), np.array([[1., 0.5], [0.5, 1.]]), size=200) | ||
|
||
a1, a2, a3, a4 = ot.unif(len(X1)), ot.unif(len(X1)), ot.unif(len(X1)), ot.unif(len(X1)) | ||
|
||
# %% | ||
# Inspect generated distributions | ||
# ------------------------------- | ||
|
||
fig, axes = plt.subplots(1, 4, figsize=(16, 4)) | ||
|
||
axes[0].scatter(x=X1[:, 0], y=X1[:, 1], c='steelblue', edgecolor='k') | ||
axes[1].scatter(x=X2[:, 0], y=X2[:, 1], c='steelblue', edgecolor='k') | ||
axes[2].scatter(x=X3[:, 0], y=X3[:, 1], c='steelblue', edgecolor='k') | ||
axes[3].scatter(x=X4[:, 0], y=X4[:, 1], c='steelblue', edgecolor='k') | ||
|
||
axes[0].set_xlim([-3, 3]) | ||
axes[0].set_ylim([-3, 3]) | ||
axes[0].set_title('Distribution 1') | ||
|
||
axes[1].set_xlim([-3, 3]) | ||
axes[1].set_ylim([-3, 3]) | ||
axes[1].set_title('Distribution 2') | ||
|
||
axes[2].set_xlim([-3, 3]) | ||
axes[2].set_ylim([-3, 3]) | ||
axes[2].set_title('Distribution 3') | ||
|
||
axes[3].set_xlim([-3, 3]) | ||
axes[3].set_ylim([-3, 3]) | ||
axes[3].set_title('Distribution 4') | ||
|
||
plt.tight_layout() | ||
plt.show() | ||
|
||
# %% | ||
# Interpolating Empirical Distributions | ||
# ------------------------------------- | ||
|
||
fig = plt.figure(figsize=(10, 10)) | ||
|
||
weights = np.array([ | ||
[3 / 3, 0 / 3], | ||
[2 / 3, 1 / 3], | ||
[1 / 3, 2 / 3], | ||
[0 / 3, 3 / 3], | ||
]).astype(np.float32) | ||
|
||
for k in range(4): | ||
XB_init = np.random.randn(n_samples, 2) | ||
XB = ot.bregman.free_support_sinkhorn_barycenter( | ||
measures_locations=[X1, X2], | ||
measures_weights=[a1, a2], | ||
weights=weights[k], | ||
X_init=XB_init, | ||
reg=reg, | ||
numItermax=numItermax, | ||
numInnerItermax=numInnerItermax | ||
) | ||
ax = plt.subplot2grid((4, 4), (0, k)) | ||
ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k') | ||
ax.set_xlim([-3, 3]) | ||
ax.set_ylim([-3, 3]) | ||
|
||
for k in range(1, 4, 1): | ||
XB_init = np.random.randn(n_samples, 2) | ||
XB = ot.bregman.free_support_sinkhorn_barycenter( | ||
measures_locations=[X1, X3], | ||
measures_weights=[a1, a2], | ||
weights=weights[k], | ||
X_init=XB_init, | ||
reg=reg, | ||
numItermax=numItermax, | ||
numInnerItermax=numInnerItermax | ||
) | ||
ax = plt.subplot2grid((4, 4), (k, 0)) | ||
ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k') | ||
ax.set_xlim([-3, 3]) | ||
ax.set_ylim([-3, 3]) | ||
|
||
for k in range(1, 4, 1): | ||
XB_init = np.random.randn(n_samples, 2) | ||
XB = ot.bregman.free_support_sinkhorn_barycenter( | ||
measures_locations=[X3, X4], | ||
measures_weights=[a1, a2], | ||
weights=weights[k], | ||
X_init=XB_init, | ||
reg=reg, | ||
numItermax=numItermax, | ||
numInnerItermax=numInnerItermax | ||
) | ||
ax = plt.subplot2grid((4, 4), (3, k)) | ||
ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k') | ||
ax.set_xlim([-3, 3]) | ||
ax.set_ylim([-3, 3]) | ||
|
||
for k in range(1, 3, 1): | ||
XB_init = np.random.randn(n_samples, 2) | ||
XB = ot.bregman.free_support_sinkhorn_barycenter( | ||
measures_locations=[X2, X4], | ||
measures_weights=[a1, a2], | ||
weights=weights[k], | ||
X_init=XB_init, | ||
reg=reg, | ||
numItermax=numItermax, | ||
numInnerItermax=numInnerItermax | ||
) | ||
ax = plt.subplot2grid((4, 4), (k, 3)) | ||
ax.scatter(XB[:, 0], XB[:, 1], color='steelblue', edgecolor='k') | ||
ax.set_xlim([-3, 3]) | ||
ax.set_ylim([-3, 3]) | ||
|
||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
# Author: Remi Flamary <[email protected]> | ||
# Kilian Fatras <[email protected]> | ||
# Quang Huy Tran <[email protected]> | ||
# Eduardo Fernandes Montesuma <[email protected]> | ||
# | ||
# License: MIT License | ||
|
||
|
@@ -490,6 +491,31 @@ def test_barycenter(nx, method, verbose, warn): | |
ot.bregman.barycenter(A_nx, M_nx, reg, log=True) | ||
|
||
|
||
def test_free_support_sinkhorn_barycenter(): | ||
measures_locations = [ | ||
np.array([-1.]).reshape((1, 1)), # First dirac support | ||
np.array([1.]).reshape((1, 1)) # Second dirac support | ||
] | ||
|
||
measures_weights = [ | ||
np.array([1.]), # First dirac sample weights | ||
np.array([1.]) # Second dirac sample weights | ||
] | ||
|
||
# Barycenter initialization | ||
X_init = np.array([-12.]).reshape((1, 1)) | ||
|
||
# Obvious barycenter locations. Take a look on test_ot.py, test_free_support_barycenter | ||
bar_locations = np.array([0.]).reshape((1, 1)) | ||
|
||
# Calculate free support barycenter w/ Sinkhorn algorithm. We set the entropic regularization | ||
# term to 1, but this should be, in general, fine-tuned to the problem. | ||
X = ot.bregman.free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg=1) | ||
|
||
# Verifies if calculated barycenter matches ground-truth | ||
np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7) | ||
|
||
|
||
@pytest.mark.parametrize("method, verbose, warn", | ||
product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], | ||
[True, False], [True, False])) | ||
|