From 897026ea1f5c35ba9e881433bc61490e70776b8c Mon Sep 17 00:00:00 2001 From: Huy Tran Date: Wed, 22 Mar 2023 08:13:53 +0100 Subject: [PATCH] [MRG] CO-Optimal Transport solver (#447) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Allow warmstart in sinkhorn and sinkhorn_log * Added argument for warmstart of dual vectors in Sinkhorn-based methods in * Add the number of the PR * [WIP] CO-Optimal Transport * Revert "[WIP] CO-Optimal Transport" This reverts commit f3d36b2705013409ac69b346585e311bc25fcfb7. * reformat with PEP8 * Fix W291 trailing whitespace error in pep8 test * Rearange position of warmstart argument and edit its description * Implementation of CO-Optimal Transport * Optimize code and edit documentation * fix backend bug in test cases * fix backend bug * fix backend bug * Add examples on COOT * Modify API and edit example * Edit API * minor edit of examples and release * fix bug in coot * fix doc examples * more fix of doc * restart CI * reordering ref * add more tests * add more tests * add test verbose * fix PEP8 bug * fix PEP8 bug * fix PEP8 bug * fix pytest bug * edit doc for better display --------- Co-authored-by: Rémi Flamary Co-authored-by: Alexandre Gramfort --- README.md | 12 +- RELEASES.md | 6 +- docs/source/all.rst | 1 + examples/others/plot_COOT.py | 97 ++++ .../others/plot_learning_weights_with_COOT.py | 150 ++++++ ot/coot.py | 434 ++++++++++++++++++ test/test_coot.py | 359 +++++++++++++++ 7 files changed, 1052 insertions(+), 7 deletions(-) create mode 100644 examples/others/plot_COOT.py create mode 100644 examples/others/plot_learning_weights_with_COOT.py create mode 100644 ot/coot.py create mode 100644 test/test_coot.py diff --git a/README.md b/README.md index e7241b806..9c5e07e11 100644 --- a/README.md +++ b/README.md @@ -276,15 +276,15 @@ You can also post bug reports and feature requests in Github issues. Make sure t [35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N., Koyejo, S., ... & Schwing, A. G. (2019). [Max-sliced wasserstein distance and its use for gans](https://openaccess.thecvf.com/content_CVPR_2019/papers/Deshpande_Max-Sliced_Wasserstein_Distance_and_Its_Use_for_GANs_CVPR_2019_paper.pdf). In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10648-10656). -[36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R. -(2019, May). [Sliced-Wasserstein flows: Nonparametric generative modeling -via optimal transport and diffusions](http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf). In International Conference on +[36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R. +(2019, May). [Sliced-Wasserstein flows: Nonparametric generative modeling +via optimal transport and diffusions](http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf). In International Conference on Machine Learning (pp. 4104-4113). PMLR. [37] Janati, H., Cuturi, M., Gramfort, A. [Debiased sinkhorn barycenters](http://proceedings.mlr.press/v119/janati20a/janati20a.pdf) Proceedings of the 37th International Conference on Machine Learning, PMLR 119:4692-4701, 2020 -[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph +[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Conference on Machine Learning (ICML), 2021. [39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017). [Kantorovich duality for general transport costs and applications](https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.712.1825&rep=rep1&type=pdf). Journal of Functional Analysis, 273(11), 3327-3405. @@ -305,4 +305,6 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer [47] Chowdhury, S., & Mémoli, F. (2019). [The gromov–wasserstein distance between networks and stable network invariants](https://academic.oup.com/imaiai/article/8/4/757/5627736). Information and Inference: A Journal of the IMA, 8(4), 757-787. -[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty (2022). [Semi-relaxed Gromov-Wasserstein divergence and applications on graphs](https://openreview.net/pdf?id=RShaMexjc-x). International Conference on Learning Representations (ICLR), 2022. \ No newline at end of file +[48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty (2022). [Semi-relaxed Gromov-Wasserstein divergence and applications on graphs](https://openreview.net/pdf?id=RShaMexjc-x). International Conference on Learning Representations (ICLR), 2022. + +[49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). [CO-Optimal Transport](https://proceedings.neurips.cc/paper/2020/file/cc384c68ad503482fb24e6d1e3b512ae-Paper.pdf). Advances in Neural Information Processing Systems, 33. diff --git a/RELEASES.md b/RELEASES.md index e4c6e1591..bc0b189b5 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -16,8 +16,10 @@ - New API for OT solver using function `ot.solve` (PR #388) - Backend version of `ot.partial` and `ot.smooth` (PR #388 and #449) - Added argument for warmstart of dual potentials in Sinkhorn-based methods in `ot.bregman` (PR #437) -- Add parameters method in `ot.da.SinkhornTransport` (PR #440) -- `ot.dr` now uses the new Pymanopt API and POT is compatible with current Pymanopt (PR #443) +- Added parameters method in `ot.da.SinkhornTransport` (PR #440) +- `ot.dr` now uses the new Pymanopt API and POT is compatible with current + Pymanopt (PR #443) +- Added CO-Optimal Transport solver + examples (PR # 447) - Remove the redundant `nx.abs()` at the end of `wasserstein_1d()` (PR #448) #### Closed issues diff --git a/docs/source/all.rst b/docs/source/all.rst index 41d8e0676..1b8d13c29 100644 --- a/docs/source/all.rst +++ b/docs/source/all.rst @@ -16,6 +16,7 @@ API and modules backend bregman + coot da datasets dr diff --git a/examples/others/plot_COOT.py b/examples/others/plot_COOT.py new file mode 100644 index 000000000..98c1ce146 --- /dev/null +++ b/examples/others/plot_COOT.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +r""" +=================================================== +Row and column alignments with CO-Optimal Transport +=================================================== + +This example is designed to show how to use the CO-Optimal Transport [47]_ in POT. +CO-Optimal Transport allows to calculate the distance between two **arbitrary-size** +matrices, and to align their rows and columns. In this example, we consider two +random matrices :math:`X_1` and :math:`X_2` defined by +:math:`(X_1)_{i,j} = \cos(\frac{i}{n_1} \pi) + \cos(\frac{j}{d_1} \pi) + \sigma \mathcal N(0,1)` +and :math:`(X_2)_{i,j} = \cos(\frac{i}{n_2} \pi) + \cos(\frac{j}{d_2} \pi) + \sigma \mathcal N(0,1)`. + +.. [49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). + `CO-Optimal Transport `_. + Advances in Neural Information Processing Systems, 33. +""" + +# Author: Remi Flamary +# Quang Huy Tran +# License: MIT License + +from matplotlib.patches import ConnectionPatch +import matplotlib.pylab as pl +import numpy as np +from ot.coot import co_optimal_transport as coot +from ot.coot import co_optimal_transport2 as coot2 + +# %% +# Generating two random matrices + +n1 = 20 +n2 = 10 +d1 = 16 +d2 = 8 +sigma = 0.2 + +X1 = ( + np.cos(np.arange(n1) * np.pi / n1)[:, None] + + np.cos(np.arange(d1) * np.pi / d1)[None, :] + + sigma * np.random.randn(n1, d1) +) +X2 = ( + np.cos(np.arange(n2) * np.pi / n2)[:, None] + + np.cos(np.arange(d2) * np.pi / d2)[None, :] + + sigma * np.random.randn(n2, d2) +) + +# %% +# Visualizing the matrices + +pl.figure(1, (8, 5)) +pl.subplot(1, 2, 1) +pl.imshow(X1) +pl.title('$X_1$') + +pl.subplot(1, 2, 2) +pl.imshow(X2) +pl.title("$X_2$") + +pl.tight_layout() + +# %% +# Visualizing the alignments of rows and columns, and calculating the CO-Optimal Transport distance + +pi_sample, pi_feature, log = coot(X1, X2, log=True, verbose=True) +coot_distance = coot2(X1, X2) +print('CO-Optimal Transport distance = {:.5f}'.format(coot_distance)) + +fig = pl.figure(4, (9, 7)) +pl.clf() + +ax1 = pl.subplot(2, 2, 3) +pl.imshow(X1) +pl.xlabel('$X_1$') + +ax2 = pl.subplot(2, 2, 2) +ax2.yaxis.tick_right() +pl.imshow(np.transpose(X2)) +pl.title("Transpose($X_2$)") +ax2.xaxis.tick_top() + +for i in range(n1): + j = np.argmax(pi_sample[i, :]) + xyA = (d1 - .5, i) + xyB = (j, d2 - .5) + con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA=ax1.transData, + coordsB=ax2.transData, color="black") + fig.add_artist(con) + +for i in range(d1): + j = np.argmax(pi_feature[i, :]) + xyA = (i, -.5) + xyB = (-.5, j) + con = ConnectionPatch( + xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue") + fig.add_artist(con) diff --git a/examples/others/plot_learning_weights_with_COOT.py b/examples/others/plot_learning_weights_with_COOT.py new file mode 100644 index 000000000..cb115c306 --- /dev/null +++ b/examples/others/plot_learning_weights_with_COOT.py @@ -0,0 +1,150 @@ +# -*- coding: utf-8 -*- +r""" +=============================================================== +Learning sample marginal distribution with CO-Optimal Transport +=============================================================== + +In this example, we illustrate how to estimate the sample marginal distribution which minimizes +the CO-Optimal Transport distance [47]_ between two matrices. More precisely, given a source data +:math:`(X, \mu_x^{(s)}, \mu_x^{(f)})` and a target matrix :math:`Y` associated with a fixed +histogram on features :math:`\mu_y^{(f)}`, we want to solve the following problem + +.. math:: + \min_{\mu_y^{(s)} \in \Delta} \text{COOT}\left( (X, \mu_x^{(s)}, \mu_x^{(f)}), (Y, \mu_y^{(s)}, \mu_y^{(f)}) \right) + +where :math:`\Delta` is the probability simplex. This minimization is done with a +simple projected gradient descent in PyTorch. We use the automatic backend of POT that +allows us to compute the CO-Optimal Transport distance with :func:`ot.coot.co_optimal_transport2` +with differentiable losses. + +.. [49] Redko, I., Vayer, T., Flamary, R., and Courty, N. (2020). + `CO-Optimal Transport `_. + Advances in Neural Information Processing Systems, 33. +""" + +# Author: Remi Flamary +# Quang Huy Tran +# License: MIT License + +from matplotlib.patches import ConnectionPatch +import torch +import numpy as np + +import matplotlib.pyplot as pl +import ot + +from ot.coot import co_optimal_transport as coot +from ot.coot import co_optimal_transport2 as coot2 + + +# %% +# Generate data +# ------------- +# The source and clean target matrices are generated by +# :math:`X_{i,j} = \cos(\frac{i}{n_1} \pi) + \cos(\frac{j}{d_1} \pi)` and +# :math:`Y_{i,j} = \cos(\frac{i}{n_2} \pi) + \cos(\frac{j}{d_2} \pi)`. +# The target matrix is then contaminated by adding 5 row outliers. +# Intuitively, we expect that the estimated sample distribution should ignore these outliers, +# i.e. their weights should be zero. + +np.random.seed(182) + +n1, d1 = 20, 16 +n2, d2 = 10, 8 +n = 15 + +X = ( + torch.cos(torch.arange(n1) * torch.pi / n1)[:, None] + + torch.cos(torch.arange(d1) * torch.pi / d1)[None, :] +) + +# Generate clean target data mixed with outliers +Y_noisy = torch.randn((n, d2)) * 10.0 +Y_noisy[:n2, :] = ( + torch.cos(torch.arange(n2) * torch.pi / n2)[:, None] + + torch.cos(torch.arange(d2) * torch.pi / d2)[None, :] +) +Y = Y_noisy[:n2, :] + +X, Y_noisy, Y = X.double(), Y_noisy.double(), Y.double() + +fig, axes = pl.subplots(nrows=1, ncols=3, figsize=(12, 5)) +axes[0].imshow(X, vmin=-2, vmax=2) +axes[0].set_title('$X$') + +axes[1].imshow(Y, vmin=-2, vmax=2) +axes[1].set_title('Clean $Y$') + +axes[2].imshow(Y_noisy, vmin=-2, vmax=2) +axes[2].set_title('Noisy $Y$') + +pl.tight_layout() + +# %% +# Optimize the COOT distance with respect to the sample marginal distribution +# --------------------------------------------------------------------------- + +losses = [] +lr = 1e-3 +niter = 1000 + +b = torch.tensor(ot.unif(n), requires_grad=True) + +for i in range(niter): + + loss = coot2(X, Y_noisy, wy_samp=b, log=False, verbose=False) + losses.append(float(loss)) + + loss.backward() + + with torch.no_grad(): + b -= lr * b.grad # gradient step + b[:] = ot.utils.proj_simplex(b) # projection on the simplex + + b.grad.zero_() + +# Estimated sample marginal distribution and training loss curve +pl.plot(losses[10:]) +pl.title('CO-Optimal Transport distance') + +print(f"Marginal distribution = {b.detach().numpy()}") + +# %% +# Visualizing the row and column alignments with the estimated sample marginal distribution +# ----------------------------------------------------------------------------------------- +# +# Clearly, the learned marginal distribution completely and successfully ignores the 5 outliers. + +X, Y_noisy = X.numpy(), Y_noisy.numpy() +b = b.detach().numpy() + +pi_sample, pi_feature = coot(X, Y_noisy, wy_samp=b, log=False, verbose=True) + +fig = pl.figure(4, (9, 7)) +pl.clf() + +ax1 = pl.subplot(2, 2, 3) +pl.imshow(X, vmin=-2, vmax=2) +pl.xlabel('$X$') + +ax2 = pl.subplot(2, 2, 2) +ax2.yaxis.tick_right() +pl.imshow(np.transpose(Y_noisy), vmin=-2, vmax=2) +pl.title("Transpose(Noisy $Y$)") +ax2.xaxis.tick_top() + +for i in range(n1): + j = np.argmax(pi_sample[i, :]) + xyA = (d1 - .5, i) + xyB = (j, d2 - .5) + con = ConnectionPatch(xyA=xyA, xyB=xyB, coordsA=ax1.transData, + coordsB=ax2.transData, color="black") + fig.add_artist(con) + +for i in range(d1): + j = np.argmax(pi_feature[i, :]) + xyA = (i, -.5) + xyB = (-.5, j) + con = ConnectionPatch( + xyA=xyA, xyB=xyB, coordsA=ax1.transData, coordsB=ax2.transData, color="blue") + fig.add_artist(con) diff --git a/ot/coot.py b/ot/coot.py new file mode 100644 index 000000000..66dd2c84a --- /dev/null +++ b/ot/coot.py @@ -0,0 +1,434 @@ +# -*- coding: utf-8 -*- +""" +CO-Optimal Transport solver +""" + +# Author: Quang Huy Tran +# +# License: MIT License + +import warnings +from .lp import emd +from .utils import list_to_array +from .backend import get_backend +from .bregman import sinkhorn + + +def co_optimal_transport(X, Y, wx_samp=None, wx_feat=None, wy_samp=None, wy_feat=None, + epsilon=0, alpha=0, M_samp=None, M_feat=None, + warmstart=None, nits_bcd=100, tol_bcd=1e-7, eval_bcd=1, + nits_ot=500, tol_sinkhorn=1e-7, method_sinkhorn="sinkhorn", + early_stopping_tol=1e-6, log=False, verbose=False): + r"""Compute the CO-Optimal Transport between two matrices. + + Return the sample and feature transport plans between + :math:`(\mathbf{X}, \mathbf{w}_{xs}, \mathbf{w}_{xf})` and + :math:`(\mathbf{Y}, \mathbf{w}_{ys}, \mathbf{w}_{yf})`. + + The function solves the following CO-Optimal Transport (COOT) problem: + + .. math:: + \mathbf{COOT}_{\alpha, \varepsilon} = \mathop{\arg \min}_{\mathbf{P}, \mathbf{Q}} + &\quad \sum_{i,j,k,l} + (\mathbf{X}_{i,k} - \mathbf{Y}_{j,l})^2 \mathbf{P}_{i,j} \mathbf{Q}_{k,l} + + \alpha_s \sum_{i,j} \mathbf{P}_{i,j} \mathbf{M^{(s)}}_{i, j} \\ + &+ \alpha_f \sum_{k, l} \mathbf{Q}_{k,l} \mathbf{M^{(f)}}_{k, l} + + \varepsilon_s \mathbf{KL}(\mathbf{P} | \mathbf{w}_{xs} \mathbf{w}_{ys}^T) + + \varepsilon_f \mathbf{KL}(\mathbf{Q} | \mathbf{w}_{xf} \mathbf{w}_{yf}^T) + + Where : + + - :math:`\mathbf{X}`: Data matrix in the source space + - :math:`\mathbf{Y}`: Data matrix in the target space + - :math:`\mathbf{M^{(s)}}`: Additional sample matrix + - :math:`\mathbf{M^{(f)}}`: Additional feature matrix + - :math:`\mathbf{w}_{xs}`: Distribution of the samples in the source space + - :math:`\mathbf{w}_{xf}`: Distribution of the features in the source space + - :math:`\mathbf{w}_{ys}`: Distribution of the samples in the target space + - :math:`\mathbf{w}_{yf}`: Distribution of the features in the target space + + .. note:: This function allows epsilon to be zero. + In that case, the :any:`ot.lp.emd` solver of POT will be used. + + Parameters + ---------- + X : (n_sample_x, n_feature_x) array-like, float + First input matrix. + Y : (n_sample_y, n_feature_y) array-like, float + Second input matrix. + wx_samp : (n_sample_x, ) array-like, float, optional (default = None) + Histogram assigned on rows (samples) of matrix X. + Uniform distribution by default. + wx_feat : (n_feature_x, ) array-like, float, optional (default = None) + Histogram assigned on columns (features) of matrix X. + Uniform distribution by default. + wy_samp : (n_sample_y, ) array-like, float, optional (default = None) + Histogram assigned on rows (samples) of matrix Y. + Uniform distribution by default. + wy_feat : (n_feature_y, ) array-like, float, optional (default = None) + Histogram assigned on columns (features) of matrix Y. + Uniform distribution by default. + epsilon : scalar or indexable object of length 2, float or int, optional (default = 0) + Regularization parameters for entropic approximation of sample and feature couplings. + Allow the case where epsilon contains 0. In that case, the EMD solver is used instead of + Sinkhorn solver. If epsilon is scalar, then the same epsilon is applied to + both regularization of sample and feature couplings. + alpha : scalar or indexable object of length 2, float or int, optional (default = 0) + Coeffficient parameter of linear terms with respect to the sample and feature couplings. + If alpha is scalar, then the same alpha is applied to both linear terms. + M_samp : (n_sample_x, n_sample_y), float, optional (default = None) + Sample matrix with respect to the linear term on sample coupling. + M_feat : (n_feature_x, n_feature_y), float, optional (default = None) + Feature matrix with respect to the linear term on feature coupling. + warmstart : dictionary, optional (default = None) + Contains 4 keys: + - "duals_sample" and "duals_feature" whose values are + tuples of 2 vectors of size (n_sample_x, n_sample_y) and (n_feature_x, n_feature_y). + Initialization of sample and feature dual vectors + if using Sinkhorn algorithm. Zero vectors by default. + + - "pi_sample" and "pi_feature" whose values are matrices + of size (n_sample_x, n_sample_y) and (n_feature_x, n_feature_y). + Initialization of sample and feature couplings. + Uniform distributions by default. + nits_bcd : int, optional (default = 100) + Number of Block Coordinate Descent (BCD) iterations to solve COOT. + tol_bcd : float, optional (default = 1e-7) + Tolerance of BCD scheme. If the L1-norm between the current and previous + sample couplings is under this threshold, then stop BCD scheme. + eval_bcd : int, optional (default = 1) + Multiplier of iteration at which the COOT cost is evaluated. For example, + if `eval_bcd = 8`, then the cost is calculated at iterations 8, 16, 24, etc... + nits_ot : int, optional (default = 100) + Number of iterations to solve each of the + two optimal transport problems in each BCD iteration. + tol_sinkhorn : float, optional (default = 1e-7) + Tolerance of Sinkhorn algorithm to stop the Sinkhorn scheme for + entropic optimal transport problem (if any) in each BCD iteration. + Only triggered when Sinkhorn solver is used. + method_sinkhorn : string, optional (default = "sinkhorn") + Method used in POT's `ot.sinkhorn` solver. + Only support "sinkhorn" and "sinkhorn_log". + early_stopping_tol : float, optional (default = 1e-6) + Tolerance for the early stopping. If the absolute difference between + the last 2 recorded COOT distances is under this tolerance, then stop BCD scheme. + log : bool, optional (default = False) + If True then the cost and 4 dual vectors, including + 2 from sample and 2 from feature couplings, are recorded. + verbose : bool, optional (default = False) + If True then print the COOT cost at every multiplier of `eval_bcd`-th iteration. + + Returns + ------- + pi_samp : (n_sample_x, n_sample_y) array-like, float + Sample coupling matrix. + pi_feat : (n_feature_x, n_feature_y) array-like, float + Feature coupling matrix. + log : dictionary, optional + Returned if `log` is True. The keys are: + duals_sample : (n_sample_x, n_sample_y) tuple, float + Pair of dual vectors when solving OT problem w.r.t the sample coupling. + duals_feature : (n_feature_x, n_feature_y) tuple, float + Pair of dual vectors when solving OT problem w.r.t the feature coupling. + distances : list, float + List of COOT distances. + + References + ---------- + .. [49] I. Redko, T. Vayer, R. Flamary, and N. Courty, CO-Optimal Transport, + Advances in Neural Information Processing ny_sampstems, 33 (2020). + """ + + def compute_kl(p, q): + kl = nx.sum(p * nx.log(p + 1.0 * (p == 0))) - nx.sum(p * nx.log(q)) + return kl + + # Main function + + if method_sinkhorn not in ["sinkhorn", "sinkhorn_log"]: + raise ValueError( + "Method {} is not supported in CO-Optimal Transport.".format(method_sinkhorn)) + + X, Y = list_to_array(X, Y) + nx = get_backend(X, Y) + + if isinstance(epsilon, float) or isinstance(epsilon, int): + eps_samp, eps_feat = epsilon, epsilon + else: + if len(epsilon) != 2: + raise ValueError("Epsilon must be either a scalar or an indexable object of length 2.") + else: + eps_samp, eps_feat = epsilon[0], epsilon[1] + + if isinstance(alpha, float) or isinstance(alpha, int): + alpha_samp, alpha_feat = alpha, alpha + else: + if len(alpha) != 2: + raise ValueError("Alpha must be either a scalar or an indexable object of length 2.") + else: + alpha_samp, alpha_feat = alpha[0], alpha[1] + + # constant input variables + if M_samp is None or alpha_samp == 0: + M_samp, alpha_samp = 0, 0 + if M_feat is None or alpha_feat == 0: + M_feat, alpha_feat = 0, 0 + + nx_samp, nx_feat = X.shape + ny_samp, ny_feat = Y.shape + + # measures on rows and columns + if wx_samp is None: + wx_samp = nx.ones(nx_samp, type_as=X) / nx_samp + if wx_feat is None: + wx_feat = nx.ones(nx_feat, type_as=X) / nx_feat + if wy_samp is None: + wy_samp = nx.ones(ny_samp, type_as=Y) / ny_samp + if wy_feat is None: + wy_feat = nx.ones(ny_feat, type_as=Y) / ny_feat + + wxy_samp = wx_samp[:, None] * wy_samp[None, :] + wxy_feat = wx_feat[:, None] * wy_feat[None, :] + + # pre-calculate cost constants + XY_sqr = (X ** 2 @ wx_feat)[:, None] + (Y ** 2 @ + wy_feat)[None, :] + alpha_samp * M_samp + XY_sqr_T = ((X.T)**2 @ wx_samp)[:, None] + ((Y.T) + ** 2 @ wy_samp)[None, :] + alpha_feat * M_feat + + # initialize coupling and dual vectors + if warmstart is None: + pi_samp, pi_feat = wxy_samp, wxy_feat # shape nx_samp x ny_samp and nx_feat x ny_feat + duals_samp = (nx.zeros(nx_samp, type_as=X), nx.zeros( + ny_samp, type_as=Y)) # shape nx_samp, ny_samp + duals_feat = (nx.zeros(nx_feat, type_as=X), nx.zeros( + ny_feat, type_as=Y)) # shape nx_feat, ny_feat + else: + pi_samp, pi_feat = warmstart["pi_sample"], warmstart["pi_feature"] + duals_samp, duals_feat = warmstart["duals_sample"], warmstart["duals_feature"] + + # initialize log + list_coot = [float("inf")] + err = tol_bcd + 1e-3 + + for idx in range(nits_bcd): + pi_samp_prev = nx.copy(pi_samp) + + # update sample coupling + ot_cost = XY_sqr - 2 * X @ pi_feat @ Y.T # size nx_samp x ny_samp + if eps_samp > 0: + pi_samp, dict_log = sinkhorn(a=wx_samp, b=wy_samp, M=ot_cost, reg=eps_samp, method=method_sinkhorn, + numItermax=nits_ot, stopThr=tol_sinkhorn, log=True, warmstart=duals_samp) + duals_samp = (nx.log(dict_log["u"]), nx.log(dict_log["v"])) + elif eps_samp == 0: + pi_samp, dict_log = emd( + a=wx_samp, b=wy_samp, M=ot_cost, numItermax=nits_ot, log=True) + duals_samp = (dict_log["u"], dict_log["v"]) + # update feature coupling + ot_cost = XY_sqr_T - 2 * X.T @ pi_samp @ Y # size nx_feat x ny_feat + if eps_feat > 0: + pi_feat, dict_log = sinkhorn(a=wx_feat, b=wy_feat, M=ot_cost, reg=eps_feat, method=method_sinkhorn, + numItermax=nits_ot, stopThr=tol_sinkhorn, log=True, warmstart=duals_feat) + duals_feat = (nx.log(dict_log["u"]), nx.log(dict_log["v"])) + elif eps_feat == 0: + pi_feat, dict_log = emd( + a=wx_feat, b=wy_feat, M=ot_cost, numItermax=nits_ot, log=True) + duals_feat = (dict_log["u"], dict_log["v"]) + + if idx % eval_bcd == 0: + # update error + err = nx.sum(nx.abs(pi_samp - pi_samp_prev)) + + # COOT part + coot = nx.sum(ot_cost * pi_feat) + if alpha_samp != 0: + coot = coot + alpha_samp * nx.sum(M_samp * pi_samp) + # Entropic part + if eps_samp != 0: + coot = coot + eps_samp * compute_kl(pi_samp, wxy_samp) + if eps_feat != 0: + coot = coot + eps_feat * compute_kl(pi_feat, wxy_feat) + list_coot.append(coot) + + if err < tol_bcd or abs(list_coot[-2] - list_coot[-1]) < early_stopping_tol: + break + + if verbose: + print( + "CO-Optimal Transport cost at iteration {}: {}".format(idx + 1, coot)) + + # sanity check + if nx.sum(nx.isnan(pi_samp)) > 0 or nx.sum(nx.isnan(pi_feat)) > 0: + warnings.warn("There is NaN in coupling.") + + if log: + dict_log = {"duals_sample": duals_samp, + "duals_feature": duals_feat, + "distances": list_coot[1:]} + + return pi_samp, pi_feat, dict_log + + else: + return pi_samp, pi_feat + + +def co_optimal_transport2(X, Y, wx_samp=None, wx_feat=None, wy_samp=None, wy_feat=None, + epsilon=0, alpha=0, M_samp=None, M_feat=None, + warmstart=None, log=False, verbose=False, early_stopping_tol=1e-6, + nits_bcd=100, tol_bcd=1e-7, eval_bcd=1, + nits_ot=500, tol_sinkhorn=1e-7, + method_sinkhorn="sinkhorn"): + r"""Compute the CO-Optimal Transport distance between two measures. + + Returns the CO-Optimal Transport distance between + :math:`(\mathbf{X}, \mathbf{w}_{xs}, \mathbf{w}_{xf})` and + :math:`(\mathbf{Y}, \mathbf{w}_{ys}, \mathbf{w}_{yf})`. + + The function solves the following CO-Optimal Transport (COOT) problem: + + .. math:: + \mathbf{COOT}_{\alpha, \varepsilon} = \mathop{\arg \min}_{\mathbf{P}, \mathbf{Q}} + &\quad \sum_{i,j,k,l} + (\mathbf{X}_{i,k} - \mathbf{Y}_{j,l})^2 \mathbf{P}_{i,j} \mathbf{Q}_{k,l} + + \alpha_1 \sum_{i,j} \mathbf{P}_{i,j} \mathbf{M^{(s)}}_{i, j} \\ + &+ \alpha_2 \sum_{k, l} \mathbf{Q}_{k,l} \mathbf{M^{(f)}}_{k, l} + + \varepsilon_1 \mathbf{KL}(\mathbf{P} | \mathbf{w}_{xs} \mathbf{w}_{ys}^T) + + \varepsilon_2 \mathbf{KL}(\mathbf{Q} | \mathbf{w}_{xf} \mathbf{w}_{yf}^T) + + Where : + + - :math:`\mathbf{X}`: Data matrix in the source space + - :math:`\mathbf{Y}`: Data matrix in the target space + - :math:`\mathbf{M^{(s)}}`: Additional sample matrix + - :math:`\mathbf{M^{(f)}}`: Additional feature matrix + - :math:`\mathbf{w}_{xs}`: Distribution of the samples in the source space + - :math:`\mathbf{w}_{xf}`: Distribution of the features in the source space + - :math:`\mathbf{w}_{ys}`: Distribution of the samples in the target space + - :math:`\mathbf{w}_{yf}`: Distribution of the features in the target space + + .. note:: This function allows epsilon to be zero. + In that case, the :any:`ot.lp.emd` solver of POT will be used. + + Parameters + ---------- + X : (n_sample_x, n_feature_x) array-like, float + First input matrix. + Y : (n_sample_y, n_feature_y) array-like, float + Second input matrix. + wx_samp : (n_sample_x, ) array-like, float, optional (default = None) + Histogram assigned on rows (samples) of matrix X. + Uniform distribution by default. + wx_feat : (n_feature_x, ) array-like, float, optional (default = None) + Histogram assigned on columns (features) of matrix X. + Uniform distribution by default. + wy_samp : (n_sample_y, ) array-like, float, optional (default = None) + Histogram assigned on rows (samples) of matrix Y. + Uniform distribution by default. + wy_feat : (n_feature_y, ) array-like, float, optional (default = None) + Histogram assigned on columns (features) of matrix Y. + Uniform distribution by default. + epsilon : scalar or indexable object of length 2, float or int, optional (default = 0) + Regularization parameters for entropic approximation of sample and feature couplings. + Allow the case where epsilon contains 0. In that case, the EMD solver is used instead of + Sinkhorn solver. If epsilon is scalar, then the same epsilon is applied to + both regularization of sample and feature couplings. + alpha : scalar or indexable object of length 2, float or int, optional (default = 0) + Coeffficient parameter of linear terms with respect to the sample and feature couplings. + If alpha is scalar, then the same alpha is applied to both linear terms. + M_samp : (n_sample_x, n_sample_y), float, optional (default = None) + Sample matrix with respect to the linear term on sample coupling. + M_feat : (n_feature_x, n_feature_y), float, optional (default = None) + Feature matrix with respect to the linear term on feature coupling. + warmstart : dictionary, optional (default = None) + Contains 4 keys: + - "duals_sample" and "duals_feature" whose values are + tuples of 2 vectors of size (n_sample_x, n_sample_y) and (n_feature_x, n_feature_y). + Initialization of sample and feature dual vectors + if using Sinkhorn algorithm. Zero vectors by default. + + - "pi_sample" and "pi_feature" whose values are matrices + of size (n_sample_x, n_sample_y) and (n_feature_x, n_feature_y). + Initialization of sample and feature couplings. + Uniform distributions by default. + nits_bcd : int, optional (default = 100) + Number of Block Coordinate Descent (BCD) iterations to solve COOT. + tol_bcd : float, optional (default = 1e-7) + Tolerance of BCD scheme. If the L1-norm between the current and previous + sample couplings is under this threshold, then stop BCD scheme. + eval_bcd : int, optional (default = 1) + Multiplier of iteration at which the COOT cost is evaluated. For example, + if `eval_bcd = 8`, then the cost is calculated at iterations 8, 16, 24, etc... + nits_ot : int, optional (default = 100) + Number of iterations to solve each of the + two optimal transport problems in each BCD iteration. + tol_sinkhorn : float, optional (default = 1e-7) + Tolerance of Sinkhorn algorithm to stop the Sinkhorn scheme for + entropic optimal transport problem (if any) in each BCD iteration. + Only triggered when Sinkhorn solver is used. + method_sinkhorn : string, optional (default = "sinkhorn") + Method used in POT's `ot.sinkhorn` solver. + Only support "sinkhorn" and "sinkhorn_log". + early_stopping_tol : float, optional (default = 1e-6) + Tolerance for the early stopping. If the absolute difference between + the last 2 recorded COOT distances is under this tolerance, then stop BCD scheme. + log : bool, optional (default = False) + If True then the cost and 4 dual vectors, including + 2 from sample and 2 from feature couplings, are recorded. + verbose : bool, optional (default = False) + If True then print the COOT cost at every multiplier of `eval_bcd`-th iteration. + + Returns + ------- + float + CO-Optimal Transport distance. + dict + Contains logged informations from :any:`co_optimal_transport` solver. + Only returned if `log` parameter is True + + References + ---------- + .. [47] I. Redko, T. Vayer, R. Flamary, and N. Courty, CO-Optimal Transport, + Advances in Neural Information Processing ny_sampstems, 33 (2020). + """ + + pi_samp, pi_feat, dict_log = co_optimal_transport(X=X, Y=Y, wx_samp=wx_samp, wx_feat=wx_feat, wy_samp=wy_samp, + wy_feat=wy_feat, epsilon=epsilon, alpha=alpha, M_samp=M_samp, + M_feat=M_feat, warmstart=warmstart, nits_bcd=nits_bcd, + tol_bcd=tol_bcd, eval_bcd=eval_bcd, nits_ot=nits_ot, + tol_sinkhorn=tol_sinkhorn, method_sinkhorn=method_sinkhorn, + early_stopping_tol=early_stopping_tol, + log=True, verbose=verbose) + + X, Y = list_to_array(X, Y) + nx = get_backend(X, Y) + + nx_samp, nx_feat = X.shape + ny_samp, ny_feat = Y.shape + + # measures on rows and columns + if wx_samp is None: + wx_samp = nx.ones(nx_samp, type_as=X) / nx_samp + if wx_feat is None: + wx_feat = nx.ones(nx_feat, type_as=X) / nx_feat + if wy_samp is None: + wy_samp = nx.ones(ny_samp, type_as=Y) / ny_samp + if wy_feat is None: + wy_feat = nx.ones(ny_feat, type_as=Y) / ny_feat + + vx_samp, vy_samp = dict_log["duals_sample"] + vx_feat, vy_feat = dict_log["duals_feature"] + + gradX = 2 * X * (wx_samp[:, None] * wx_feat[None, :]) - \ + 2 * pi_samp @ Y @ pi_feat.T # shape (nx_samp, nx_feat) + gradY = 2 * Y * (wy_samp[:, None] * wy_feat[None, :]) - \ + 2 * pi_samp.T @ X @ pi_feat # shape (ny_samp, ny_feat) + + coot = dict_log["distances"][-1] + coot = nx.set_gradients(coot, (wx_samp, wx_feat, wy_samp, wy_feat, X, Y), + (vx_samp, vx_feat, vy_samp, vy_feat, gradX, gradY)) + + if log: + return coot, dict_log + + else: + return coot diff --git a/test/test_coot.py b/test/test_coot.py new file mode 100644 index 000000000..ef68a9bcd --- /dev/null +++ b/test/test_coot.py @@ -0,0 +1,359 @@ +"""Tests for module COOT on OT """ + +# Author: Quang Huy Tran +# +# License: MIT License + +import numpy as np +import ot +from ot.coot import co_optimal_transport as coot +from ot.coot import co_optimal_transport2 as coot2 +import pytest + + +@pytest.mark.parametrize("verbose", [False, True, 1, 0]) +def test_coot(nx, verbose): + n_samples = 60 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=4) + xt = xs[::-1].copy() + xs_nx = nx.from_numpy(xs) + xt_nx = nx.from_numpy(xt) + + # test couplings + pi_sample, pi_feature = coot(X=xs, Y=xt, verbose=verbose) + pi_sample_nx, pi_feature_nx = coot(X=xs_nx, Y=xt_nx, verbose=verbose) + pi_sample_nx = nx.to_numpy(pi_sample_nx) + pi_feature_nx = nx.to_numpy(pi_feature_nx) + + anti_id_sample = np.flipud(np.eye(n_samples, n_samples)) / n_samples + id_feature = np.eye(2, 2) / 2 + + np.testing.assert_allclose(pi_sample, anti_id_sample, atol=1e-04) + np.testing.assert_allclose(pi_sample_nx, anti_id_sample, atol=1e-04) + np.testing.assert_allclose(pi_feature, id_feature, atol=1e-04) + np.testing.assert_allclose(pi_feature_nx, id_feature, atol=1e-04) + + # test marginal distributions + px_s, px_f = ot.unif(n_samples), ot.unif(2) + py_s, py_f = ot.unif(n_samples), ot.unif(2) + + np.testing.assert_allclose(px_s, pi_sample_nx.sum(0), atol=1e-04) + np.testing.assert_allclose(py_s, pi_sample_nx.sum(1), atol=1e-04) + np.testing.assert_allclose(px_f, pi_feature_nx.sum(0), atol=1e-04) + np.testing.assert_allclose(py_f, pi_feature_nx.sum(1), atol=1e-04) + + np.testing.assert_allclose(px_s, pi_sample.sum(0), atol=1e-04) + np.testing.assert_allclose(py_s, pi_sample.sum(1), atol=1e-04) + np.testing.assert_allclose(px_f, pi_feature.sum(0), atol=1e-04) + np.testing.assert_allclose(py_f, pi_feature.sum(1), atol=1e-04) + + # test COOT distance + + coot_np = coot2(X=xs, Y=xt, verbose=verbose) + coot_nx = nx.to_numpy(coot2(X=xs_nx, Y=xt_nx, verbose=verbose)) + np.testing.assert_allclose(coot_np, 0, atol=1e-08) + np.testing.assert_allclose(coot_nx, 0, atol=1e-08) + + +def test_entropic_coot(nx): + n_samples = 60 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=4) + xt = xs[::-1].copy() + xs_nx = nx.from_numpy(xs) + xt_nx = nx.from_numpy(xt) + + epsilon = (1, 1e-1) + nits_ot = 2000 + + # test couplings + pi_sample, pi_feature = coot(X=xs, Y=xt, epsilon=epsilon, nits_ot=nits_ot) + pi_sample_nx, pi_feature_nx = coot( + X=xs_nx, Y=xt_nx, epsilon=epsilon, nits_ot=nits_ot) + pi_sample_nx = nx.to_numpy(pi_sample_nx) + pi_feature_nx = nx.to_numpy(pi_feature_nx) + + np.testing.assert_allclose(pi_sample, pi_sample_nx, atol=1e-04) + np.testing.assert_allclose(pi_feature, pi_feature_nx, atol=1e-04) + + # test marginal distributions + px_s, px_f = ot.unif(n_samples), ot.unif(2) + py_s, py_f = ot.unif(n_samples), ot.unif(2) + + np.testing.assert_allclose(px_s, pi_sample_nx.sum(0), atol=1e-04) + np.testing.assert_allclose(py_s, pi_sample_nx.sum(1), atol=1e-04) + np.testing.assert_allclose(px_f, pi_feature_nx.sum(0), atol=1e-04) + np.testing.assert_allclose(py_f, pi_feature_nx.sum(1), atol=1e-04) + + np.testing.assert_allclose(px_s, pi_sample.sum(0), atol=1e-04) + np.testing.assert_allclose(py_s, pi_sample.sum(1), atol=1e-04) + np.testing.assert_allclose(px_f, pi_feature.sum(0), atol=1e-04) + np.testing.assert_allclose(py_f, pi_feature.sum(1), atol=1e-04) + + # test entropic COOT distance + + coot_np = coot2(X=xs, Y=xt, epsilon=epsilon, nits_ot=nits_ot) + coot_nx = nx.to_numpy( + coot2(X=xs_nx, Y=xt_nx, epsilon=epsilon, nits_ot=nits_ot)) + + np.testing.assert_allclose(coot_np, coot_nx, atol=1e-08) + + +def test_coot_with_linear_terms(nx): + n_samples = 60 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=4) + xt = xs[::-1].copy() + xs_nx = nx.from_numpy(xs) + xt_nx = nx.from_numpy(xt) + + M_samp = np.ones((n_samples, n_samples)) + np.fill_diagonal(np.fliplr(M_samp), 0) + M_feat = np.ones((2, 2)) + np.fill_diagonal(M_feat, 0) + M_samp_nx, M_feat_nx = nx.from_numpy(M_samp), nx.from_numpy(M_feat) + + alpha = (1, 2) + + # test couplings + anti_id_sample = np.flipud(np.eye(n_samples, n_samples)) / n_samples + id_feature = np.eye(2, 2) / 2 + + pi_sample, pi_feature = coot( + X=xs, Y=xt, alpha=alpha, M_samp=M_samp, M_feat=M_feat) + pi_sample_nx, pi_feature_nx = coot( + X=xs_nx, Y=xt_nx, alpha=alpha, M_samp=M_samp_nx, M_feat=M_feat_nx) + pi_sample_nx = nx.to_numpy(pi_sample_nx) + pi_feature_nx = nx.to_numpy(pi_feature_nx) + + np.testing.assert_allclose(pi_sample, anti_id_sample, atol=1e-04) + np.testing.assert_allclose(pi_sample_nx, anti_id_sample, atol=1e-04) + np.testing.assert_allclose(pi_feature, id_feature, atol=1e-04) + np.testing.assert_allclose(pi_feature_nx, id_feature, atol=1e-04) + + # test marginal distributions + px_s, px_f = ot.unif(n_samples), ot.unif(2) + py_s, py_f = ot.unif(n_samples), ot.unif(2) + + np.testing.assert_allclose(px_s, pi_sample_nx.sum(0), atol=1e-04) + np.testing.assert_allclose(py_s, pi_sample_nx.sum(1), atol=1e-04) + np.testing.assert_allclose(px_f, pi_feature_nx.sum(0), atol=1e-04) + np.testing.assert_allclose(py_f, pi_feature_nx.sum(1), atol=1e-04) + + np.testing.assert_allclose(px_s, pi_sample.sum(0), atol=1e-04) + np.testing.assert_allclose(py_s, pi_sample.sum(1), atol=1e-04) + np.testing.assert_allclose(px_f, pi_feature.sum(0), atol=1e-04) + np.testing.assert_allclose(py_f, pi_feature.sum(1), atol=1e-04) + + # test COOT distance + + coot_np = coot2(X=xs, Y=xt, alpha=alpha, M_samp=M_samp, M_feat=M_feat) + coot_nx = nx.to_numpy( + coot2(X=xs_nx, Y=xt_nx, alpha=alpha, M_samp=M_samp_nx, M_feat=M_feat_nx)) + np.testing.assert_allclose(coot_np, 0, atol=1e-08) + np.testing.assert_allclose(coot_nx, 0, atol=1e-08) + + +def test_coot_raise_value_error(nx): + n_samples = 80 # nb samples + + mu_s = np.array([2, 4]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=43) + xt = xs[::-1].copy() + xs_nx = nx.from_numpy(xs) + xt_nx = nx.from_numpy(xt) + + # raise value error of method sinkhorn + def coot_sh(method_sinkhorn): + return coot(X=xs, Y=xt, method_sinkhorn=method_sinkhorn) + + def coot_sh_nx(method_sinkhorn): + return coot(X=xs_nx, Y=xt_nx, method_sinkhorn=method_sinkhorn) + + np.testing.assert_raises(ValueError, coot_sh, "not_sinkhorn") + np.testing.assert_raises(ValueError, coot_sh_nx, "not_sinkhorn") + + # raise value error for epsilon + def coot_eps(epsilon): + return coot(X=xs, Y=xt, epsilon=epsilon) + + def coot_eps_nx(epsilon): + return coot(X=xs_nx, Y=xt_nx, epsilon=epsilon) + + np.testing.assert_raises(ValueError, coot_eps, (1, 2, 3)) + np.testing.assert_raises(ValueError, coot_eps_nx, [1, 2, 3, 4]) + + # raise value error for alpha + def coot_alpha(alpha): + return coot(X=xs, Y=xt, alpha=alpha) + + def coot_alpha_nx(alpha): + return coot(X=xs_nx, Y=xt_nx, alpha=alpha) + + np.testing.assert_raises(ValueError, coot_alpha, [1]) + np.testing.assert_raises(ValueError, coot_alpha_nx, np.arange(4)) + + +def test_coot_warmstart(nx): + n_samples = 80 # nb samples + + mu_s = np.array([2, 3]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=125) + xt = xs[::-1].copy() + xs_nx = nx.from_numpy(xs) + xt_nx = nx.from_numpy(xt) + + # initialize warmstart + init_pi_sample = np.random.rand(n_samples, n_samples) + init_pi_sample = init_pi_sample / np.sum(init_pi_sample) + init_pi_sample_nx = nx.from_numpy(init_pi_sample) + + init_pi_feature = np.random.rand(2, 2) + init_pi_feature /= init_pi_feature / np.sum(init_pi_feature) + init_pi_feature_nx = nx.from_numpy(init_pi_feature) + + init_duals_sample = (np.random.random(n_samples) * 2 - 1, + np.random.random(n_samples) * 2 - 1) + init_duals_sample_nx = (nx.from_numpy(init_duals_sample[0]), + nx.from_numpy(init_duals_sample[1])) + + init_duals_feature = (np.random.random(2) * 2 - 1, + np.random.random(2) * 2 - 1) + init_duals_feature_nx = (nx.from_numpy(init_duals_feature[0]), + nx.from_numpy(init_duals_feature[1])) + + warmstart = { + "pi_sample": init_pi_sample, + "pi_feature": init_pi_feature, + "duals_sample": init_duals_sample, + "duals_feature": init_duals_feature + } + + warmstart_nx = { + "pi_sample": init_pi_sample_nx, + "pi_feature": init_pi_feature_nx, + "duals_sample": init_duals_sample_nx, + "duals_feature": init_duals_feature_nx + } + + # test couplings + pi_sample, pi_feature = coot(X=xs, Y=xt, warmstart=warmstart) + pi_sample_nx, pi_feature_nx = coot( + X=xs_nx, Y=xt_nx, warmstart=warmstart_nx) + pi_sample_nx = nx.to_numpy(pi_sample_nx) + pi_feature_nx = nx.to_numpy(pi_feature_nx) + + anti_id_sample = np.flipud(np.eye(n_samples, n_samples)) / n_samples + id_feature = np.eye(2, 2) / 2 + + np.testing.assert_allclose(pi_sample, anti_id_sample, atol=1e-04) + np.testing.assert_allclose(pi_sample_nx, anti_id_sample, atol=1e-04) + np.testing.assert_allclose(pi_feature, id_feature, atol=1e-04) + np.testing.assert_allclose(pi_feature_nx, id_feature, atol=1e-04) + + # test marginal distributions + px_s, px_f = ot.unif(n_samples), ot.unif(2) + py_s, py_f = ot.unif(n_samples), ot.unif(2) + + np.testing.assert_allclose(px_s, pi_sample_nx.sum(0), atol=1e-04) + np.testing.assert_allclose(py_s, pi_sample_nx.sum(1), atol=1e-04) + np.testing.assert_allclose(px_f, pi_feature_nx.sum(0), atol=1e-04) + np.testing.assert_allclose(py_f, pi_feature_nx.sum(1), atol=1e-04) + + np.testing.assert_allclose(px_s, pi_sample.sum(0), atol=1e-04) + np.testing.assert_allclose(py_s, pi_sample.sum(1), atol=1e-04) + np.testing.assert_allclose(px_f, pi_feature.sum(0), atol=1e-04) + np.testing.assert_allclose(py_f, pi_feature.sum(1), atol=1e-04) + + # test COOT distance + coot_np = coot2(X=xs, Y=xt, warmstart=warmstart) + coot_nx = nx.to_numpy(coot2(X=xs_nx, Y=xt_nx, warmstart=warmstart_nx)) + np.testing.assert_allclose(coot_np, 0, atol=1e-08) + np.testing.assert_allclose(coot_nx, 0, atol=1e-08) + + +def test_coot_log(nx): + n_samples = 90 # nb samples + + mu_s = np.array([-2, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss( + n_samples, mu_s, cov_s, random_state=43) + xt = xs[::-1].copy() + xs_nx = nx.from_numpy(xs) + xt_nx = nx.from_numpy(xt) + + pi_sample, pi_feature, log = coot(X=xs, Y=xt, log=True) + pi_sample_nx, pi_feature_nx, log_nx = coot(X=xs_nx, Y=xt_nx, log=True) + + duals_sample, duals_feature = log["duals_sample"], log["duals_feature"] + assert len(duals_sample) == 2 + assert len(duals_feature) == 2 + assert len(duals_sample[0]) == n_samples + assert len(duals_sample[1]) == n_samples + assert len(duals_feature[0]) == 2 + assert len(duals_feature[1]) == 2 + + duals_sample_nx = log_nx["duals_sample"] + assert len(duals_sample_nx) == 2 + assert len(duals_sample_nx[0]) == n_samples + assert len(duals_sample_nx[1]) == n_samples + + duals_feature_nx = log_nx["duals_feature"] + assert len(duals_feature_nx) == 2 + assert len(duals_feature_nx[0]) == 2 + assert len(duals_feature_nx[1]) == 2 + + list_coot = log["distances"] + assert len(list_coot) >= 1 + + list_coot_nx = log_nx["distances"] + assert len(list_coot_nx) >= 1 + + # test with coot distance + coot_np, log = coot2(X=xs, Y=xt, log=True) + coot_nx, log_nx = coot2(X=xs_nx, Y=xt_nx, log=True) + + duals_sample, duals_feature = log["duals_sample"], log["duals_feature"] + assert len(duals_sample) == 2 + assert len(duals_feature) == 2 + assert len(duals_sample[0]) == n_samples + assert len(duals_sample[1]) == n_samples + assert len(duals_feature[0]) == 2 + assert len(duals_feature[1]) == 2 + + duals_sample_nx = log_nx["duals_sample"] + assert len(duals_sample_nx) == 2 + assert len(duals_sample_nx[0]) == n_samples + assert len(duals_sample_nx[1]) == n_samples + + duals_feature_nx = log_nx["duals_feature"] + assert len(duals_feature_nx) == 2 + assert len(duals_feature_nx[0]) == 2 + assert len(duals_feature_nx[1]) == 2 + + list_coot = log["distances"] + assert len(list_coot) >= 1 + + list_coot_nx = log_nx["distances"] + assert len(list_coot_nx) >= 1