Skip to content

Commit

Permalink
[MRG] CO-Optimal Transport solver (PythonOT#447)
Browse files Browse the repository at this point in the history
* Allow warmstart in sinkhorn and sinkhorn_log

* Added argument for warmstart of dual vectors in Sinkhorn-based methods in

* Add the number of the PR

* [WIP] CO-Optimal Transport

* Revert "[WIP] CO-Optimal Transport"

This reverts commit f3d36b2.

* reformat with PEP8

* Fix W291 trailing whitespace error in pep8 test

* Rearange position of warmstart argument and edit its description

* Implementation of CO-Optimal Transport

* Optimize code and edit documentation

* fix backend bug in test cases

* fix backend bug

* fix backend bug

* Add examples on COOT

* Modify API and edit example

* Edit API

* minor edit of examples and release

* fix bug in coot

* fix doc examples

* more fix of doc

* restart CI

* reordering ref

* add more tests

* add more tests

* add test verbose

* fix PEP8 bug

* fix PEP8 bug

* fix PEP8 bug

* fix pytest bug

* edit doc for better display

---------

Co-authored-by: Rémi Flamary <[email protected]>
Co-authored-by: Alexandre Gramfort <[email protected]>
  • Loading branch information
3 people authored Mar 22, 2023
1 parent b9ed7b1 commit 897026e
Show file tree
Hide file tree
Showing 7 changed files with 1,052 additions and 7 deletions.
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -276,15 +276,15 @@ You can also post bug reports and feature requests in Github issues. Make sure t

[35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N., Koyejo, S., ... & Schwing, A. G. (2019). [Max-sliced wasserstein distance and its use for gans](https://openaccess.thecvf.com/content_CVPR_2019/papers/Deshpande_Max-Sliced_Wasserstein_Distance_and_Its_Use_for_GANs_CVPR_2019_paper.pdf). In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10648-10656).

[36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R.
(2019, May). [Sliced-Wasserstein flows: Nonparametric generative modeling
via optimal transport and diffusions](http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf). In International Conference on
[36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R.
(2019, May). [Sliced-Wasserstein flows: Nonparametric generative modeling
via optimal transport and diffusions](http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf). In International Conference on
Machine Learning (pp. 4104-4113). PMLR.

[37] Janati, H., Cuturi, M., Gramfort, A. [Debiased sinkhorn barycenters](http://proceedings.mlr.press/v119/janati20a/janati20a.pdf) Proceedings of the 37th International
Conference on Machine Learning, PMLR 119:4692-4701, 2020

[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph
[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph
Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Conference on Machine Learning (ICML), 2021.

[39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). [Kantorovich duality for general transport costs and applications](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.712.1825&rep=rep1&type=pdf). Journal of Functional Analysis, 273(11), 3327-3405.
Expand All @@ -305,4 +305,6 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer

[47] Chowdhury, S., & Mémoli, F. (2019). [The gromov–wasserstein distance between networks and stable network invariants](https://academic.oup.com/imaiai/article/8/4/757/5627736). Information and Inference: A Journal of the IMA, 8(4), 757-787.

[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty (2022). [Semi-relaxed Gromov-Wasserstein divergence and applications on graphs](https://openreview.net/pdf?id=RShaMexjc-x). International Conference on Learning Representations (ICLR), 2022.
[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty (2022). [Semi-relaxed Gromov-Wasserstein divergence and applications on graphs](https://openreview.net/pdf?id=RShaMexjc-x). International Conference on Learning Representations (ICLR), 2022.

[49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). [CO-Optimal Transport](https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf). Advances in Neural Information Processing Systems, 33.
6 changes: 4 additions & 2 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
- New API for OT solver using function `ot.solve` (PR #388)
- Backend version of `ot.partial` and `ot.smooth` (PR #388 and #449)
- Added argument for warmstart of dual potentials in Sinkhorn-based methods in `ot.bregman` (PR #437)
- Add parameters method in `ot.da.SinkhornTransport` (PR #440)
- `ot.dr` now uses the new Pymanopt API and POT is compatible with current Pymanopt (PR #443)
- Added parameters method in `ot.da.SinkhornTransport` (PR #440)
- `ot.dr` now uses the new Pymanopt API and POT is compatible with current
Pymanopt (PR #443)
- Added CO-Optimal Transport solver + examples (PR # 447)
- Remove the redundant `nx.abs()` at the end of `wasserstein_1d()` (PR #448)

#### Closed issues
Expand Down
1 change: 1 addition & 0 deletions docs/source/all.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ API and modules

backend
bregman
coot
da
datasets
dr
Expand Down
97 changes: 97 additions & 0 deletions examples/others/plot_COOT.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# -*- coding: utf-8 -*-
r"""
===================================================
Row and column alignments with CO-Optimal Transport
===================================================
This example is designed to show how to use the CO-Optimal Transport [47]_ in POT.
CO-Optimal Transport allows to calculate the distance between two **arbitrary-size**
matrices, and to align their rows and columns. In this example, we consider two
random matrices :math:`X_1` and :math:`X_2` defined by
:math:`(X_1)_{i,j} = \cos(\frac{i}{n_1} \pi) + \cos(\frac{j}{d_1} \pi) + \sigma \mathcal N(0,1)`
and :math:`(X_2)_{i,j} = \cos(\frac{i}{n_2} \pi) + \cos(\frac{j}{d_2} \pi) + \sigma \mathcal N(0,1)`.
.. [49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020).
`CO-Optimal Transport <https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf>`_.
Advances in Neural Information Processing Systems, 33.
"""

# Author: Remi Flamary <[email protected]>
# Quang Huy Tran <[email protected]>
# License: MIT License

from matplotlib.patches import ConnectionPatch
import matplotlib.pylab as pl
import numpy as np
from ot.coot import co_optimal_transport as coot
from ot.coot import co_optimal_transport2 as coot2

# %%
# Generating two random matrices

n1 = 20
n2 = 10
d1 = 16
d2 = 8
sigma = 0.2

X1 = (
np.cos(np.arange(n1) * np.pi / n1)[:, None] +
np.cos(np.arange(d1) * np.pi / d1)[None, :] +
sigma * np.random.randn(n1, d1)
)
X2 = (
np.cos(np.arange(n2) * np.pi / n2)[:, None] +
np.cos(np.arange(d2) * np.pi / d2)[None, :] +
sigma * np.random.randn(n2, d2)
)

# %%
# Visualizing the matrices

pl.figure(1, (8, 5))
pl.subplot(1, 2, 1)
pl.imshow(X1)
pl.title('$X_1$')

pl.subplot(1, 2, 2)
pl.imshow(X2)
pl.title("$X_2$")

pl.tight_layout()

# %%
# Visualizing the alignments of rows and columns, and calculating the CO-Optimal Transport distance

pi_sample, pi_feature, log = coot(X1, X2, log=True, verbose=True)
coot_distance = coot2(X1, X2)
print('CO-Optimal Transport distance = {:.5f}'.format(coot_distance))

fig = pl.figure(4, (9, 7))
pl.clf()

ax1 = pl.subplot(2, 2, 3)
pl.imshow(X1)
pl.xlabel('$X_1$')

ax2 = pl.subplot(2, 2, 2)
ax2.yaxis.tick_right()
pl.imshow(np.transpose(X2))
pl.title("Transpose($X_2$)")
ax2.xaxis.tick_top()

for i in range(n1):
j = np.argmax(pi_sample[i, :])
xyA = (d1 - .5, i)
xyB = (j, d2 - .5)
con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA=ax1.transData,
coordsB=ax2.transData, color="black")
fig.add_artist(con)

for i in range(d1):
j = np.argmax(pi_feature[i, :])
xyA = (i, -.5)
xyB = (-.5, j)
con = ConnectionPatch(
xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue")
fig.add_artist(con)
150 changes: 150 additions & 0 deletions examples/others/plot_learning_weights_with_COOT.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# -*- coding: utf-8 -*-
r"""
===============================================================
Learning sample marginal distribution with CO-Optimal Transport
===============================================================
In this example, we illustrate how to estimate the sample marginal distribution which minimizes
the CO-Optimal Transport distance [47]_ between two matrices. More precisely, given a source data
:math:`(X, \mu_x^{(s)}, \mu_x^{(f)})` and a target matrix :math:`Y` associated with a fixed
histogram on features :math:`\mu_y^{(f)}`, we want to solve the following problem
.. math::
\min_{\mu_y^{(s)} \in \Delta} \text{COOT}\left( (X, \mu_x^{(s)}, \mu_x^{(f)}), (Y, \mu_y^{(s)}, \mu_y^{(f)}) \right)
where :math:`\Delta` is the probability simplex. This minimization is done with a
simple projected gradient descent in PyTorch. We use the automatic backend of POT that
allows us to compute the CO-Optimal Transport distance with :func:`ot.coot.co_optimal_transport2`
with differentiable losses.
.. [49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020).
`CO-Optimal Transport <https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf>`_.
Advances in Neural Information Processing Systems, 33.
"""

# Author: Remi Flamary <[email protected]>
# Quang Huy Tran <[email protected]>
# License: MIT License

from matplotlib.patches import ConnectionPatch
import torch
import numpy as np

import matplotlib.pyplot as pl
import ot

from ot.coot import co_optimal_transport as coot
from ot.coot import co_optimal_transport2 as coot2


# %%
# Generate data
# -------------
# The source and clean target matrices are generated by
# :math:`X_{i,j} = \cos(\frac{i}{n_1} \pi) + \cos(\frac{j}{d_1} \pi)` and
# :math:`Y_{i,j} = \cos(\frac{i}{n_2} \pi) + \cos(\frac{j}{d_2} \pi)`.
# The target matrix is then contaminated by adding 5 row outliers.
# Intuitively, we expect that the estimated sample distribution should ignore these outliers,
# i.e. their weights should be zero.

np.random.seed(182)

n1, d1 = 20, 16
n2, d2 = 10, 8
n = 15

X = (
torch.cos(torch.arange(n1) * torch.pi / n1)[:, None] +
torch.cos(torch.arange(d1) * torch.pi / d1)[None, :]
)

# Generate clean target data mixed with outliers
Y_noisy = torch.randn((n, d2)) * 10.0
Y_noisy[:n2, :] = (
torch.cos(torch.arange(n2) * torch.pi / n2)[:, None] +
torch.cos(torch.arange(d2) * torch.pi / d2)[None, :]
)
Y = Y_noisy[:n2, :]

X, Y_noisy, Y = X.double(), Y_noisy.double(), Y.double()

fig, axes = pl.subplots(nrows=1, ncols=3, figsize=(12, 5))
axes[0].imshow(X, vmin=-2, vmax=2)
axes[0].set_title('$X$')

axes[1].imshow(Y, vmin=-2, vmax=2)
axes[1].set_title('Clean $Y$')

axes[2].imshow(Y_noisy, vmin=-2, vmax=2)
axes[2].set_title('Noisy $Y$')

pl.tight_layout()

# %%
# Optimize the COOT distance with respect to the sample marginal distribution
# ---------------------------------------------------------------------------

losses = []
lr = 1e-3
niter = 1000

b = torch.tensor(ot.unif(n), requires_grad=True)

for i in range(niter):

loss = coot2(X, Y_noisy, wy_samp=b, log=False, verbose=False)
losses.append(float(loss))

loss.backward()

with torch.no_grad():
b -= lr * b.grad # gradient step
b[:] = ot.utils.proj_simplex(b) # projection on the simplex

b.grad.zero_()

# Estimated sample marginal distribution and training loss curve
pl.plot(losses[10:])
pl.title('CO-Optimal Transport distance')

print(f"Marginal distribution = {b.detach().numpy()}")

# %%
# Visualizing the row and column alignments with the estimated sample marginal distribution
# -----------------------------------------------------------------------------------------
#
# Clearly, the learned marginal distribution completely and successfully ignores the 5 outliers.

X, Y_noisy = X.numpy(), Y_noisy.numpy()
b = b.detach().numpy()

pi_sample, pi_feature = coot(X, Y_noisy, wy_samp=b, log=False, verbose=True)

fig = pl.figure(4, (9, 7))
pl.clf()

ax1 = pl.subplot(2, 2, 3)
pl.imshow(X, vmin=-2, vmax=2)
pl.xlabel('$X$')

ax2 = pl.subplot(2, 2, 2)
ax2.yaxis.tick_right()
pl.imshow(np.transpose(Y_noisy), vmin=-2, vmax=2)
pl.title("Transpose(Noisy $Y$)")
ax2.xaxis.tick_top()

for i in range(n1):
j = np.argmax(pi_sample[i, :])
xyA = (d1 - .5, i)
xyB = (j, d2 - .5)
con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA=ax1.transData,
coordsB=ax2.transData, color="black")
fig.add_artist(con)

for i in range(d1):
j = np.argmax(pi_feature[i, :])
xyA = (i, -.5)
xyB = (-.5, j)
con = ConnectionPatch(
xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue")
fig.add_artist(con)
Loading

0 comments on commit 897026e

Please sign in to comment.