diff --git a/RELEASES.md b/RELEASES.md index 223eb0116..8cf5ae342 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -20,7 +20,7 @@ #### Closed issues - Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504) - Lazily instantiate backends to avoid unnecessary GPU memory pre-allocations on package import (Issue #516, PR #520) - +- Handle documentation and warnings when integers are provided to (f)gw solvers based on cg (Issue #530, PR #559) ## 0.9.1 *August 2023* diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py index 88b1eb75f..8ee68e917 100644 --- a/ot/gromov/_gw.py +++ b/ot/gromov/_gw.py @@ -12,6 +12,7 @@ # License: MIT License import numpy as np +import warnings from ..utils import dist, UndefinedParameter, list_to_array @@ -53,6 +54,10 @@ def gromov_wasserstein(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric which can lead to copy overhead on GPU arrays. .. note:: All computations in the conjugate gradient solver are done with numpy to limit memory overhead. + .. note:: This function will cast the computed transport plan to the data + type of the provided input :math:`\mathbf{C}_1`. Casting to an integer + tensor might result in a loss of precision. If this behaviour is + unwanted, please make sure to provide a floating point input. Parameters ---------- @@ -122,7 +127,7 @@ def gromov_wasserstein(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric if q is not None: arr.append(list_to_array(q)) else: - q = unif(C2.shape[0], type_as=C2) + q = unif(C2.shape[0], type_as=C1) if G0 is not None: G0_ = G0 arr.append(G0) @@ -171,6 +176,16 @@ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): else: def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=0., reg=1., nx=np_, **kwargs) + + if not nx.is_floating_point(C10): + warnings.warn( + "Input structure matrix consists of integer. The transport plan will be " + "casted accordingly, possibly resulting in a loss of precision. " + "If this behaviour is unwanted, please make sure your input " + "structure matrix consists of floating point elements.", + stacklevel=2 + ) + if log: res, log = cg(p, q, 0., 1., f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) log['gw_dist'] = nx.from_numpy(log['loss'][-1], type_as=C10) @@ -216,6 +231,10 @@ def gromov_wasserstein2(C1, C2, p=None, q=None, loss_fun='square_loss', symmetri which can lead to copy overhead on GPU arrays. .. note:: All computations in the conjugate gradient solver are done with numpy to limit memory overhead. + .. note:: This function will cast the computed transport plan to the data + type of the provided input :math:`\mathbf{C}_1`. Casting to an integer + tensor might result in a loss of precision. If this behaviour is + unwanted, please make sure to provide a floating point input. Parameters ---------- @@ -286,7 +305,7 @@ def gromov_wasserstein2(C1, C2, p=None, q=None, loss_fun='square_loss', symmetri if p is None: p = unif(C1.shape[0], type_as=C1) if q is None: - q = unif(C2.shape[0], type_as=C2) + q = unif(C2.shape[0], type_as=C1) T, log_gw = gromov_wasserstein( C1, C2, p, q, loss_fun, symmetric, log=True, armijo=armijo, G0=G0, @@ -344,6 +363,10 @@ def fused_gromov_wasserstein(M, C1, C2, p=None, q=None, loss_fun='square_loss', which can lead to copy overhead on GPU arrays. .. note:: All computations in the conjugate gradient solver are done with numpy to limit memory overhead. + .. note:: This function will cast the computed transport plan to the data + type of the provided input :math:`\mathbf{M}`. Casting to an integer + tensor might result in a loss of precision. If this behaviour is + unwanted, please make sure to provide a floating point input. Parameters @@ -409,11 +432,11 @@ def fused_gromov_wasserstein(M, C1, C2, p=None, q=None, loss_fun='square_loss', if p is not None: arr.append(list_to_array(p)) else: - p = unif(C1.shape[0], type_as=C1) + p = unif(C1.shape[0], type_as=M) if q is not None: arr.append(list_to_array(q)) else: - q = unif(C2.shape[0], type_as=C2) + q = unif(C2.shape[0], type_as=M) if G0 is not None: G0_ = G0 arr.append(G0) @@ -465,14 +488,22 @@ def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): else: def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): return solve_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, M=(1 - alpha) * M, reg=alpha, nx=np_, **kwargs) + if not nx.is_floating_point(M0): + warnings.warn( + "Input feature matrix consists of integer. The transport plan will be " + "casted accordingly, possibly resulting in a loss of precision. " + "If this behaviour is unwanted, please make sure your input " + "feature matrix consists of floating point elements.", + stacklevel=2 + ) if log: res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) - log['fgw_dist'] = nx.from_numpy(log['loss'][-1], type_as=C10) - log['u'] = nx.from_numpy(log['u'], type_as=C10) - log['v'] = nx.from_numpy(log['v'], type_as=C10) - return nx.from_numpy(res, type_as=C10), log + log['fgw_dist'] = nx.from_numpy(log['loss'][-1], type_as=M0) + log['u'] = nx.from_numpy(log['u'], type_as=M0) + log['v'] = nx.from_numpy(log['v'], type_as=M0) + return nx.from_numpy(res, type_as=M0), log else: - return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs), type_as=C10) + return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=False, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs), type_as=M0) def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss', symmetric=None, alpha=0.5, @@ -510,6 +541,10 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss', which can lead to copy overhead on GPU arrays. .. note:: All computations in the conjugate gradient solver are done with numpy to limit memory overhead. + .. note:: This function will cast the computed transport plan to the data + type of the provided input :math:`\mathbf{M}`. Casting to an integer + tensor might result in a loss of precision. If this behaviour is + unwanted, please make sure to provide a floating point input. Parameters ---------- @@ -578,9 +613,9 @@ def fused_gromov_wasserstein2(M, C1, C2, p=None, q=None, loss_fun='square_loss', # init marginals if set as None if p is None: - p = unif(C1.shape[0], type_as=C1) + p = unif(C1.shape[0], type_as=M) if q is None: - q = unif(C2.shape[0], type_as=C2) + q = unif(C2.shape[0], type_as=M) T, log_fgw = fused_gromov_wasserstein( M, C1, C2, p, q, loss_fun, symmetric, alpha, armijo, G0, log=True, diff --git a/test/test_gromov.py b/test/test_gromov.py index a71433bb5..9e873d5a0 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -122,6 +122,37 @@ def test_asymmetric_gromov(nx): np.testing.assert_allclose(logb['gw_dist'], 0., atol=1e-04) +def test_gromov_integer_warnings(nx): + n_samples = 10 # nb samples + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=1) + xt = xs[::-1].copy() + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + C1 = C1.astype(np.int32) + C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0) + + G = ot.gromov.gromov_wasserstein( + C1, C2, None, q, 'square_loss', G0=G0, verbose=True, + alpha_min=0., alpha_max=1.) + Gb = nx.to_numpy(ot.gromov.gromov_wasserstein( + C1b, C2b, pb, None, 'square_loss', symmetric=True, G0=G0b, verbose=True)) + + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(G, 0., atol=1e-09) + + def test_gromov_dtype_device(nx): # setup n_samples = 20 # nb samples @@ -1145,7 +1176,7 @@ def test_fgw(nx): def test_asymmetric_fgw(nx): - n_samples = 50 # nb samples + n_samples = 20 # nb samples rng = np.random.RandomState(0) C1 = rng.uniform(low=0., high=10, size=(n_samples, n_samples)) idx = np.arange(n_samples) @@ -1221,6 +1252,32 @@ def test_asymmetric_fgw(nx): np.testing.assert_allclose(logb['fgw_dist'], 0., atol=1e-04) +def test_fgw_integer_warnings(nx): + n_samples = 20 # nb samples + rng = np.random.RandomState(0) + C1 = rng.uniform(low=0., high=10, size=(n_samples, n_samples)) + idx = np.arange(n_samples) + rng.shuffle(idx) + C2 = C1[idx, :][:, idx] + + # add features + F1 = rng.uniform(low=0., high=10, size=(n_samples, 1)) + F2 = F1[idx, :] + p = ot.unif(n_samples) + q = ot.unif(n_samples) + G0 = p[:, None] * q[None, :] + + M = ot.dist(F1, F2).astype(np.int32) + Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0) + + G, log = ot.gromov.fused_gromov_wasserstein(M, C1, C2, p, q, 'square_loss', alpha=0.5, G0=G0, log=True, symmetric=False, verbose=True) + Gb, logb = ot.gromov.fused_gromov_wasserstein(Mb, C1b, C2b, pb, qb, 'square_loss', alpha=0.5, log=True, symmetric=None, G0=G0b, verbose=True) + Gb = nx.to_numpy(Gb) + # check constraints + np.testing.assert_allclose(G, Gb, atol=1e-06) + np.testing.assert_allclose(G, 0., atol=1e-06) + + def test_fgw2_gradients(): n_samples = 20 # nb samples