Skip to content

Commit

Permalink
[MRG] (f)gw barycenter solvers new features and cleaning (PythonOT#578)
Browse files Browse the repository at this point in the history
* add conv_criterion feature to (f)gw barycenter solvers + trying to harmonise these solvers a bit

* correct pep8

* handle different convergence criterions

* fix

* update tests for new conv_criterion feature

* add fixed structure or feature to entropic fgw barycenter + tests

* conv_criterion -> stop_criterion / corrections in the docs / harmonise log behaviour

* Update RELEASES.md

---------

Co-authored-by: Rémi Flamary <[email protected]>
  • Loading branch information
cedricvincentcuaz and rflamary authored Nov 21, 2023
1 parent 31455e5 commit cffb6cf
Show file tree
Hide file tree
Showing 4 changed files with 532 additions and 211 deletions.
2 changes: 2 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
+ Add KL loss to all semi-relaxed (Fused) Gromov-Wasserstein solvers (PR #559)
+ Further upgraded unbalanced OT solvers for more flexibility and future use (PR #551)
+ New API function `ot.solve_sample` for solving OT problems from empirical samples (PR #563)
+ Add `stop_criterion` feature to (un)regularized (f)gw barycenter solvers (PR #578)
+ Add `fixed_structure` and `fixed_features` to entropic fgw barycenter solver (PR #578)

#### Closed issues
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)
Expand Down
245 changes: 178 additions & 67 deletions ot/gromov/_bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import warnings

from ..bregman import sinkhorn
from ..utils import dist, list_to_array, check_random_state, unif
from ..utils import dist, UndefinedParameter, list_to_array, check_random_state, unif
from ..backend import get_backend

from ._utils import init_matrix, gwloss, gwggrad
Expand Down Expand Up @@ -345,8 +345,9 @@ def entropic_gromov_wasserstein2(

def entropic_gromov_barycenters(
N, Cs, ps=None, p=None, lambdas=None, loss_fun='square_loss',
epsilon=0.1, symmetric=True, max_iter=1000, tol=1e-9, warmstartT=False,
verbose=False, log=False, init_C=None, random_state=None, **kwargs):
epsilon=0.1, symmetric=True, max_iter=1000, tol=1e-9,
stop_criterion='barycenter', warmstartT=False, verbose=False,
log=False, init_C=None, random_state=None, **kwargs):
r"""
Returns the Gromov-Wasserstein barycenters of `S` measured similarity matrices :math:`(\mathbf{C}_s)_{1 \leq s \leq S}`
estimated using Gromov-Wasserstein transports from Sinkhorn projections.
Expand Down Expand Up @@ -388,6 +389,10 @@ def entropic_gromov_barycenters(
Max number of iterations
tol : float, optional
Stop threshold on error (>0)
stop_criterion : str, optional. Default is 'barycenter'.
Convergence criterion taking values in ['barycenter', 'loss']. If set to 'barycenter'
uses absolute norm variations of estimated barycenters. Else if set to 'loss'
uses the relative variations of the loss.
warmstartT: bool, optional
Either to perform warmstart of transport plans in the successive
gromov-wasserstein transport problems.
Expand All @@ -407,7 +412,11 @@ def entropic_gromov_barycenters(
C : array-like, shape (`N`, `N`)
Similarity matrix in the barycenter space (permutated arbitrarily)
log : dict
Log dictionary of error during iterations. Return only if `log=True` in parameters.
Only returned when log=True. It contains the keys:
- :math:`\mathbf{T}`: list of (`N`, `ns`) transport matrices
- :math:`\mathbf{p}`: (`N`,) barycenter weights
- values used in convergence evaluation.
References
----------
Expand All @@ -418,6 +427,9 @@ def entropic_gromov_barycenters(
if loss_fun not in ('square_loss', 'kl_loss'):
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")

if stop_criterion not in ['barycenter', 'loss']:
raise ValueError(f"Unknown `stop_criterion='{stop_criterion}'`. Use one of: {'barycenter', 'loss'}.")

Cs = list_to_array(*Cs)
arr = [*Cs]
if ps is not None:
Expand Down Expand Up @@ -446,45 +458,75 @@ def entropic_gromov_barycenters(
C = init_C

cpt = 0
err = 1

error = []
err = 1e15 # either the error on 'barycenter' or 'loss'

if warmstartT:
T = [None] * S

if stop_criterion == 'barycenter':
inner_log = False
else:
inner_log = True
curr_loss = 1e15

if log:
log_ = {}
log_['err'] = []
if stop_criterion == 'loss':
log_['loss'] = []

while (err > tol) and (cpt < max_iter):
Cprev = C
if stop_criterion == 'barycenter':
Cprev = C
else:
prev_loss = curr_loss

# get transport plans
if warmstartT:
T = [entropic_gromov_wasserstein(
res = [entropic_gromov_wasserstein(
C, Cs[s], p, ps[s], loss_fun, epsilon, symmetric, T[s],
max_iter, 1e-4, verbose=verbose, log=False, **kwargs) for s in range(S)]
max_iter, 1e-4, verbose=verbose, log=inner_log, **kwargs) for s in range(S)]
else:
T = [entropic_gromov_wasserstein(
res = [entropic_gromov_wasserstein(
C, Cs[s], p, ps[s], loss_fun, epsilon, symmetric, None,
max_iter, 1e-4, verbose=verbose, log=False, **kwargs) for s in range(S)]
max_iter, 1e-4, verbose=verbose, log=inner_log, **kwargs) for s in range(S)]
if stop_criterion == 'barycenter':
T = res
else:
T = [output[0] for output in res]
curr_loss = np.sum([output[1]['gw_dist'] for output in res])

# update barycenters
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs, nx)
elif loss_fun == 'kl_loss':
C = update_kl_loss(p, lambdas, T, Cs, nx)

if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
# update convergence criterion
if stop_criterion == 'barycenter':
err = nx.norm(C - Cprev)
error.append(err)
if log:
log_['err'].append(err)

if verbose:
if cpt % 200 == 0:
print('{:5s}|{:12s}'.format(
'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(cpt, err))
else:
err = abs(curr_loss - prev_loss) / prev_loss if prev_loss != 0. else np.nan
if log:
log_['loss'].append(curr_loss)
log_['err'].append(err)

if verbose:
if cpt % 200 == 0:
print('{:5s}|{:12s}'.format(
'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(cpt, err))

cpt += 1

if log:
return C, {"err": error}
log_['T'] = T
log_['p'] = p

return C, log_
else:
return C

Expand Down Expand Up @@ -838,8 +880,9 @@ def entropic_fused_gromov_wasserstein2(
def entropic_fused_gromov_barycenters(
N, Ys, Cs, ps=None, p=None, lambdas=None, loss_fun='square_loss',
epsilon=0.1, symmetric=True, alpha=0.5, max_iter=1000, tol=1e-9,
warmstartT=False, verbose=False, log=False, init_C=None, init_Y=None,
random_state=None, **kwargs):
stop_criterion='barycenter', warmstartT=False, verbose=False,
log=False, init_C=None, init_Y=None, fixed_structure=False,
fixed_features=False, random_state=None, **kwargs):
r"""
Returns the Fused Gromov-Wasserstein barycenters of `S` measurable networks with node features :math:`(\mathbf{C}_s, \mathbf{Y}_s, \mathbf{p}_s)_{1 \leq s \leq S}`
estimated using Fused Gromov-Wasserstein transports from Sinkhorn projections.
Expand Down Expand Up @@ -886,6 +929,10 @@ def entropic_fused_gromov_barycenters(
Max number of iterations
tol : float, optional
Stop threshold on error (>0)
stop_criterion : str, optional. Default is 'barycenter'.
Stop criterion taking values in ['barycenter', 'loss']. If set to 'barycenter'
uses absolute norm variations of estimated barycenters. Else if set to 'loss'
uses the relative variations of the loss.
warmstartT: bool, optional
Either to perform warmstart of transport plans in the successive
fused gromov-wasserstein transport problems.
Expand All @@ -898,6 +945,10 @@ def entropic_fused_gromov_barycenters(
init_Y : array-like, shape (N,d), optional
Initialization for the barycenters' features. If not set a
random init is used.
fixed_structure : bool, optional
Whether to fix the structure of the barycenter during the updates.
fixed_features : bool, optional
Whether to fix the feature of the barycenter during the updates
random_state : int or RandomState instance, optional
Fix the seed for reproducibility
**kwargs: dict
Expand All @@ -910,7 +961,12 @@ def entropic_fused_gromov_barycenters(
C : array-like, shape (`N`, `N`)
Similarity matrix in the barycenter space (permutated as Y's rows)
log : dict
Log dictionary of error during iterations. Return only if `log=True` in parameters.
Only returned when log=True. It contains the keys:
- :math:`\mathbf{T}`: list of (`N`, `ns`) transport matrices
- :math:`\mathbf{p}`: (`N`,) barycenter weights
- :math:`(\mathbf{M}_s)_s`: all distance matrices between the feature of the barycenter and the other features :math:`(dist(\mathbf{X}, \mathbf{Y}_s))_s` shape (`N`, `ns`)
- values used in convergence evaluation.
References
----------
Expand All @@ -926,6 +982,9 @@ def entropic_fused_gromov_barycenters(
if loss_fun not in ('square_loss', 'kl_loss'):
raise ValueError(f"Unknown `loss_fun='{loss_fun}'`. Use one of: {'square_loss', 'kl_loss'}.")

if stop_criterion not in ['barycenter', 'loss']:
raise ValueError(f"Unknown `stop_criterion='{stop_criterion}'`. Use one of: {'barycenter', 'loss'}.")

Cs = list_to_array(*Cs)
Ys = list_to_array(*Ys)
arr = [*Cs, *Ys]
Expand All @@ -945,67 +1004,108 @@ def entropic_fused_gromov_barycenters(

d = Ys[0].shape[1] # dimension on the node features

# Initialization of C : random SPD matrix (if not provided by user)
if init_C is None:
generator = check_random_state(random_state)
xalea = generator.randn(N, 2)
C = dist(xalea, xalea)
C /= C.max()
C = nx.from_numpy(C, type_as=p)
# Initialization of C : random euclidean distance matrix (if not provided by user)
if fixed_structure:
if init_C is None:
raise UndefinedParameter('If C is fixed it must be initialized')
else:
C = init_C
else:
C = init_C
if init_C is None:
generator = check_random_state(random_state)
xalea = generator.randn(N, 2)
C = dist(xalea, xalea)
C = nx.from_numpy(C, type_as=ps[0])
else:
C = init_C

# Initialization of Y
if init_Y is None:
Y = nx.zeros((N, d), type_as=ps[0])
if fixed_features:
if init_Y is None:
raise UndefinedParameter('If Y is fixed it must be initialized')
else:
Y = init_Y
else:
Y = init_Y
if init_Y is None:
Y = nx.zeros((N, d), type_as=ps[0])

if warmstartT:
T = [None] * S
else:
Y = init_Y

Ms = [dist(Y, Ys[s]) for s in range(len(Ys))]

if warmstartT:
T = [None] * S

cpt = 0
err = 1

err_feature = 1
err_structure = 1
if stop_criterion == 'barycenter':
inner_log = False
err_feature = 1e15
err_structure = 1e15
err_rel_loss = 0.

else:
inner_log = True
err_feature = 0.
err_structure = 0.
curr_loss = 1e15
err_rel_loss = 1e15

if log:
log_ = {}
log_['err_feature'] = []
log_['err_structure'] = []
log_['Ts_iter'] = []
if stop_criterion == 'barycenter':
log_['err_feature'] = []
log_['err_structure'] = []
log_['Ts_iter'] = []
else:
log_['loss'] = []
log_['err_rel_loss'] = []

while (err > tol) and (cpt < max_iter):
Cprev = C
Yprev = Y
while ((err_feature > tol or err_structure > tol or err_rel_loss > tol) and cpt < max_iter):
if stop_criterion == 'barycenter':
Cprev = C
Yprev = Y
else:
prev_loss = curr_loss

# get transport plans
if warmstartT:
T = [entropic_fused_gromov_wasserstein(
res = [entropic_fused_gromov_wasserstein(
Ms[s], C, Cs[s], p, ps[s], loss_fun, epsilon, symmetric, alpha,
T[s], max_iter, 1e-4, verbose=verbose, log=False, **kwargs) for s in range(S)]
T[s], max_iter, 1e-4, verbose=verbose, log=inner_log, **kwargs) for s in range(S)]

else:
T = [entropic_fused_gromov_wasserstein(
res = [entropic_fused_gromov_wasserstein(
Ms[s], C, Cs[s], p, ps[s], loss_fun, epsilon, symmetric, alpha,
None, max_iter, 1e-4, verbose=verbose, log=False, **kwargs) for s in range(S)]
None, max_iter, 1e-4, verbose=verbose, log=inner_log, **kwargs) for s in range(S)]

if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs, nx)
elif loss_fun == 'kl_loss':
C = update_kl_loss(p, lambdas, T, Cs, nx)

Ys_temp = [y.T for y in Ys]
Y = update_feature_matrix(lambdas, Ys_temp, T, p, nx).T
Ms = [dist(Y, Ys[s]) for s in range(len(Ys))]

if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
err_feature = nx.norm(Y - nx.reshape(Yprev, (N, d)))
err_structure = nx.norm(C - Cprev)
if stop_criterion == 'barycenter':
T = res
else:
T = [output[0] for output in res]
curr_loss = np.sum([output[1]['fgw_dist'] for output in res])

# update barycenters
if not fixed_features:
Ys_temp = [y.T for y in Ys]
X = update_feature_matrix(lambdas, Ys_temp, T, p, nx).T
Ms = [dist(X, Ys[s]) for s in range(len(Ys))]

if not fixed_structure:
if loss_fun == 'square_loss':
C = update_square_loss(p, lambdas, T, Cs, nx)

elif loss_fun == 'kl_loss':
C = update_kl_loss(p, lambdas, T, Cs, nx)

# update convergence criterion
if stop_criterion == 'barycenter':
err_feature, err_structure = 0., 0.
if not fixed_features:
err_feature = nx.norm(Y - Yprev)
if not fixed_structure:
err_structure = nx.norm(C - Cprev)
if log:
log_['err_feature'].append(err_feature)
log_['err_structure'].append(err_structure)
Expand All @@ -1017,14 +1117,25 @@ def entropic_fused_gromov_barycenters(
'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(cpt, err_structure))
print('{:5d}|{:8e}|'.format(cpt, err_feature))
else:
err_rel_loss = abs(curr_loss - prev_loss) / prev_loss if prev_loss != 0. else np.nan
if log:
log_['loss'].append(curr_loss)
log_['err_rel_loss'].append(err_rel_loss)

if verbose:
if cpt % 200 == 0:
print('{:5s}|{:12s}'.format(
'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(cpt, err_rel_loss))

cpt += 1

if log:
log_['T'] = T # from target to Ys
log_['T'] = T
log_['p'] = p
log_['Ms'] = Ms

if log:
return Y, C, log_
else:
return Y, C
Loading

0 comments on commit cffb6cf

Please sign in to comment.