forked from PythonOT/POT
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MRG] CO-Optimal Transport solver (PythonOT#447)
* Allow warmstart in sinkhorn and sinkhorn_log * Added argument for warmstart of dual vectors in Sinkhorn-based methods in * Add the number of the PR * [WIP] CO-Optimal Transport * Revert "[WIP] CO-Optimal Transport" This reverts commit f3d36b2. * reformat with PEP8 * Fix W291 trailing whitespace error in pep8 test * Rearange position of warmstart argument and edit its description * Implementation of CO-Optimal Transport * Optimize code and edit documentation * fix backend bug in test cases * fix backend bug * fix backend bug * Add examples on COOT * Modify API and edit example * Edit API * minor edit of examples and release * fix bug in coot * fix doc examples * more fix of doc * restart CI * reordering ref * add more tests * add more tests * add test verbose * fix PEP8 bug * fix PEP8 bug * fix PEP8 bug * fix pytest bug * edit doc for better display --------- Co-authored-by: Rémi Flamary <[email protected]> Co-authored-by: Alexandre Gramfort <[email protected]>
- Loading branch information
1 parent
b9ed7b1
commit 897026e
Showing
7 changed files
with
1,052 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ API and modules | |
|
||
backend | ||
bregman | ||
coot | ||
da | ||
datasets | ||
dr | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# -*- coding: utf-8 -*- | ||
r""" | ||
=================================================== | ||
Row and column alignments with CO-Optimal Transport | ||
=================================================== | ||
This example is designed to show how to use the CO-Optimal Transport [47]_ in POT. | ||
CO-Optimal Transport allows to calculate the distance between two **arbitrary-size** | ||
matrices, and to align their rows and columns. In this example, we consider two | ||
random matrices :math:`X_1` and :math:`X_2` defined by | ||
:math:`(X_1)_{i,j} = \cos(\frac{i}{n_1} \pi) + \cos(\frac{j}{d_1} \pi) + \sigma \mathcal N(0,1)` | ||
and :math:`(X_2)_{i,j} = \cos(\frac{i}{n_2} \pi) + \cos(\frac{j}{d_2} \pi) + \sigma \mathcal N(0,1)`. | ||
.. [49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). | ||
`CO-Optimal Transport <https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf>`_. | ||
Advances in Neural Information Processing Systems, 33. | ||
""" | ||
|
||
# Author: Remi Flamary <[email protected]> | ||
# Quang Huy Tran <[email protected]> | ||
# License: MIT License | ||
|
||
from matplotlib.patches import ConnectionPatch | ||
import matplotlib.pylab as pl | ||
import numpy as np | ||
from ot.coot import co_optimal_transport as coot | ||
from ot.coot import co_optimal_transport2 as coot2 | ||
|
||
# %% | ||
# Generating two random matrices | ||
|
||
n1 = 20 | ||
n2 = 10 | ||
d1 = 16 | ||
d2 = 8 | ||
sigma = 0.2 | ||
|
||
X1 = ( | ||
np.cos(np.arange(n1) * np.pi / n1)[:, None] + | ||
np.cos(np.arange(d1) * np.pi / d1)[None, :] + | ||
sigma * np.random.randn(n1, d1) | ||
) | ||
X2 = ( | ||
np.cos(np.arange(n2) * np.pi / n2)[:, None] + | ||
np.cos(np.arange(d2) * np.pi / d2)[None, :] + | ||
sigma * np.random.randn(n2, d2) | ||
) | ||
|
||
# %% | ||
# Visualizing the matrices | ||
|
||
pl.figure(1, (8, 5)) | ||
pl.subplot(1, 2, 1) | ||
pl.imshow(X1) | ||
pl.title('$X_1$') | ||
|
||
pl.subplot(1, 2, 2) | ||
pl.imshow(X2) | ||
pl.title("$X_2$") | ||
|
||
pl.tight_layout() | ||
|
||
# %% | ||
# Visualizing the alignments of rows and columns, and calculating the CO-Optimal Transport distance | ||
|
||
pi_sample, pi_feature, log = coot(X1, X2, log=True, verbose=True) | ||
coot_distance = coot2(X1, X2) | ||
print('CO-Optimal Transport distance = {:.5f}'.format(coot_distance)) | ||
|
||
fig = pl.figure(4, (9, 7)) | ||
pl.clf() | ||
|
||
ax1 = pl.subplot(2, 2, 3) | ||
pl.imshow(X1) | ||
pl.xlabel('$X_1$') | ||
|
||
ax2 = pl.subplot(2, 2, 2) | ||
ax2.yaxis.tick_right() | ||
pl.imshow(np.transpose(X2)) | ||
pl.title("Transpose($X_2$)") | ||
ax2.xaxis.tick_top() | ||
|
||
for i in range(n1): | ||
j = np.argmax(pi_sample[i, :]) | ||
xyA = (d1 - .5, i) | ||
xyB = (j, d2 - .5) | ||
con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA=ax1.transData, | ||
coordsB=ax2.transData, color="black") | ||
fig.add_artist(con) | ||
|
||
for i in range(d1): | ||
j = np.argmax(pi_feature[i, :]) | ||
xyA = (i, -.5) | ||
xyB = (-.5, j) | ||
con = ConnectionPatch( | ||
xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue") | ||
fig.add_artist(con) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
# -*- coding: utf-8 -*- | ||
r""" | ||
=============================================================== | ||
Learning sample marginal distribution with CO-Optimal Transport | ||
=============================================================== | ||
In this example, we illustrate how to estimate the sample marginal distribution which minimizes | ||
the CO-Optimal Transport distance [47]_ between two matrices. More precisely, given a source data | ||
:math:`(X, \mu_x^{(s)}, \mu_x^{(f)})` and a target matrix :math:`Y` associated with a fixed | ||
histogram on features :math:`\mu_y^{(f)}`, we want to solve the following problem | ||
.. math:: | ||
\min_{\mu_y^{(s)} \in \Delta} \text{COOT}\left( (X, \mu_x^{(s)}, \mu_x^{(f)}), (Y, \mu_y^{(s)}, \mu_y^{(f)}) \right) | ||
where :math:`\Delta` is the probability simplex. This minimization is done with a | ||
simple projected gradient descent in PyTorch. We use the automatic backend of POT that | ||
allows us to compute the CO-Optimal Transport distance with :func:`ot.coot.co_optimal_transport2` | ||
with differentiable losses. | ||
.. [49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). | ||
`CO-Optimal Transport <https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf>`_. | ||
Advances in Neural Information Processing Systems, 33. | ||
""" | ||
|
||
# Author: Remi Flamary <[email protected]> | ||
# Quang Huy Tran <[email protected]> | ||
# License: MIT License | ||
|
||
from matplotlib.patches import ConnectionPatch | ||
import torch | ||
import numpy as np | ||
|
||
import matplotlib.pyplot as pl | ||
import ot | ||
|
||
from ot.coot import co_optimal_transport as coot | ||
from ot.coot import co_optimal_transport2 as coot2 | ||
|
||
|
||
# %% | ||
# Generate data | ||
# ------------- | ||
# The source and clean target matrices are generated by | ||
# :math:`X_{i,j} = \cos(\frac{i}{n_1} \pi) + \cos(\frac{j}{d_1} \pi)` and | ||
# :math:`Y_{i,j} = \cos(\frac{i}{n_2} \pi) + \cos(\frac{j}{d_2} \pi)`. | ||
# The target matrix is then contaminated by adding 5 row outliers. | ||
# Intuitively, we expect that the estimated sample distribution should ignore these outliers, | ||
# i.e. their weights should be zero. | ||
|
||
np.random.seed(182) | ||
|
||
n1, d1 = 20, 16 | ||
n2, d2 = 10, 8 | ||
n = 15 | ||
|
||
X = ( | ||
torch.cos(torch.arange(n1) * torch.pi / n1)[:, None] + | ||
torch.cos(torch.arange(d1) * torch.pi / d1)[None, :] | ||
) | ||
|
||
# Generate clean target data mixed with outliers | ||
Y_noisy = torch.randn((n, d2)) * 10.0 | ||
Y_noisy[:n2, :] = ( | ||
torch.cos(torch.arange(n2) * torch.pi / n2)[:, None] + | ||
torch.cos(torch.arange(d2) * torch.pi / d2)[None, :] | ||
) | ||
Y = Y_noisy[:n2, :] | ||
|
||
X, Y_noisy, Y = X.double(), Y_noisy.double(), Y.double() | ||
|
||
fig, axes = pl.subplots(nrows=1, ncols=3, figsize=(12, 5)) | ||
axes[0].imshow(X, vmin=-2, vmax=2) | ||
axes[0].set_title('$X$') | ||
|
||
axes[1].imshow(Y, vmin=-2, vmax=2) | ||
axes[1].set_title('Clean $Y$') | ||
|
||
axes[2].imshow(Y_noisy, vmin=-2, vmax=2) | ||
axes[2].set_title('Noisy $Y$') | ||
|
||
pl.tight_layout() | ||
|
||
# %% | ||
# Optimize the COOT distance with respect to the sample marginal distribution | ||
# --------------------------------------------------------------------------- | ||
|
||
losses = [] | ||
lr = 1e-3 | ||
niter = 1000 | ||
|
||
b = torch.tensor(ot.unif(n), requires_grad=True) | ||
|
||
for i in range(niter): | ||
|
||
loss = coot2(X, Y_noisy, wy_samp=b, log=False, verbose=False) | ||
losses.append(float(loss)) | ||
|
||
loss.backward() | ||
|
||
with torch.no_grad(): | ||
b -= lr * b.grad # gradient step | ||
b[:] = ot.utils.proj_simplex(b) # projection on the simplex | ||
|
||
b.grad.zero_() | ||
|
||
# Estimated sample marginal distribution and training loss curve | ||
pl.plot(losses[10:]) | ||
pl.title('CO-Optimal Transport distance') | ||
|
||
print(f"Marginal distribution = {b.detach().numpy()}") | ||
|
||
# %% | ||
# Visualizing the row and column alignments with the estimated sample marginal distribution | ||
# ----------------------------------------------------------------------------------------- | ||
# | ||
# Clearly, the learned marginal distribution completely and successfully ignores the 5 outliers. | ||
|
||
X, Y_noisy = X.numpy(), Y_noisy.numpy() | ||
b = b.detach().numpy() | ||
|
||
pi_sample, pi_feature = coot(X, Y_noisy, wy_samp=b, log=False, verbose=True) | ||
|
||
fig = pl.figure(4, (9, 7)) | ||
pl.clf() | ||
|
||
ax1 = pl.subplot(2, 2, 3) | ||
pl.imshow(X, vmin=-2, vmax=2) | ||
pl.xlabel('$X$') | ||
|
||
ax2 = pl.subplot(2, 2, 2) | ||
ax2.yaxis.tick_right() | ||
pl.imshow(np.transpose(Y_noisy), vmin=-2, vmax=2) | ||
pl.title("Transpose(Noisy $Y$)") | ||
ax2.xaxis.tick_top() | ||
|
||
for i in range(n1): | ||
j = np.argmax(pi_sample[i, :]) | ||
xyA = (d1 - .5, i) | ||
xyB = (j, d2 - .5) | ||
con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA=ax1.transData, | ||
coordsB=ax2.transData, color="black") | ||
fig.add_artist(con) | ||
|
||
for i in range(d1): | ||
j = np.argmax(pi_feature[i, :]) | ||
xyA = (i, -.5) | ||
xyB = (-.5, j) | ||
con = ConnectionPatch( | ||
xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue") | ||
fig.add_artist(con) |
Oops, something went wrong.