Skip to content

Commit

Permalink
[MRG] Add Wasserstein barycenter for Gaussian distribution (PythonOT#582
Browse files Browse the repository at this point in the history
)

* add barycenter functions

* add tests

* update release
  • Loading branch information
tgnassou authored Nov 30, 2023
1 parent 89e6010 commit f809253
Show file tree
Hide file tree
Showing 3 changed files with 248 additions and 0 deletions.
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
+ Wrapper for `geomloss`` solver on empirical samples (PR #571)
+ Add `stop_criterion` feature to (un)regularized (f)gw barycenter solvers (PR #578)
+ Add `fixed_structure` and `fixed_features` to entropic fgw barycenter solver (PR #578)
+ Add Bures-Wasserstein barycenter in `ot.gaussian` (PR #582)

#### Closed issues
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)
Expand Down
182 changes: 182 additions & 0 deletions ot/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,188 @@ def empirical_bures_wasserstein_distance(xs, xt, reg=1e-6, ws=None,
return W


def bures_wasserstein_barycenter(m, C, weights=None, num_iter=1000, eps=1e-7, log=False):
r"""Return OT linear operator between samples.
The function estimates the optimal barycenter of the
empirical distributions. This is equivalent to resolving the fixed point
algorithm for multiple Gaussian distributions :math:`\left{\mathcal{N}(\mu,\Sigma)\right}_{i=1}^n`
:ref:`[1] <references-OT-mapping-linear-barycenter>`.
The barycenter still following a Gaussian distribution :math:`\mathcal{N}(\mu_b,\Sigma_b)`
where :
.. math::
\mu_b = \sum_{i=1}^n w_i \mu_i
And the barycentric covariance is the solution of the following fixed-point algorithm:
.. math::
\Sigma_b = \sum_{i=1}^n w_i \left(\Sigma_b^{1/2}\Sigma_i^{1/2}\Sigma_b^{1/2}\right)^{1/2}
Parameters
----------
m : array-like (k,d)
mean of k distributions
C : array-like (k,d,d)
covariance of k distributions
weights : array-like (k), optional
weights for each distribution
num_iter : int, optional
number of iteration for the fixed point algorithm
eps : float, optional
tolerance for the fixed point algorithm
log : bool, optional
record log if True
Returns
-------
mb : (d,) array-like
mean of the barycenter
Cb : (d, d) array-like
covariance of the barycenter
log : dict
log dictionary return only if log==True in parameters
.. _references-OT-mapping-linear-barycenter:
References
----------
.. [1] M. Agueh and G. Carlier, "Barycenters in the Wasserstein space",
SIAM Journal on Mathematical Analysis, vol. 43, no. 2, pp. 904-924,
2011.
"""
nx = get_backend(*C, *m,)

# Compute the mean barycenter
mb = nx.mean(m)

# Init the covariance barycenter
Cb = nx.mean(C, axis=0)

if weights is None:
weights = nx.ones(len(C), type_as=C[0]) / len(C)

for it in range(num_iter):
# fixed point update
Cb12 = nx.sqrtm(Cb)

Cnew = Cb12 @ C @ Cb12
C_ = []
for i in range(len(C)):
C_.append(nx.sqrtm(Cnew[i]))
Cnew = nx.stack(C_, axis=0)
Cnew *= weights[:, None, None]
Cnew = nx.sum(Cnew, axis=0)

# check convergence
diff = nx.norm(Cb - Cnew)
if diff <= eps:
break
Cb = Cnew
else:
print("Dit not converge.")

if log:
log = {}
log['num_iter'] = it
log['final_diff'] = diff
return mb, Cb, log
else:
return mb, Cb


def empirical_bures_wasserstein_barycenter(
X, reg=1e-6, weights=None, num_iter=1000, eps=1e-7,
w=None, bias=True, log=False
):
r"""Return OT linear operator between samples.
The function estimates the optimal barycenter of the
empirical distributions. This is equivalent to resolving the fixed point
algorithm for multiple Gaussian distributions :math:`\left{\mathcal{N}(\mu,\Sigma)\right}_{i=1}^n`
:ref:`[1] <references-OT-mapping-linear-barycenter>`.
The barycenter still following a Gaussian distribution :math:`\mathcal{N}(\mu_b,\Sigma_b)`
where :
.. math::
\mu_b = \sum_{i=1}^n w_i \mu_i
And the barycentric covariance is the solution of the following fixed-point algorithm:
.. math::
\Sigma_b = \sum_{i=1}^n w_i \left(\Sigma_b^{1/2}\Sigma_i^{1/2}\Sigma_b^{1/2}\right)^{1/2}
Parameters
----------
X : list of array-like (n,d)
samples in each distribution
reg : float,optional
regularization added to the diagonals of covariances (>0)
weights : array-like (n,), optional
weights for each distribution
num_iter : int, optional
number of iteration for the fixed point algorithm
eps : float, optional
tolerance for the fixed point algorithm
w : list of array-like (n,), optional
weights for each sample in each distribution
bias: boolean, optional
estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True)
log : bool, optional
record log if True
Returns
-------
mb : (d,) array-like
mean of the barycenter
Cb : (d, d) array-like
covariance of the barycenter
log : dict
log dictionary return only if log==True in parameters
.. _references-OT-mapping-linear-barycenter:
References
----------
.. [1] M. Agueh and G. Carlier, "Barycenters in the Wasserstein space",
SIAM Journal on Mathematical Analysis, vol. 43, no. 2, pp. 904-924,
2011.
"""
X = list_to_array(*X)
nx = get_backend(*X)

k = len(X)
d = [X[i].shape[1] for i in range(k)]

if bias:
m = [nx.mean(X[i], axis=0)[None, :] for i in range(k)]
X = [X[i] - m[i] for i in range(k)]
else:
m = [nx.zeros((1, d[i]), type_as=X[i]) for i in range(k)]

if w is None:
w = [nx.ones((X[i].shape[0], 1), type_as=X[i]) / X[i].shape[0] for i in range(k)]

C = [
nx.dot((X[i] * w[i]).T, X[i]) / nx.sum(w[i]) + reg * nx.eye(d[i], type_as=X[i])
for i in range(k)
]
m = nx.stack(m, axis=0)
C = nx.stack(C, axis=0)
if log:
mb, Cb, log = bures_wasserstein_barycenter(m, C, weights=weights, num_iter=num_iter, eps=eps, log=log)
return mb, Cb, log
else:
mb, Cb = bures_wasserstein_barycenter(m, C, weights=weights, num_iter=num_iter, eps=eps, log=log)
return mb, Cb


def gaussian_gromov_wasserstein_distance(Cov_s, Cov_t, log=False):
r""" Return the Gaussian Gromov-Wasserstein value from [57].
Expand Down
65 changes: 65 additions & 0 deletions test/test_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,71 @@ def test_empirical_bures_wasserstein_distance(nx, bias):
np.testing.assert_allclose(10 * bias, nx.to_numpy(Wb), rtol=1e-2, atol=1e-2)


def test_bures_wasserstein_barycenter(nx):
n = 50
k = 10
X = []
y = []
m = []
C = []
for _ in range(k):
X_, y_ = make_data_classif('3gauss', n)
m_ = np.mean(X_, axis=0)[None, :]
C_ = np.cov(X_.T)
X.append(X_)
y.append(y_)
m.append(m_)
C.append(C_)
m = np.array(m)
C = np.array(C)
X = nx.from_numpy(*X)
m = nx.from_numpy(m)
C = nx.from_numpy(C)

mblog, Cblog, log = ot.gaussian.bures_wasserstein_barycenter(m, C, log=True)
mb, Cb = ot.gaussian.bures_wasserstein_barycenter(m, C, log=False)

np.testing.assert_allclose(Cb, Cblog, rtol=1e-2, atol=1e-2)
np.testing.assert_allclose(mb, mblog, rtol=1e-2, atol=1e-2)

# Test weights argument
weights = nx.ones(k) / k
mbw, Cbw = ot.gaussian.bures_wasserstein_barycenter(m, C, weights=weights, log=False)
np.testing.assert_allclose(Cbw, Cb, rtol=1e-2, atol=1e-2)

# test with closed form for diagonal covariance matrices
Cdiag = [nx.diag(nx.diag(C[i])) for i in range(k)]
Cdiag = nx.stack(Cdiag, axis=0)
mbdiag, Cbdiag = ot.gaussian.bures_wasserstein_barycenter(m, Cdiag, log=False)

Cdiag_sqrt = [nx.sqrtm(C) for C in Cdiag]
Cdiag_sqrt = nx.stack(Cdiag_sqrt, axis=0)
Cdiag_mean = nx.mean(Cdiag_sqrt, axis=0)
Cdiag_cf = Cdiag_mean @ Cdiag_mean

np.testing.assert_allclose(Cbdiag, Cdiag_cf, rtol=1e-2, atol=1e-2)


@pytest.mark.parametrize("bias", [True, False])
def test_empirical_bures_wasserstein_barycenter(nx, bias):
n = 50
k = 10
X = []
y = []
for _ in range(k):
X_, y_ = make_data_classif('3gauss', n)
X.append(X_)
y.append(y_)

X = nx.from_numpy(*X)

mblog, Cblog, log = ot.gaussian.empirical_bures_wasserstein_barycenter(X, log=True, bias=bias)
mb, Cb = ot.gaussian.empirical_bures_wasserstein_barycenter(X, log=False, bias=bias)

np.testing.assert_allclose(Cb, Cblog, rtol=1e-2, atol=1e-2)
np.testing.assert_allclose(mb, mblog, rtol=1e-2, atol=1e-2)


@pytest.mark.parametrize("d_target", [1, 2, 3, 10])
def test_gaussian_gromov_wasserstein_distance(nx, d_target):
ns = 400
Expand Down

0 comments on commit f809253

Please sign in to comment.