Skip to content

Commit

Permalink
[MRG] More general solvers for `ot.solveand examples of different v…
Browse files Browse the repository at this point in the history
…ariants. (PythonOT#620)

* add exaple and allow for functional regularizers

* fix test since ow all is implemented

* manuel regularizer available for exact and unbalanecd ot

* exmaple with banaced manuel regularizer

* upate documenation

* pep8

* clenaup envelope instedaof implicit

* big release file update
  • Loading branch information
rflamary authored Apr 26, 2024
1 parent e75c9af commit 81d7631
Show file tree
Hide file tree
Showing 6 changed files with 260 additions and 31 deletions.
16 changes: 13 additions & 3 deletions RELEASES.md
Original file line number Diff line number Diff line change
@@ -1,19 +1,29 @@
# Releases

## 0.9.3dev
## 0.9.4dev

#### New features
+ `ot.gromov._gw.solve_gromov_linesearch` now has an argument to specifify if the matrices are symmetric in which case the computation can be done faster.
+ `ot.gromov._gw.solve_gromov_linesearch` now has an argument to specify if the matrices are symmetric in which case the computation can be done faster (PR #607).
+ Continuous entropic mapping (PR #613)
+ New general unbalanced solvers for `ot.solve` and BFGS solver and illustrative example (PR #620)
+ Add gradient computation with envelope theorem to sinkhorn solver of `ot.solve` with `grad='envelope'` (PR #605).

#### Closed issues
- Fixed an issue with cost correction for mismatched labels in `ot.da.BaseTransport` fit methods. This fix addresses the original issue introduced PR #587 (PR #593)
- Fix gpu compatibility of sr(F)GW solvers when `G0 is not None`(PR #596)
- Fix doc and example for lowrank sinkhorn (PR #601)
- Fix issue with empty weights for `ot.emd2` (PR #606, Issue #534)
- Fix a sign error regarding the gradient of `ot.gromov._gw.fused_gromov_wasserstein2` and `ot.gromov._gw.gromov_wasserstein2` for the kl loss (PR #610)
- Fix same sign error for sr(F)GW conditional gradient solvers (PR #611)
- Split `test/test_gromov.py` into `test/gromov/` (PR #619)

## 0.9.3
*January 2024*


#### Closed issues
- Fixed an issue with cost correction for mismatched labels in `ot.da.BaseTransport` fit methods. This fix addresses the original issue introduced PR #587 (PR #593)


## 0.9.2
*December 2023*

Expand Down
150 changes: 150 additions & 0 deletions examples/plot_solve_variants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# -*- coding: utf-8 -*-
"""
======================================
Optimal Transport solvers comparison
======================================
This example illustrates the solutions returns for diffrent variants of exact,
regularized and unbalanced OT solvers.
"""

# Author: Remi Flamary <[email protected]>
#
# License: MIT License
# sphinx_gallery_thumbnail_number = 3

#%%

import numpy as np
import matplotlib.pylab as pl
import ot
import ot.plot
from ot.datasets import make_1D_gauss as gauss

##############################################################################
# Generate data
# -------------


#%% parameters

n = 50 # nb bins

# bin positions
x = np.arange(n, dtype=np.float64)

# Gaussian distributions
a = 0.6 * gauss(n, m=15, s=5) + 0.4 * gauss(n, m=35, s=5) # m= mean, s= std
b = gauss(n, m=25, s=5)

# loss matrix
M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1)))
M /= M.max()


##############################################################################
# Plot distributions and loss matrix
# ----------------------------------

#%% plot the distributions

pl.figure(1, figsize=(6.4, 3))
pl.plot(x, a, 'b', label='Source distribution')
pl.plot(x, b, 'r', label='Target distribution')
pl.legend()

#%% plot distributions and loss matrix

pl.figure(2, figsize=(5, 5))
ot.plot.plot1D_mat(a, b, M, 'Cost matrix M')

##############################################################################
# Define Group lasso regularization and gradient
# ------------------------------------------------
# The groups are the first and second half of the columns of G


def reg_gl(G): # group lasso + small l2 reg
G1 = G[:n // 2, :]**2
G2 = G[n // 2:, :]**2
gl1 = np.sum(np.sqrt(np.sum(G1, 0)))
gl2 = np.sum(np.sqrt(np.sum(G2, 0)))
return gl1 + gl2 + 0.1 * np.sum(G**2)


def grad_gl(G): # gradient of group lasso + small l2 reg
G1 = G[:n // 2, :]
G2 = G[n // 2:, :]
gl1 = G1 / np.sqrt(np.sum(G1**2, 0, keepdims=True) + 1e-8)
gl2 = G2 / np.sqrt(np.sum(G2**2, 0, keepdims=True) + 1e-8)
return np.concatenate((gl1, gl2), axis=0) + 0.2 * G


reg_type_gl = (reg_gl, grad_gl)

# %%
# Set up parameters for solvers and solve
# ---------------------------------------

lst_regs = ["No Reg.", "Entropic", "L2", "Group Lasso + L2"]
lst_unbalanced = ["Balanced", "Unbalanced KL", 'Unbalanced L2', 'Unb. TV (Partial)'] # ["Balanced", "Unb. KL", "Unb. L2", "Unb L1 (partial)"]

lst_solvers = [ # name, param for ot.solve function
# balanced OT
('Exact OT', dict()),
('Entropic Reg. OT', dict(reg=0.005)),
('L2 Reg OT', dict(reg=1, reg_type='l2')),
('Group Lasso Reg. OT', dict(reg=0.1, reg_type=reg_type_gl)),


# unbalanced OT KL
('Unbalanced KL No Reg.', dict(unbalanced=0.005)),
('Unbalanced KL wit KL Reg.', dict(reg=0.0005, unbalanced=0.005, unbalanced_type='kl', reg_type='kl')),
('Unbalanced KL with L2 Reg.', dict(reg=0.5, reg_type='l2', unbalanced=0.005, unbalanced_type='kl')),
('Unbalanced KL with Group Lasso Reg.', dict(reg=0.1, reg_type=reg_type_gl, unbalanced=0.05, unbalanced_type='kl')),

# unbalanced OT L2
('Unbalanced L2 No Reg.', dict(unbalanced=0.5, unbalanced_type='l2')),
('Unbalanced L2 with KL Reg.', dict(reg=0.001, unbalanced=0.2, unbalanced_type='l2')),
('Unbalanced L2 with L2 Reg.', dict(reg=0.1, reg_type='l2', unbalanced=0.2, unbalanced_type='l2')),
('Unbalanced L2 with Group Lasso Reg.', dict(reg=0.05, reg_type=reg_type_gl, unbalanced=0.7, unbalanced_type='l2')),

# unbalanced OT TV
('Unbalanced TV No Reg.', dict(unbalanced=0.1, unbalanced_type='tv')),
('Unbalanced TV with KL Reg.', dict(reg=0.001, unbalanced=0.01, unbalanced_type='tv')),
('Unbalanced TV with L2 Reg.', dict(reg=0.1, reg_type='l2', unbalanced=0.01, unbalanced_type='tv')),
('Unbalanced TV with Group Lasso Reg.', dict(reg=0.02, reg_type=reg_type_gl, unbalanced=0.01, unbalanced_type='tv')),

]

lst_plans = []
for (name, param) in lst_solvers:
G = ot.solve(M, a, b, **param).plan
lst_plans.append(G)

##############################################################################
# Plot plans
# ----------

pl.figure(3, figsize=(9, 9))

for i, bname in enumerate(lst_unbalanced):
for j, rname in enumerate(lst_regs):
pl.subplot(len(lst_unbalanced), len(lst_regs), i * len(lst_regs) + j + 1)

plan = lst_plans[i * len(lst_regs) + j]
m2 = plan.sum(0)
m1 = plan.sum(1)
m1, m2 = m1 / a.max(), m2 / b.max()
pl.imshow(plan, cmap='Greys')
pl.plot(x, m2 * 10, 'r')
pl.plot(m1 * 10, x, 'b')
pl.plot(x, b / b.max() * 10, 'r', alpha=0.3)
pl.plot(a / a.max() * 10, x, 'b', alpha=0.3)
#pl.axis('off')
pl.tick_params(left=False, right=False, labelleft=False,
labelbottom=False, bottom=False)
if i == 0:
pl.title(rname)
if j == 0:
pl.ylabel(bname, fontsize=14)
2 changes: 1 addition & 1 deletion ot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
# utils functions
from .utils import dist, unif, tic, toc, toq

__version__ = "0.9.3dev"
__version__ = "0.9.4dev"

__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils',
'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
Expand Down
54 changes: 36 additions & 18 deletions ot/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .gaussian import empirical_bures_wasserstein_distance
from .factored import factored_optimal_transport
from .lowrank import lowrank_sinkhorn
from .optim import cg

lst_method_lazy = ['1d', 'gaussian', 'lowrank', 'factored', 'geomloss', 'geomloss_auto', 'geomloss_tensorized', 'geomloss_online', 'geomloss_multiscale']

Expand Down Expand Up @@ -57,13 +58,15 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
Regularization weight :math:`\lambda_r`, by default None (no reg., exact
OT)
reg_type : str, optional
Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL"
Type of regularization :math:`R` either "KL", "L2", "entropy",
by default "KL". a tuple of functions can be provided for general
solver (see :any:`cg`). This is only used when ``reg!=None``.
unbalanced : float, optional
Unbalanced penalization weight :math:`\lambda_u`, by default None
(balanced OT)
unbalanced_type : str, optional
Type of unbalanced penalization function :math:`U` either "KL", "L2",
"TV", by default "KL"
"TV", by default "KL".
method : str, optional
Method for solving the problem when multiple algorithms are available,
default None for automatic selection.
Expand All @@ -80,10 +83,10 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
verbose : bool, optional
Print information in the solver, by default False
grad : str, optional
Type of gradient computation, either or 'autodiff' or 'implicit' used only for
Type of gradient computation, either or 'autodiff' or 'envelope' used only for
Sinkhorn solver. By default 'autodiff' provides gradients wrt all
outputs (`plan, value, value_linear`) but with important memory cost.
'implicit' provides gradients only for `value` and and other outputs are
'envelope' provides gradients only for `value` and and other outputs are
detached. This is useful for memory saving when only the value is needed.
Returns
Expand Down Expand Up @@ -140,13 +143,13 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
# or for original Sinkhorn paper formulation [2]
res = ot.solve(M, a, b, reg=1.0, reg_type='entropy')
# Use implicit differentiation for memory saving
res = ot.solve(M, a, b, reg=1.0, grad='implicit') # M, a, b are torch tensors
# Use envelope theorem differentiation for memory saving
res = ot.solve(M, a, b, reg=1.0, grad='envelope') # M, a, b are torch tensors
res.value.backward() # only the value is differentiable
Note that by default the Sinkhorn solver uses automatic differentiation to
compute the gradients of the values and plan. This can be changed with the
`grad` parameter. The `implicit` mode computes the implicit gradients only
`grad` parameter. The `envelope` mode computes the gradients only
for the value and the other outputs are detached. This is useful for
memory saving when only the gradient of value is needed.
Expand Down Expand Up @@ -311,9 +314,22 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,

if unbalanced is None: # Balanced regularized OT

if reg_type.lower() in ['entropy', 'kl']:
if isinstance(reg_type, tuple): # general solver

if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-9

plan, log = cg(a, b, M, reg=reg, f=reg_type[0], df=reg_type[1], numItermax=max_iter, stopThr=tol, log=True, verbose=verbose, G0=plan_init)

value_linear = nx.sum(M * plan)
value = log['loss'][-1]
potentials = (log['u'], log['v'])

elif reg_type.lower() in ['entropy', 'kl']:

if grad == 'implicit': # if implicit then detach the input
if grad == 'envelope': # if envelope then detach the input
M0, a0, b0 = M, a, b
M, a, b = nx.detach(M, a, b)

Expand All @@ -336,7 +352,7 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,

potentials = (log['log_u'], log['log_v'])

if grad == 'implicit': # set the gradient at convergence
if grad == 'envelope': # set the gradient at convergence

value = nx.set_gradients(value, (M0, a0, b0),
(plan, reg * (potentials[0] - potentials[0].mean()), reg * (potentials[1] - potentials[1].mean())))
Expand All @@ -359,7 +375,7 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,

else: # unbalanced AND regularized OT

if reg_type.lower() in ['kl'] and unbalanced_type.lower() == 'kl':
if not isinstance(reg_type, tuple) and reg_type.lower() in ['kl'] and unbalanced_type.lower() == 'kl':

if max_iter is None:
max_iter = 1000
Expand All @@ -374,14 +390,16 @@ def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,

potentials = (log['logu'], log['logv'])

elif reg_type.lower() in ['kl', 'l2', 'entropy'] and unbalanced_type.lower() in ['kl', 'l2']:
elif (isinstance(reg_type, tuple) or reg_type.lower() in ['kl', 'l2', 'entropy']) and unbalanced_type.lower() in ['kl', 'l2', 'tv']:

if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-12
if isinstance(reg_type, str):
reg_type = reg_type.lower()

plan, log = lbfgsb_unbalanced(a, b, M, reg=reg, reg_m=unbalanced, reg_div=reg_type.lower(), regm_div=unbalanced_type.lower(), numItermax=max_iter, stopThr=tol, verbose=verbose, log=True)
plan, log = lbfgsb_unbalanced(a, b, M, reg=reg, reg_m=unbalanced, reg_div=reg_type, regm_div=unbalanced_type.lower(), numItermax=max_iter, stopThr=tol, verbose=verbose, log=True, G0=plan_init)

value_linear = nx.sum(M * plan)

Expand Down Expand Up @@ -962,10 +980,10 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t
verbose : bool, optional
Print information in the solver, by default False
grad : str, optional
Type of gradient computation, either or 'autodiff' or 'implicit' used only for
Type of gradient computation, either or 'autodiff' or 'envelope' used only for
Sinkhorn solver. By default 'autodiff' provides gradients wrt all
outputs (`plan, value, value_linear`) but with important memory cost.
'implicit' provides gradients only for `value` and and other outputs are
'envelope' provides gradients only for `value` and and other outputs are
detached. This is useful for memory saving when only the value is needed.
Returns
Expand Down Expand Up @@ -1034,13 +1052,13 @@ def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_t
# lazy OT plan
lazy_plan = res.lazy_plan
# Use implicit differentiation for memory saving
res = ot.solve_sample(xa, xb, a, b, reg=1.0, grad='implicit')
# Use envelope theorem differentiation for memory saving
res = ot.solve_sample(xa, xb, a, b, reg=1.0, grad='envelope')
res.value.backward() # only the value is differentiable
Note that by default the Sinkhorn solver uses automatic differentiation to
compute the gradients of the values and plan. This can be changed with the
`grad` parameter. The `implicit` mode computes the implicit gradients only
`grad` parameter. The `envelope` mode computes the gradients only
for the value and the other outputs are detached. This is useful for
memory saving when only the gradient of value is needed.
Expand Down
Loading

0 comments on commit 81d7631

Please sign in to comment.