Skip to content

Commit

Permalink
[MRG] Efficient Discrete Multi Marginal Optimal Transport (PythonOT#454)
Browse files Browse the repository at this point in the history
* add demd.py to ot, add plot_demd_*.py to examples, updated init.py in ot, build failed need to fix

* update REAMDME.md with citation to iclr23 paper and example link

* chaneg directory of examples, build successful

* fix small latex bug

* update all.rst, examples and demd have passed pep8 and pyflake

* add more detailed comments for examples

* TODO: test module for demd, wrong demd index after build

* add test module

* add contributors

* pass pyflake checks, pass pep8

* added the PR to the RELEASES.md file

* temporal changes with logs

* init changes

* merge examples, demd -> lp.dmmot

* bug fix in plot_dmmot, some commenting/documenting edits

* dmmot example cleanup, some comments/plotting edits

* add dist_monge method

* all dmmot methods takes (n, d) shape A as input (follows POT style)

* passed pep8 and pyflake checks

* resolve test fail issue

* fix pep8 error

* resolve issues from last review, pyflake and pep8 checked

* add lr decay

* add more examples, ground cost options, test for uniqueness

* remove additional experiment setting, not needed in this PR

* fixed line 14 1 blank line

* fix gradient computation link

* Update ot/lp/dmmot.py

Store input variable instead of copying it

---------

Co-authored-by: Rémi Flamary <[email protected]>
Co-authored-by: Ronak <[email protected]>
  • Loading branch information
3 people authored Aug 3, 2023
1 parent a879690 commit 5ead79b
Show file tree
Hide file tree
Showing 7 changed files with 590 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ The contributors to this library are:
* [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter)
* [Theo Gnassounou](https://github.com/tgnassou) (OT between Gaussian distributions)
* [Clément Bonet](https://clbonet.github.io) (Wassertstein on circle, Spherical Sliced-Wasserstein)
* [Ronak Mehta](https://ronakrm.github.io) (Efficient Discrete Multi Marginal Optimal Transport Regularization)
* [Xizheng Yu](https://github.com/x12hengyu) (Efficient Discrete Multi Marginal Optimal Transport Regularization)
* [Sonia Mazelet](https://github.com/SoniaMaz8) (Template based GNN layers)

## Acknowledgments
Expand Down
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ POT provides the following generic OT solvers (links to examples):
* [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46]
* [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38].
* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) (exact and regularized [48]).
* [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://pythonot.github.io/auto_examples/others/plot_demd_gradient_minimize.html) [50].
* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays.

POT provides the following Machine Learning related solvers:
Expand Down Expand Up @@ -319,3 +320,7 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer
[53] C. Vincent-Cuaz, R. Flamary, M. Corneli, T. Vayer, N. Courty (2022). [Template based graph neural network with optimal transport distances](https://papers.nips.cc/paper_files/paper/2022/file/4d3525bc60ba1adc72336c0392d3d902-Paper-Conference.pdf). Advances in Neural Information Processing Systems, 35.

[54] Bécigneul, G., Ganea, O. E., Chen, B., Barzilay, R., & Jaakkola, T. S. (2020). [Optimal transport graph neural networks](https://arxiv.org/pdf/2006.04804).

[55] Ronak Mehta, Jeffery Kline, Vishnu Suresh Lokhande, Glenn Fung, & Vikas Singh (2023). [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://openreview.net/forum?id=R98ZfMt-jE). In The Eleventh International Conference on Learning Representations (ICLR).

[56] Jeffery Kline. [Properties of the d-dimensional earth mover’s problem](https://www.sciencedirect.com/science/article/pii/S0166218X19301441). Discrete Applied Mathematics, 265: 128–141, 2019.
2 changes: 2 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
- Make marginal parameters optional for (F)GW solvers in `._gw`, `._bregman` and `._semirelaxed` (PR #455)
- Add Entropic Wasserstein Component Analysis (ECWA) in ot.dr (PR #486)

- Added feature Efficient Discrete Multi Marginal Optimal Transport Regularization + examples (PR #454)

#### Closed issues

- Fix change in scipy API for `cdist` (PR #487)
Expand Down
158 changes: 158 additions & 0 deletions examples/others/plot_dmmot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# -*- coding: utf-8 -*-
r"""
===============================================================================
Computing d-dimensional Barycenters via d-MMOT
===============================================================================
When the cost is discretized (Monge), the d-MMOT solver can more quickly
compute and minimize the distance between many distributions without the need
for intermediate barycenter computations. This example compares the time to
identify, and the quality of, solutions for the d-MMOT problem using a
primal/dual algorithm and classical LP barycenter approaches.
"""

# Author: Ronak Mehta <[email protected]>
# Xizheng Yu <[email protected]>
#
# License: MIT License

# %%
# Generating 2 distributions
# -----
import numpy as np
import matplotlib.pyplot as pl
import ot

np.random.seed(0)

n = 100
d = 2
# Gaussian distributions
a1 = ot.datasets.make_1D_gauss(n, m=20, s=5) # m=mean, s=std
a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)
A = np.vstack((a1, a2)).T
x = np.arange(n, dtype=np.float64)
M = ot.utils.dist(x.reshape((n, 1)), metric='minkowski')

pl.figure(1, figsize=(6.4, 3))
pl.plot(x, a1, 'b', label='Source distribution')
pl.plot(x, a2, 'r', label='Target distribution')
pl.legend()

# %%
# Minimize the distances among distributions, identify the Barycenter
# -----
# The objective being minimized is different for both methods, so the objective
# values cannot be compared.

# L2 Iteration
weights = np.ones(d) / d
l2_bary = A.dot(weights)

print('LP Iterations:')
weights = np.ones(d) / d
lp_bary, lp_log = ot.lp.barycenter(
A, M, weights, solver='interior-point', verbose=False, log=True)
print('Time\t: ', ot.toc(''))
print('Obj\t: ', lp_log['fun'])

print('')
print('Discrete MMOT Algorithm:')
ot.tic()
barys, log = ot.lp.dmmot_monge_1dgrid_optimize(
A, niters=4000, lr_init=1e-5, lr_decay=0.997, log=True)
dmmot_obj = log['primal objective']
print('Time\t: ', ot.toc(''))
print('Obj\t: ', dmmot_obj)

# %%
# Compare Barycenters in both methods
# -----
pl.figure(1, figsize=(6.4, 3))
for i in range(len(barys)):
if i == 0:
pl.plot(x, barys[i], 'g-*', label='Discrete MMOT')
else:
continue
# pl.plot(x, barys[i], 'g-*')
pl.plot(x, lp_bary, label='LP Barycenter')
pl.plot(x, l2_bary, label='L2 Barycenter')
pl.plot(x, a1, 'b', label='Source distribution')
pl.plot(x, a2, 'r', label='Target distribution')
pl.title('Monge Cost: Barycenters from LP Solver and dmmot solver')
pl.legend()


# %%
# More than 2 distributions
# --------------------------------------------------
# Generate 7 pseudorandom gaussian distributions with 50 bins.
n = 50 # nb bins
d = 7
vecsize = n * d

data = []
for i in range(d):
m = n * (0.5 * np.random.rand(1)) * float(np.random.randint(2) + 1)
a = ot.datasets.make_1D_gauss(n, m=m, s=5)
data.append(a)

x = np.arange(n, dtype=np.float64)
M = ot.utils.dist(x.reshape((n, 1)), metric='minkowski')
A = np.vstack(data).T

pl.figure(1, figsize=(6.4, 3))
for i in range(len(data)):
pl.plot(x, data[i])

pl.title('Distributions')
pl.legend()

# %%
# Minimizing Distances Among Many Distributions
# ---------------
# The objective being minimized is different for both methods, so the objective
# values cannot be compared.

# Perform gradient descent optimization using the d-MMOT method.
barys = ot.lp.dmmot_monge_1dgrid_optimize(
A, niters=3000, lr_init=1e-4, lr_decay=0.997)

# after minimization, any distribution can be used as a estimate of barycenter.
bary = barys[0]

# Compute 1D Wasserstein barycenter using the L2/LP method
weights = ot.unif(d)
l2_bary = A.dot(weights)
lp_bary, bary_log = ot.lp.barycenter(A, M, weights, solver='interior-point',
verbose=False, log=True)

# %%
# Compare Barycenters in both methods
# ---------
pl.figure(1, figsize=(6.4, 3))
pl.plot(x, bary, 'g-*', label='Discrete MMOT')
pl.plot(x, l2_bary, 'k', label='L2 Barycenter')
pl.plot(x, lp_bary, 'k-', label='LP Wasserstein')
pl.title('Barycenters')
pl.legend()

# %%
# Compare with original distributions
# ---------
pl.figure(1, figsize=(6.4, 3))
for i in range(len(data)):
pl.plot(x, data[i])
for i in range(len(barys)):
if i == 0:
pl.plot(x, barys[i], 'g-*', label='Discrete MMOT')
else:
continue
# pl.plot(x, barys[i], 'g')
pl.plot(x, l2_bary, 'k^', label='L2')
pl.plot(x, lp_bary, 'o', color='grey', label='LP')
pl.title('Barycenters')
pl.legend()
pl.show()

# %%
4 changes: 3 additions & 1 deletion ot/lp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from . import cvx
from .cvx import barycenter
from .dmmot import *

# import compiled emd
from .emd_wrap import emd_c, check_result, emd_1d_sorted
Expand All @@ -30,7 +31,8 @@

__all__ = ['emd', 'emd2', 'barycenter', 'free_support_barycenter', 'cvx', ' emd_1d_sorted',
'emd_1d', 'emd2_1d', 'wasserstein_1d', 'generalized_free_support_barycenter',
'binary_search_circle', 'wasserstein_circle', 'semidiscrete_wasserstein2_unif_circle']
'binary_search_circle', 'wasserstein_circle', 'semidiscrete_wasserstein2_unif_circle',
'discrete_mmot', 'discrete_mmot_converge']


def check_number_threads(numThreads):
Expand Down
Loading

0 comments on commit 5ead79b

Please sign in to comment.