diff --git a/RELEASES.md b/RELEASES.md index a695d6a70..3b8513dbd 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,6 +1,6 @@ # Releases -## 0.9.3 +## 0.9.3dev #### New features + `ot.gromov._gw.solve_gromov_linesearch` now has an argument to specifify if the matrices are symmetric in which case the computation can be done faster. @@ -9,6 +9,7 @@ - Fixed an issue with cost correction for mismatched labels in `ot.da.BaseTransport` fit methods. This fix addresses the original issue introduced PR #587 (PR #593) - Fix gpu compatibility of sr(F)GW solvers when `G0 is not None`(PR #596) - Fix doc and example for lowrank sinkhorn (PR #601) +- Fix issue with empty weights for `ot.emd2` (PR #606, Issue #534) ## 0.9.2 *December 2023* diff --git a/ot/__init__.py b/ot/__init__.py index db49d6c34..9a63b5f6f 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -58,7 +58,7 @@ # utils functions from .utils import dist, unif, tic, toc, toq -__version__ = "0.9.3" +__version__ = "0.9.3dev" __all__ = ['emd', 'emd2', 'emd_1d', 'sinkhorn', 'sinkhorn2', 'utils', 'datasets', 'bregman', 'lp', 'tic', 'toc', 'toq', 'gromov', diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py index 3d7a47480..281ed5f0b 100644 --- a/ot/gromov/_gw.py +++ b/ot/gromov/_gw.py @@ -703,8 +703,6 @@ def solve_gromov_linesearch(G, deltaG, cost_G, C1, C2, M, reg, """ if nx is None: - G, deltaG, C1, C2, M = list_to_array(G, deltaG, C1, C2, M) - if isinstance(M, int) or isinstance(M, float): nx = get_backend(G, deltaG, C1, C2) else: diff --git a/ot/gromov/_semirelaxed.py b/ot/gromov/_semirelaxed.py index fb9d2b3ca..c37ba2bf4 100644 --- a/ot/gromov/_semirelaxed.py +++ b/ot/gromov/_semirelaxed.py @@ -583,8 +583,6 @@ def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p, Gromov-Wasserstein". NeurIPS 2023 Workshop OTML. """ if nx is None: - G, deltaG, C1, C2, M = list_to_array(G, deltaG, C1, C2, M) - if isinstance(M, int) or isinstance(M, float): nx = get_backend(G, deltaG, C1, C2) else: diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 545d1d8cd..93316a6c1 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -302,17 +302,24 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1, c ot.optim.cg : General regularized OT """ - # convert to numpy if list a, b, M = list_to_array(a, b, M) + nx = get_backend(M, a, b) - a0, b0, M0 = a, b, M - if len(a0) != 0: - type_as = a0 - elif len(b0) != 0: - type_as = b0 + if len(a) != 0: + type_as = a + elif len(b) != 0: + type_as = b else: - type_as = M0 - nx = get_backend(M0, a0, b0) + type_as = M + + # if empty array given then use uniform distributions + if len(a) == 0: + a = nx.ones((M.shape[0],), type_as=type_as) / M.shape[0] + if len(b) == 0: + b = nx.ones((M.shape[1],), type_as=type_as) / M.shape[1] + + # store original tensors + a0, b0, M0 = a, b, M # convert to numpy M, a, b = nx.to_numpy(M, a, b) @@ -474,15 +481,23 @@ def emd2(a, b, M, processes=1, """ a, b, M = list_to_array(a, b, M) + nx = get_backend(M, a, b) - a0, b0, M0 = a, b, M - if len(a0) != 0: - type_as = a0 - elif len(b0) != 0: - type_as = b0 + if len(a) != 0: + type_as = a + elif len(b) != 0: + type_as = b else: - type_as = M0 - nx = get_backend(M0, a0, b0) + type_as = M + + # if empty array given then use uniform distributions + if len(a) == 0: + a = nx.ones((M.shape[0],), type_as=type_as) / M.shape[0] + if len(b) == 0: + b = nx.ones((M.shape[1],), type_as=type_as) / M.shape[1] + + # store original tensors + a0, b0, M0 = a, b, M # convert to numpy M, a, b = nx.to_numpy(M, a, b) @@ -491,11 +506,6 @@ def emd2(a, b, M, processes=1, b = np.asarray(b, dtype=np.float64) M = np.asarray(M, dtype=np.float64, order='C') - # if empty array given then use uniform distributions - if len(a) == 0: - a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] - if len(b) == 0: - b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] assert (a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1]), \ "Dimension mismatch, check dimensions of M with a and b" diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index e792db904..d9395c8d4 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -223,8 +223,12 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True, ot.lp.emd2_1d : EMD for 1d distributions (returns cost instead of the transportation matrix) """ - a, b, x_a, x_b = list_to_array(a, b, x_a, x_b) + x_a, x_b = list_to_array(x_a, x_b) nx = get_backend(x_a, x_b) + if a is not None: + a = list_to_array(a, nx=nx) + if b is not None: + b = list_to_array(b, nx=nx) assert (x_a.ndim == 1 or x_a.ndim == 2 and x_a.shape[1] == 1), \ "emd_1d should only be used with monodimensional data" diff --git a/ot/optim.py b/ot/optim.py index 8700f75d1..dcdef6a88 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -12,7 +12,6 @@ import warnings from .lp import emd from .bregman import sinkhorn -from .utils import list_to_array from .backend import get_backend with warnings.catch_warnings(): @@ -73,7 +72,6 @@ def line_search_armijo( """ if nx is None: - xk, pk, gfk = list_to_array(xk, pk, gfk) xk0, pk0 = xk, pk nx = get_backend(xk0, pk0) else: @@ -236,7 +234,7 @@ def generic_conditional_gradient(a, b, M, f, df, reg1, reg2, lp_solver, line_sea ot.lp.emd : Unregularized optimal transport ot.bregman.sinkhorn : Entropic regularized optimal transport """ - a, b, M, G0 = list_to_array(a, b, M, G0) + if isinstance(M, int) or isinstance(M, float): nx = get_backend(a, b) else: diff --git a/ot/utils.py b/ot/utils.py index 19e61f1fe..404a9f2db 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -56,12 +56,30 @@ def laplacian(x): return L -def list_to_array(*lst): +def list_to_array(*lst, nx=None): r""" Convert a list if in numpy format """ + lst_not_empty = [a for a in lst if len(a) > 0 and not isinstance(a, list)] + if nx is None: # find backend + + if len(lst_not_empty) == 0: + type_as = np.zeros(0) + nx = get_backend(type_as) + else: + nx = get_backend(*lst_not_empty) + type_as = lst_not_empty[0] + else: + if len(lst_not_empty) == 0: + type_as = None + else: + type_as = lst_not_empty[0] if len(lst) > 1: - return [np.array(a) if isinstance(a, list) else a for a in lst] + return [nx.from_numpy(np.array(a), type_as=type_as) + if isinstance(a, list) else a for a in lst] else: - return np.array(lst[0]) if isinstance(lst[0], list) else lst[0] + if isinstance(lst[0], list): + return nx.from_numpy(np.array(lst[0]), type_as=type_as) + else: + return lst[0] def proj_simplex(v, z=1): diff --git a/test/conftest.py b/test/conftest.py index 0303ed9f2..043c8ca70 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -13,7 +13,7 @@ if jax: os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' - from jax.config import config + from jax import config config.update("jax_enable_x64", True) if tf: diff --git a/test/test_ot.py b/test/test_ot.py index 5c6e6732b..a90321d5f 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -74,7 +74,11 @@ def test_emd2_backends(nx): valb = ot.emd2(ab, ab, Mb) + # check with empty inputs + valb2 = ot.emd2([], [], Mb) + np.allclose(val, nx.to_numpy(valb)) + np.allclose(val, nx.to_numpy(valb2)) def test_emd_emd2_types_devices(nx): diff --git a/test/test_utils.py b/test/test_utils.py index 6cdb7ead7..966cef989 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -322,6 +322,18 @@ def test_cost_normalization(nx): ot.utils.cost_normalization(C1, 'error') +def test_list_to_array(nx): + + lst = [np.array([1, 2, 3]), np.array([4, 5, 6])] + + a1, a2 = ot.utils.list_to_array(*lst) + + assert a1.shape == (3,) + assert a2.shape == (3,) + + a, b, M = ot.utils.list_to_array([], [], [[1.0, 2.0], [3.0, 4.0]]) + + def test_check_params(): res1 = ot.utils.check_params(first='OK', second=20)