Skip to content

Commit

Permalink
[MRG] Sliced wasserstein (PythonOT#203)
Browse files Browse the repository at this point in the history
* example for log treatment in bregman.py

* Improve doc

* Revert "example for log treatment in bregman.py"

This reverts commit 9f51c14

* Add comments by Flamary

* Delete repetitive description

* Added raw string to avoid pbs with backslashes

* Implements sliced wasserstein

* Changed formatting of string for py3.5 support

* Docstest, expected 0.0 and not 0.

* Adressed comments by @rflamary

* No 3d plot here

* add sliced to the docs

* Incorporate comments by @rflamary

* add link to pdf

Co-authored-by: Rémi Flamary <[email protected]>
  • Loading branch information
AdrienCorenflos and rflamary authored Oct 22, 2020
1 parent 7adc1b1 commit 78b44af
Show file tree
Hide file tree
Showing 7 changed files with 324 additions and 1 deletion.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ POT provides the following generic OT solvers (links to examples):
* [Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25].
* [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) (exact [29] and entropic [3]
formulations).
* [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32].

POT provides the following Machine Learning related solvers:

Expand Down Expand Up @@ -180,6 +181,7 @@ The contributors to this library are
* [Romain Tavenard](https://rtavenar.github.io/) (1d Wasserstein)
* [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn)
* [Ievgen Redko](https://ievred.github.io/) (Laplacian DA, JCPOT)
* [Adrien Corenflos](https://adriencorenflos.github.io/) (Sliced Wasserstein Distance)

This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):

Expand Down Expand Up @@ -263,3 +265,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[29] Chapel, L., Alaya, M., Gasso, G. (2019). [Partial Gromov-Wasserstein with Applications on Positive-Unlabeled Learning](https://arxiv.org/abs/2002.08276), arXiv preprint arXiv:2002.08276.

[30] Flamary R., Courty N., Tuia D., Rakotomamonjy A. (2014). [Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching](https://remi.flamary.com/biblio/flamary2014optlaplace.pdf), NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014.

[31] Bonneel, Nicolas, et al. [Sliced and radon wasserstein barycenters of measures](https://perso.liris.cnrs.fr/nicolas.bonneel/WassersteinSliced-JMIV.pdf), Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
1 change: 1 addition & 0 deletions docs/source/all.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ API and modules
stochastic
unbalanced
partial
sliced

.. autosummary::
:toctree: ../modules/generated/
Expand Down
4 changes: 4 additions & 0 deletions examples/sliced-wasserstein/README.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@


Sliced Wasserstein Distance
---------------------------
84 changes: 84 additions & 0 deletions examples/sliced-wasserstein/plot_variance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# -*- coding: utf-8 -*-
"""
==============================
2D Sliced Wasserstein Distance
==============================
This example illustrates the computation of the sliced Wasserstein Distance as proposed in [31].
[31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
"""

# Author: Adrien Corenflos <[email protected]>
#
# License: MIT License

import matplotlib.pylab as pl
import numpy as np

import ot

##############################################################################
# Generate data
# -------------

# %% parameters and data generation

n = 500 # nb samples

mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])

mu_t = np.array([4, 4])
cov_t = np.array([[1, -.8], [-.8, 1]])

xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)

a, b = np.ones((n,)) / n, np.ones((n,)) / n # uniform distribution on samples

##############################################################################
# Plot data
# ---------

# %% plot samples

pl.figure(1)
pl.plot(xs[:, 0], xs[:, 1], '+b', label='Source samples')
pl.plot(xt[:, 0], xt[:, 1], 'xr', label='Target samples')
pl.legend(loc=0)
pl.title('Source and target distributions')

###################################################################################
# Compute Sliced Wasserstein distance for different seeds and number of projections
# -----------

n_seed = 50
n_projections_arr = np.logspace(0, 3, 25, dtype=int)
res = np.empty((n_seed, 25))

# %% Compute statistics
for seed in range(n_seed):
for i, n_projections in enumerate(n_projections_arr):
res[seed, i] = ot.sliced_wasserstein_distance(xs, xt, a, b, n_projections, seed)

res_mean = np.mean(res, axis=0)
res_std = np.std(res, axis=0)

###################################################################################
# Plot Sliced Wasserstein Distance
# -----------

pl.figure(2)
pl.plot(n_projections_arr, res_mean, label="SWD")
pl.fill_between(n_projections_arr, res_mean - 2 * res_std, res_mean + 2 * res_std, alpha=0.5)

pl.legend()
pl.xscale('log')

pl.xlabel("Number of projections")
pl.ylabel("Distance")
pl.title('Sliced Wasserstein Distance with 95% confidence inverval')

pl.show()
3 changes: 2 additions & 1 deletion ot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from .bregman import sinkhorn, sinkhorn2, barycenter
from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2
from .da import sinkhorn_lpl1_mm
from .sliced import sliced_wasserstein_distance

# utils functions
from .utils import dist, unif, tic, toc, toq
Expand All @@ -50,4 +51,4 @@
'emd_1d', 'emd2_1d', 'wasserstein_1d',
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
'sinkhorn_unbalanced', 'barycenter_unbalanced',
'sinkhorn_unbalanced2']
'sinkhorn_unbalanced2', 'sliced_wasserstein_distance']
144 changes: 144 additions & 0 deletions ot/sliced.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""
Sliced Wasserstein Distance.
"""

# Author: Adrien Corenflos <[email protected]>
#
# License: MIT License


import numpy as np


def get_random_projections(n_projections, d, seed=None):
r"""
Generates n_projections samples from the uniform on the unit sphere of dimension d-1: :math:`\mathcal{U}(\mathcal{S}^{d-1})`
Parameters
----------
n_projections : int
number of samples requested
d : int
dimension of the space
seed: int or RandomState, optional
Seed used for numpy random number generator
Returns
-------
out: ndarray, shape (n_projections, d)
The uniform unit vectors on the sphere
Examples
--------
>>> n_projections = 100
>>> d = 5
>>> projs = get_random_projections(n_projections, d)
>>> np.allclose(np.sum(np.square(projs), 1), 1.) # doctest: +NORMALIZE_WHITESPACE
True
"""

if not isinstance(seed, np.random.RandomState):
random_state = np.random.RandomState(seed)
else:
random_state = seed

projections = random_state.normal(0., 1., [n_projections, d])
norm = np.linalg.norm(projections, ord=2, axis=1, keepdims=True)
projections = projections / norm
return projections


def sliced_wasserstein_distance(X_s, X_t, a=None, b=None, n_projections=50, seed=None, log=False):
r"""
Computes a Monte-Carlo approximation of the 2-Sliced Wasserstein distance
.. math::
\mathcal{SWD}_2(\mu, \nu) = \underset{\theta \sim \mathcal{U}(\mathbb{S}^{d-1})}{\mathbb{E}}[\mathcal{W}_2^2(\theta_\# \mu, \theta_\# \nu)]^{\frac{1}{2}}
where :
- :math:`\theta_\# \mu` stands for the pushforwars of the projection :math:`\mathbb{R}^d \ni X \mapsto \langle \theta, X \rangle`
Parameters
----------
X_s : ndarray, shape (n_samples_a, dim)
samples in the source domain
X_t : ndarray, shape (n_samples_b, dim)
samples in the target domain
a : ndarray, shape (n_samples_a,), optional
samples weights in the source domain
b : ndarray, shape (n_samples_b,), optional
samples weights in the target domain
n_projections : int, optional
Number of projections used for the Monte-Carlo approximation
seed: int or RandomState or None, optional
Seed used for numpy random number generator
log: bool, optional
if True, sliced_wasserstein_distance returns the projections used and their associated EMD.
Returns
-------
cost: float
Sliced Wasserstein Cost
log : dict, optional
log dictionary return only if log==True in parameters
Examples
--------
>>> n_samples_a = 20
>>> reg = 0.1
>>> X = np.random.normal(0., 1., (n_samples_a, 5))
>>> sliced_wasserstein_distance(X, X, seed=0) # doctest: +NORMALIZE_WHITESPACE
0.0
References
----------
.. [31] Bonneel, Nicolas, et al. "Sliced and radon wasserstein barycenters of measures." Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
"""
from .lp import emd2_1d

X_s = np.asanyarray(X_s)
X_t = np.asanyarray(X_t)

n = X_s.shape[0]
m = X_t.shape[0]

if X_s.shape[1] != X_t.shape[1]:
raise ValueError(
"X_s and X_t must have the same number of dimensions {} and {} respectively given".format(X_s.shape[1],
X_t.shape[1]))

if a is None:
a = np.full(n, 1 / n)
if b is None:
b = np.full(m, 1 / m)

d = X_s.shape[1]

projections = get_random_projections(n_projections, d, seed)

X_s_projections = np.dot(projections, X_s.T)
X_t_projections = np.dot(projections, X_t.T)

if log:
projected_emd = np.empty(n_projections)
else:
projected_emd = None

res = 0.

for i, (X_s_proj, X_t_proj) in enumerate(zip(X_s_projections, X_t_projections)):
emd = emd2_1d(X_s_proj, X_t_proj, a, b, log=False, dense=False)
if projected_emd is not None:
projected_emd[i] = emd
res += emd

res = (res / n_projections) ** 0.5
if log:
return res, {"projections": projections, "projected_emds": projected_emd}
return res
85 changes: 85 additions & 0 deletions test/test_sliced.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Tests for module sliced"""

# Author: Adrien Corenflos <[email protected]>
#
# License: MIT License

import numpy as np
import pytest

import ot
from ot.sliced import get_random_projections


def test_get_random_projections():
rng = np.random.RandomState(0)
projections = get_random_projections(1000, 50, rng)
np.testing.assert_almost_equal(np.sum(projections ** 2, 1), 1.)


def test_sliced_same_dist():
n = 100
rng = np.random.RandomState(0)

x = rng.randn(n, 2)
u = ot.utils.unif(n)

res = ot.sliced_wasserstein_distance(x, x, u, u, 10, seed=rng)
np.testing.assert_almost_equal(res, 0.)


def test_sliced_bad_shapes():
n = 100
rng = np.random.RandomState(0)

x = rng.randn(n, 2)
y = rng.randn(n, 4)
u = ot.utils.unif(n)

with pytest.raises(ValueError):
_ = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng)


def test_sliced_log():
n = 100
rng = np.random.RandomState(0)

x = rng.randn(n, 4)
y = rng.randn(n, 4)
u = ot.utils.unif(n)

res, log = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng, log=True)
assert len(log) == 2
projections = log["projections"]
projected_emds = log["projected_emds"]

assert len(projections) == len(projected_emds) == 10
for emd in projected_emds:
assert emd > 0


def test_sliced_different_dists():
n = 100
rng = np.random.RandomState(0)

x = rng.randn(n, 2)
u = ot.utils.unif(n)
y = rng.randn(n, 2)

res = ot.sliced_wasserstein_distance(x, y, u, u, 10, seed=rng)
assert res > 0.


def test_1d_sliced_equals_emd():
n = 100
m = 120
rng = np.random.RandomState(0)

x = rng.randn(n, 1)
a = rng.uniform(0, 1, n)
a /= a.sum()
y = rng.randn(m, 1)
u = ot.utils.unif(m)
res = ot.sliced_wasserstein_distance(x, y, a, u, 10, seed=42)
expected = ot.emd2_1d(x.squeeze(), y.squeeze(), a, u)
np.testing.assert_almost_equal(res ** 2, expected)

0 comments on commit 78b44af

Please sign in to comment.