Skip to content

Commit

Permalink
[MRG] Fix bug in emd2 with empty weighs on backends (PythonOT#606)
Browse files Browse the repository at this point in the history
* fix buf emd2 for empty inputs

* update release file

* debug problems in optimization hen using list_to_arry by removing it everywhere

* update jax config in tests

* hopefully final fix
  • Loading branch information
rflamary authored Mar 1, 2024
1 parent f1fe593 commit 0573eba
Show file tree
Hide file tree
Showing 11 changed files with 77 additions and 34 deletions.
3 changes: 2 additions & 1 deletion RELEASES.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Releases

## 0.9.3
## 0.9.3dev

#### New features
+ `ot.gromov._gw.solve_gromov_linesearch` now has an argument to specifify if the matrices are symmetric in which case the computation can be done faster.
Expand All @@ -9,6 +9,7 @@
- Fixed an issue with cost correction for mismatched labels in `ot.da.BaseTransport` fit methods. This fix addresses the original issue introduced PR #587 (PR #593)
- Fix gpu compatibility of sr(F)GW solvers when `G0 is not None`(PR #596)
- Fix doc and example for lowrank sinkhorn (PR #601)
- Fix issue with empty weights for `ot.emd2` (PR #606, Issue #534)

## 0.9.2
*December 2023*
Expand Down
2 changes: 1 addition & 1 deletion ot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
# utils functions
from .utils import dist, unif, tic, toc, toq

__version__ = "0.9.3"
__version__ = "0.9.3dev"

__all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils',
'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov',
Expand Down
2 changes: 0 additions & 2 deletions ot/gromov/_gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,8 +703,6 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg,
"""
if nx is None:
G, deltaG, C1, C2, M = list_to_array(G, deltaG, C1, C2, M)

if isinstance(M, int) or isinstance(M, float):
nx = get_backend(G, deltaG, C1, C2)
else:
Expand Down
2 changes: 0 additions & 2 deletions ot/gromov/_semirelaxed.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,8 +583,6 @@ def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p,
Gromov-Wasserstein". NeurIPS 2023 Workshop OTML.
"""
if nx is None:
G, deltaG, C1, C2, M = list_to_array(G, deltaG, C1, C2, M)

if isinstance(M, int) or isinstance(M, float):
nx = get_backend(G, deltaG, C1, C2)
else:
Expand Down
50 changes: 30 additions & 20 deletions ot/lp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,17 +302,24 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1, c
ot.optim.cg : General regularized OT
"""

# convert to numpy if list
a, b, M = list_to_array(a, b, M)
nx = get_backend(M, a, b)

a0, b0, M0 = a, b, M
if len(a0) != 0:
type_as = a0
elif len(b0) != 0:
type_as = b0
if len(a) != 0:
type_as = a
elif len(b) != 0:
type_as = b
else:
type_as = M0
nx = get_backend(M0, a0, b0)
type_as = M

# if empty array given then use uniform distributions
if len(a) == 0:
a = nx.ones((M.shape[0],), type_as=type_as) / M.shape[0]
if len(b) == 0:
b = nx.ones((M.shape[1],), type_as=type_as) / M.shape[1]

# store original tensors
a0, b0, M0 = a, b, M

# convert to numpy
M, a, b = nx.to_numpy(M, a, b)
Expand Down Expand Up @@ -474,15 +481,23 @@ def emd2(a, b, M, processes=1,
"""

a, b, M = list_to_array(a, b, M)
nx = get_backend(M, a, b)

a0, b0, M0 = a, b, M
if len(a0) != 0:
type_as = a0
elif len(b0) != 0:
type_as = b0
if len(a) != 0:
type_as = a
elif len(b) != 0:
type_as = b
else:
type_as = M0
nx = get_backend(M0, a0, b0)
type_as = M

# if empty array given then use uniform distributions
if len(a) == 0:
a = nx.ones((M.shape[0],), type_as=type_as) / M.shape[0]
if len(b) == 0:
b = nx.ones((M.shape[1],), type_as=type_as) / M.shape[1]

# store original tensors
a0, b0, M0 = a, b, M

# convert to numpy
M, a, b = nx.to_numpy(M, a, b)
Expand All @@ -491,11 +506,6 @@ def emd2(a, b, M, processes=1,
b = np.asarray(b, dtype=np.float64)
M = np.asarray(M, dtype=np.float64, order='C')

# if empty array given then use uniform distributions
if len(a) == 0:
a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0]
if len(b) == 0:
b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1]

assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \
"Dimension mismatch, check dimensions of M with a and b"
Expand Down
6 changes: 5 additions & 1 deletion ot/lp/solver_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,12 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the
transportation matrix)
"""
a, b, x_a, x_b = list_to_array(a, b, x_a, x_b)
x_a, x_b = list_to_array(x_a, x_b)
nx = get_backend(x_a, x_b)
if a is not None:
a = list_to_array(a, nx=nx)
if b is not None:
b = list_to_array(b, nx=nx)

assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \
"emd_1d should only be used with monodimensional data"
Expand Down
4 changes: 1 addition & 3 deletions ot/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import warnings
from .lp import emd
from .bregman import sinkhorn
from .utils import list_to_array
from .backend import get_backend

with warnings.catch_warnings():
Expand Down Expand Up @@ -73,7 +72,6 @@ def line_search_armijo(
"""
if nx is None:
xk, pk, gfk = list_to_array(xk, pk, gfk)
xk0, pk0 = xk, pk
nx = get_backend(xk0, pk0)
else:
Expand Down Expand Up @@ -236,7 +234,7 @@ def generic_conditional_gradient(a, b, M, f, df, reg1, reg2, lp_solver, line_sea
ot.lp.emd : Unregularized optimal transport
ot.bregman.sinkhorn : Entropic regularized optimal transport
"""
a, b, M, G0 = list_to_array(a, b, M, G0)

if isinstance(M, int) or isinstance(M, float):
nx = get_backend(a, b)
else:
Expand Down
24 changes: 21 additions & 3 deletions ot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,30 @@ def laplacian(x):
return L


def list_to_array(*lst):
def list_to_array(*lst, nx=None):
r""" Convert a list if in numpy format """
lst_not_empty = [a for a in lst if len(a) > 0 and not isinstance(a, list)]
if nx is None: # find backend

if len(lst_not_empty) == 0:
type_as = np.zeros(0)
nx = get_backend(type_as)
else:
nx = get_backend(*lst_not_empty)
type_as = lst_not_empty[0]
else:
if len(lst_not_empty) == 0:
type_as = None
else:
type_as = lst_not_empty[0]
if len(lst) > 1:
return [np.array(a) if isinstance(a, list) else a for a in lst]
return [nx.from_numpy(np.array(a), type_as=type_as)
if isinstance(a, list) else a for a in lst]
else:
return np.array(lst[0]) if isinstance(lst[0], list) else lst[0]
if isinstance(lst[0], list):
return nx.from_numpy(np.array(lst[0]), type_as=type_as)
else:
return lst[0]


def proj_simplex(v, z=1):
Expand Down
2 changes: 1 addition & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

if jax:
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
from jax.config import config
from jax import config
config.update("jax_enable_x64", True)

if tf:
Expand Down
4 changes: 4 additions & 0 deletions test/test_ot.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ def test_emd2_backends(nx):

valb = ot.emd2(ab, ab, Mb)

# check with empty inputs
valb2 = ot.emd2([], [], Mb)

np.allclose(val, nx.to_numpy(valb))
np.allclose(val, nx.to_numpy(valb2))


def test_emd_emd2_types_devices(nx):
Expand Down
12 changes: 12 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,18 @@ def test_cost_normalization(nx):
ot.utils.cost_normalization(C1, 'error')


def test_list_to_array(nx):

lst = [np.array([1, 2, 3]), np.array([4, 5, 6])]

a1, a2 = ot.utils.list_to_array(*lst)

assert a1.shape == (3,)
assert a2.shape == (3,)

a, b, M = ot.utils.list_to_array([], [], [[1.0, 2.0], [3.0, 4.0]])


def test_check_params():

res1 = ot.utils.check_params(first='OK', second=20)
Expand Down

0 comments on commit 0573eba

Please sign in to comment.