Skip to content

Commit

Permalink
[MRG] correct independence of fgw barycenters to init (PythonOT#566)
Browse files Browse the repository at this point in the history
* correct independence of fgw barycenters to init

* fix pep8 and tests

* correct PR id

* take into account comments
  • Loading branch information
cedricvincentcuaz authored Nov 9, 2023
1 parent 1682b60 commit a56e1b2
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 28 deletions.
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)
- Lazily instantiate backends to avoid unnecessary GPU memory pre-allocations on package import (Issue #516, PR #520)
- Handle documentation and warnings when integers are provided to (f)gw solvers based on cg (Issue #530, PR #559)
- Correct independence of `fgw_barycenters` to `init_C` and `init_X` (Issue #547, PR #566)

## 0.9.1
*August 2023*
Expand Down
41 changes: 13 additions & 28 deletions ot/gromov/_gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,6 @@ def df(G):
def df(G):
return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_))

# removed since 0.9.2
#if loss_fun == 'kl_loss':
# armijo = True # there is no closed form line-search with KL

if armijo:
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs)
Expand Down Expand Up @@ -478,10 +474,6 @@ def df(G):
def df(G):
return 0.5 * (gwggrad(constC, hC1, hC2, G, np_) + gwggrad(constCt, hC1t, hC2t, G, np_))

# removed since 0.9.2
#if loss_fun == 'kl_loss':
# armijo = True # there is no closed form line-search with KL

if armijo:
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=np_, **kwargs)
Expand Down Expand Up @@ -827,10 +819,6 @@ def gromov_barycenters(
else:
C = init_C

# removed since 0.9.2
#if loss_fun == 'kl_loss':
# armijo = True

cpt = 0
err = 1

Expand Down Expand Up @@ -1005,16 +993,14 @@ def fgw_barycenters(
else:
if init_X is None:
X = nx.zeros((N, d), type_as=ps[0])

else:
X = init_X

T = [nx.outer(p, q) for q in ps]

Ms = [dist(X, Ys[s]) for s in range(len(Ys))]

# removed since 0.9.2
#if loss_fun == 'kl_loss':
# armijo = True
if warmstartT:
T = [nx.outer(p, q) for q in ps]

cpt = 0
err_feature = 1
Expand All @@ -1030,11 +1016,19 @@ def fgw_barycenters(
Cprev = C
Xprev = X

if warmstartT:
T = [fused_gromov_wasserstein(
Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric,
G0=T[s], max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)]
else:
T = [fused_gromov_wasserstein(
Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric,
G0=None, max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)]
# T is N,ns
if not fixed_features:
Ys_temp = [y.T for y in Ys]
X = update_feature_matrix(lambdas, Ys_temp, T, p).T

Ms = [dist(X, Ys[s]) for s in range(len(Ys))]
Ms = [dist(X, Ys[s]) for s in range(len(Ys))]

if not fixed_structure:
T_temp = [t.T for t in T]
Expand All @@ -1044,15 +1038,6 @@ def fgw_barycenters(
elif loss_fun == 'kl_loss':
C = update_kl_loss(p, lambdas, T_temp, Cs)

if warmstartT:
T = [fused_gromov_wasserstein(
Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric,
G0=T[s], max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)]
else:
T = [fused_gromov_wasserstein(
Ms[s], C, Cs[s], p, ps[s], loss_fun=loss_fun, alpha=alpha, armijo=armijo, symmetric=symmetric,
G0=None, max_iter=max_iter, tol_rel=1e-5, tol_abs=0., verbose=verbose, **kwargs) for s in range(S)]
# T is N,ns
err_feature = nx.norm(X - nx.reshape(Xprev, (N, d)))
err_structure = nx.norm(C - Cprev)
if log:
Expand Down
15 changes: 15 additions & 0 deletions test/test_gromov.py
Original file line number Diff line number Diff line change
Expand Up @@ -1425,6 +1425,13 @@ def test_fgw_barycenter(nx):
init_C /= init_C.max()
init_Cb = nx.from_numpy(init_C)

with pytest.raises(ot.utils.UndefinedParameter): # to raise warning when `fixed_structure=True`and `init_C=None`
Xb, Cb = ot.gromov.fgw_barycenters(
n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=None,
alpha=0.5, fixed_structure=True, init_C=None, fixed_features=False,
p=None, loss_fun='square_loss', max_iter=100, tol=1e-3
)

Xb, Cb = ot.gromov.fgw_barycenters(
n_samples, [ysb, ytb], [C1b, C2b], ps=[p1b, p2b], lambdas=None,
alpha=0.5, fixed_structure=True, init_C=init_Cb, fixed_features=False,
Expand All @@ -1437,12 +1444,20 @@ def test_fgw_barycenter(nx):
init_X = rng.randn(n_samples, ys.shape[1])
init_Xb = nx.from_numpy(init_X)

with pytest.raises(ot.utils.UndefinedParameter): # to raise warning when `fixed_features=True`and `init_X=None`
Xb, Cb, logb = ot.gromov.fgw_barycenters(
n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5,
fixed_structure=False, fixed_features=True, init_X=None,
p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3,
warmstartT=True, log=True, random_state=98765, verbose=True
)
Xb, Cb, logb = ot.gromov.fgw_barycenters(
n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], [.5, .5], 0.5,
fixed_structure=False, fixed_features=True, init_X=init_Xb,
p=pb, loss_fun='square_loss', max_iter=100, tol=1e-3,
warmstartT=True, log=True, random_state=98765, verbose=True
)

X, C = nx.to_numpy(Xb), nx.to_numpy(Cb)
np.testing.assert_allclose(C.shape, (n_samples, n_samples))
np.testing.assert_allclose(X.shape, (n_samples, ys.shape[1]))
Expand Down

0 comments on commit a56e1b2

Please sign in to comment.