Skip to content

Commit

Permalink
Merge pull request PythonOT#47 from rflamary/bary
Browse files Browse the repository at this point in the history
LP Wasserstein barycenter with scipy linear solver and/or cvxopt
  • Loading branch information
rflamary authored May 29, 2018
2 parents ec79b79 + 54f0b47 commit 90efa5a
Show file tree
Hide file tree
Showing 9 changed files with 488 additions and 6 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ notebook :
ipython notebook --matplotlib=inline --notebook-dir=notebooks/

autopep8 :
autopep8 -ir test ot examples
autopep8 -ir test ot examples --jobs -1

aautopep8 :
autopep8 -air test ot examples
autopep8 -air test ot examples --jobs -1

FORCE :
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ This open source Python library provide several solvers for optimization problem
It provides the following solvers:

* OT Network Flow solver for the linear program/ Earth Movers Distance [1].
* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] with optional GPU implementation (required cudamat).
* Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] with optional GPU implementation (requires cudamat).
* Non regularized Wasserstein barycenters [16] with LP solver.
* Bregman projections for Wasserstein barycenter [3] and unmixing [4].
* Optimal transport for domain adaptation with group lasso regularization [5]
* Conditional gradient [6] and Generalized conditional gradient for regularized OT [7].
Expand Down Expand Up @@ -210,3 +211,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[14] Knott, M. and Smith, C. S. (1984).[On the optimal mapping of distributions](https://link.springer.com/article/10.1007/BF00934745), Journal of Optimization Theory and Applications Vol 43.

[15] Peyré, G., & Cuturi, M. (2018). [Computational Optimal Transport](https://arxiv.org/pdf/1803.00567.pdf) .

[16] Agueh, M., & Carlier, G. (2011). [Barycenters in the Wasserstein space](https://hal.archives-ouvertes.fr/hal-00637399/document). SIAM Journal on Mathematical Analysis, 43(2), 904-924.
292 changes: 292 additions & 0 deletions examples/plot_barycenter_lp_vs_entropic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
# -*- coding: utf-8 -*-
"""
=================================================================================
1D Wasserstein barycenter comparison between exact LP and entropic regularization
=================================================================================
This example illustrates the computation of regularized Wasserstein Barycenter
as proposed in [3] and exact LP barycenters using standard LP solver.
It reproduces approximately Figure 3.1 and 3.2 from the following paper:
Cuturi, M., & Peyré, G. (2016). A smoothed dual approach for variational
Wasserstein problems. SIAM Journal on Imaging Sciences, 9(1), 320-343.
[3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015).
Iterative Bregman projections for regularized transportation problems
SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
"""

# Author: Remi Flamary <[email protected]>
#
# License: MIT License

import numpy as np
import matplotlib.pylab as pl
import ot
# necessary for 3d plot even if not used
from mpl_toolkits.mplot3d import Axes3D # noqa
from matplotlib.collections import PolyCollection # noqa

#import ot.lp.cvx as cvx

#
# Generate data
# -------------

#%% parameters

problems = []

n = 100 # nb bins

# bin positions
x = np.arange(n, dtype=np.float64)

# Gaussian distributions
# Gaussian distributions
a1 = ot.datasets.get_1D_gauss(n, m=20, s=5) # m= mean, s= std
a2 = ot.datasets.get_1D_gauss(n, m=60, s=8)

# creating matrix A containing all distributions
A = np.vstack((a1, a2)).T
n_distributions = A.shape[1]

# loss matrix + normalization
M = ot.utils.dist0(n)
M /= M.max()

#
# Plot data
# ---------

#%% plot the distributions

pl.figure(1, figsize=(6.4, 3))
for i in range(n_distributions):
pl.plot(x, A[:, i])
pl.title('Distributions')
pl.tight_layout()

#
# Barycenter computation
# ----------------------

#%% barycenter computation

alpha = 0.5 # 0<=alpha<=1
weights = np.array([1 - alpha, alpha])

# l2bary
bary_l2 = A.dot(weights)

# wasserstein
reg = 1e-3
ot.tic()
bary_wass = ot.bregman.barycenter(A, M, reg, weights)
ot.toc()


ot.tic()
bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)
ot.toc()

pl.figure(2)
pl.clf()
pl.subplot(2, 1, 1)
for i in range(n_distributions):
pl.plot(x, A[:, i])
pl.title('Distributions')

pl.subplot(2, 1, 2)
pl.plot(x, bary_l2, 'r', label='l2')
pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')
pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')
pl.legend()
pl.title('Barycenters')
pl.tight_layout()

problems.append([A, [bary_l2, bary_wass, bary_wass2]])

#%% parameters

a1 = 1.0 * (x > 10) * (x < 50)
a2 = 1.0 * (x > 60) * (x < 80)

a1 /= a1.sum()
a2 /= a2.sum()

# creating matrix A containing all distributions
A = np.vstack((a1, a2)).T
n_distributions = A.shape[1]

# loss matrix + normalization
M = ot.utils.dist0(n)
M /= M.max()


#%% plot the distributions

pl.figure(1, figsize=(6.4, 3))
for i in range(n_distributions):
pl.plot(x, A[:, i])
pl.title('Distributions')
pl.tight_layout()

#
# Barycenter computation
# ----------------------

#%% barycenter computation

alpha = 0.5 # 0<=alpha<=1
weights = np.array([1 - alpha, alpha])

# l2bary
bary_l2 = A.dot(weights)

# wasserstein
reg = 1e-3
ot.tic()
bary_wass = ot.bregman.barycenter(A, M, reg, weights)
ot.toc()


ot.tic()
bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)
ot.toc()


problems.append([A, [bary_l2, bary_wass, bary_wass2]])

pl.figure(2)
pl.clf()
pl.subplot(2, 1, 1)
for i in range(n_distributions):
pl.plot(x, A[:, i])
pl.title('Distributions')

pl.subplot(2, 1, 2)
pl.plot(x, bary_l2, 'r', label='l2')
pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')
pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')
pl.legend()
pl.title('Barycenters')
pl.tight_layout()

#%% parameters

a1 = np.zeros(n)
a2 = np.zeros(n)

a1[10] = .25
a1[20] = .5
a1[30] = .25
a2[80] = 1


a1 /= a1.sum()
a2 /= a2.sum()

# creating matrix A containing all distributions
A = np.vstack((a1, a2)).T
n_distributions = A.shape[1]

# loss matrix + normalization
M = ot.utils.dist0(n)
M /= M.max()


#%% plot the distributions

pl.figure(1, figsize=(6.4, 3))
for i in range(n_distributions):
pl.plot(x, A[:, i])
pl.title('Distributions')
pl.tight_layout()

#
# Barycenter computation
# ----------------------

#%% barycenter computation

alpha = 0.5 # 0<=alpha<=1
weights = np.array([1 - alpha, alpha])

# l2bary
bary_l2 = A.dot(weights)

# wasserstein
reg = 1e-3
ot.tic()
bary_wass = ot.bregman.barycenter(A, M, reg, weights)
ot.toc()


ot.tic()
bary_wass2 = ot.lp.barycenter(A, M, weights, solver='interior-point', verbose=True)
ot.toc()


problems.append([A, [bary_l2, bary_wass, bary_wass2]])

pl.figure(2)
pl.clf()
pl.subplot(2, 1, 1)
for i in range(n_distributions):
pl.plot(x, A[:, i])
pl.title('Distributions')

pl.subplot(2, 1, 2)
pl.plot(x, bary_l2, 'r', label='l2')
pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')
pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')
pl.legend()
pl.title('Barycenters')
pl.tight_layout()


#
# Final figure
# ------------
#

#%% plot

nbm = len(problems)
nbm2 = (nbm // 2)


pl.figure(2, (20, 6))
pl.clf()

for i in range(nbm):

A = problems[i][0]
bary_l2 = problems[i][1][0]
bary_wass = problems[i][1][1]
bary_wass2 = problems[i][1][2]

pl.subplot(2, nbm, 1 + i)
for j in range(n_distributions):
pl.plot(x, A[:, j])
if i == nbm2:
pl.title('Distributions')
pl.xticks(())
pl.yticks(())

pl.subplot(2, nbm, 1 + i + nbm)

pl.plot(x, bary_l2, 'r', label='L2 (Euclidean)')
pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')
pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')
if i == nbm - 1:
pl.legend()
if i == nbm2:
pl.title('Barycenters')

pl.xticks(())
pl.yticks(())
4 changes: 3 additions & 1 deletion ot/bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,11 +839,13 @@ def barycenter(A, M, reg, weights=None, numItermax=1000,
Parameters
----------
A : np.ndarray (d,n)
n training distributions of size d
n training distributions a_i of size d
M : np.ndarray (d,d)
loss matrix for OT
reg : float
Regularization term >0
weights : np.ndarray (n,)
Weights of each histogram a_i on the simplex (barycentric coodinates)
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Expand Down
3 changes: 3 additions & 0 deletions ot/lp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@

import numpy as np

from .import cvx

# import compiled emd
from .emd_wrap import emd_c, check_result
from ..utils import parmap
from .cvx import barycenter


def emd(a, b, M, numItermax=100000, log=False):
Expand Down
Loading

0 comments on commit 90efa5a

Please sign in to comment.