Skip to content

Commit

Permalink
[WIP] Fix matrix feature shape in entropic FGW barycenters (PythonOT#575
Browse files Browse the repository at this point in the history
)

* fix matrix feature shape in entropic FGW barycenter

* fix matrix feature shape in entropic FGW barycenter

* complete tests for gromov.bregman
  • Loading branch information
cedricvincentcuaz authored Nov 11, 2023
1 parent fcd8f05 commit 6f4a40d
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 9 deletions.
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
- Correct independence of `fgw_barycenters` to `init_C` and `init_X` (Issue #547, PR #566)
- Avoid precision change when computing norm using PyTorch backend (Discussion #570, PR #572)
- Create `ot/bregman/`repository (Issue #567, PR #569)
- Fix matrix feature shape in `entropic_fused_gromov_barycenters`(Issue #574, PR #573)


## 0.9.1
Expand Down
10 changes: 4 additions & 6 deletions ot/gromov/_bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,8 @@ def entropic_fused_gromov_barycenters(
else:
Y = init_Y

T = [nx.outer(p_, p) for p_ in ps]
if warmstartT:
T = [nx.outer(p_, p) for p_ in ps]

Ms = [dist(Ys[s], Y) for s in range(len(Ys))]

Expand All @@ -971,9 +972,6 @@ def entropic_fused_gromov_barycenters(
err_feature = 1
err_structure = 1

if warmstartT:
T = [None] * S

if log:
log_ = {}
log_['err_feature'] = []
Expand All @@ -987,7 +985,7 @@ def entropic_fused_gromov_barycenters(
if warmstartT:
T = [entropic_fused_gromov_wasserstein(
Ms[s], Cs[s], C, ps[s], p, loss_fun, epsilon, symmetric, alpha,
None, max_iter, 1e-4, verbose=verbose, log=False, **kwargs) for s in range(S)]
T[s], max_iter, 1e-4, verbose=verbose, log=False, **kwargs) for s in range(S)]

else:
T = [entropic_fused_gromov_wasserstein(
Expand All @@ -1001,7 +999,7 @@ def entropic_fused_gromov_barycenters(

Ys_temp = [y.T for y in Ys]
T_temp = [Ts.T for Ts in T]
Y = update_feature_matrix(lambdas, Ys_temp, T_temp, p)
Y = update_feature_matrix(lambdas, Ys_temp, T_temp, p).T
Ms = [dist(Ys[s], Y) for s in range(len(Ys))]

if cpt % 10 == 0:
Expand Down
33 changes: 30 additions & 3 deletions test/test_gromov.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,12 @@ def test_entropic_proximal_gromov(nx):

C1b, C2b, pb, qb, G0b = nx.from_numpy(C1, C2, p, q, G0)

with pytest.raises(ValueError):
loss_fun = 'weird_loss_fun'
G, log = ot.gromov.entropic_gromov_wasserstein(
C1, C2, None, q, loss_fun, symmetric=None, G0=G0,
epsilon=1e-1, max_iter=50, solver='PPA', verbose=True, log=True, numItermax=1)

G, log = ot.gromov.entropic_gromov_wasserstein(
C1, C2, None, q, 'square_loss', symmetric=None, G0=G0,
epsilon=1e-1, max_iter=50, solver='PPA', verbose=True, log=True, numItermax=1)
Expand Down Expand Up @@ -606,6 +612,12 @@ def test_entropic_fgw(nx):

Mb, C1b, C2b, pb, qb, G0b = nx.from_numpy(M, C1, C2, p, q, G0)

with pytest.raises(ValueError):
loss_fun = 'weird_loss_fun'
G, log = ot.gromov.entropic_fused_gromov_wasserstein(
M, C1, C2, None, None, loss_fun, symmetric=None, G0=G0,
epsilon=1e-1, max_iter=10, verbose=True, log=True)

G, log = ot.gromov.entropic_fused_gromov_wasserstein(
M, C1, C2, None, None, 'square_loss', symmetric=None, G0=G0,
epsilon=1e-1, max_iter=10, verbose=True, log=True)
Expand Down Expand Up @@ -812,20 +824,28 @@ def test_entropic_fgw_barycenter(nx):
C2 = ot.dist(Xt)
p1 = ot.unif(ns)
p2 = ot.unif(nt)
n_samples = 2
n_samples = 3
p = ot.unif(n_samples)

ysb, ytb, C1b, C2b, p1b, p2b, pb = nx.from_numpy(ys, yt, C1, C2, p1, p2, p)

with pytest.raises(ValueError):
loss_fun = 'weird_loss_fun'
X, C, log = ot.gromov.entropic_fused_gromov_barycenters(
n_samples, [ys, yt], [C1, C2], None, p, [.5, .5], loss_fun, 0.1,
max_iter=10, tol=1e-3, verbose=True, warmstartT=True, random_state=42,
solver='PPA', numItermax=10, log=True
)

X, C, log = ot.gromov.entropic_fused_gromov_barycenters(
n_samples, [ys, yt], [C1, C2], None, p, [.5, .5], 'square_loss', 0.1,
max_iter=10, tol=1e-3, verbose=True, warmstartT=True, random_state=42,
solver='PPA', numItermax=1, log=True
solver='PPA', numItermax=10, log=True
)
Xb, Cb = ot.gromov.entropic_fused_gromov_barycenters(
n_samples, [ysb, ytb], [C1b, C2b], [p1b, p2b], None, [.5, .5], 'square_loss', 0.1,
max_iter=10, tol=1e-3, verbose=False, warmstartT=True, random_state=42,
solver='PPA', numItermax=1, log=False)
solver='PPA', numItermax=10, log=False)
Xb, Cb = nx.to_numpy(Xb, Cb)

np.testing.assert_allclose(C, Cb, atol=1e-06)
Expand Down Expand Up @@ -1052,6 +1072,13 @@ def test_gromov_entropic_barycenter(nx):

C1b, C2b, p1b, p2b, pb = nx.from_numpy(C1, C2, p1, p2, p)

with pytest.raises(ValueError):
loss_fun = 'weird_loss_fun'
Cb = ot.gromov.entropic_gromov_barycenters(
n_samples, [C1, C2], None, p, [.5, .5], loss_fun, 1e-3,
max_iter=10, tol=1e-3, verbose=True, warmstartT=True, random_state=42
)

Cb = ot.gromov.entropic_gromov_barycenters(
n_samples, [C1, C2], None, p, [.5, .5], 'square_loss', 1e-3,
max_iter=10, tol=1e-3, verbose=True, warmstartT=True, random_state=42
Expand Down

0 comments on commit 6f4a40d

Please sign in to comment.