Skip to content

Commit

Permalink
Move from random seeding to local generators (PythonOT#512)
Browse files Browse the repository at this point in the history
  • Loading branch information
kachayev authored Aug 25, 2023
1 parent 98e3187 commit 20cc202
Show file tree
Hide file tree
Showing 18 changed files with 206 additions and 169 deletions.
10 changes: 7 additions & 3 deletions ot/dr.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import pymanopt.optimizers

from .bregman import sinkhorn as sinkhorn_bregman
from .utils import dist as dist_utils
from .utils import dist as dist_utils, check_random_state


def dist(x1, x2):
Expand Down Expand Up @@ -267,7 +267,7 @@ def proj(X):
return Popt.point, proj


def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0):
def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0, random_state=None):
r"""
Projection Robust Wasserstein Distance :ref:`[32] <references-projection-robust-wasserstein>`
Expand Down Expand Up @@ -303,6 +303,9 @@ def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopTh
Stop threshold on error (>0)
verbose : int, optional
Print information along iterations.
random_state : int, RandomState instance or None, default=None
Determines random number generation for initial value of projection
operator when U0 is not given.
Returns
-------
Expand Down Expand Up @@ -332,7 +335,8 @@ def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopTh
assert d > k

if U0 is None:
U = np.random.randn(d, k)
rng = check_random_state(random_state)
U = rng.randn(d, k)
U, _ = np.linalg.qr(U)
else:
U = U0
Expand Down
27 changes: 19 additions & 8 deletions ot/gromov/_dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
import numpy as np


from ..utils import unif
from ..utils import unif, check_random_state
from ..backend import get_backend
from ._gw import gromov_wasserstein, fused_gromov_wasserstein


def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate=1., Cdict_init=None, projection='nonnegative_symmetric', use_log=True,
tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs):
tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, random_state=None, **kwargs):
r"""
Infer Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, q) \}_{d \in [D]}` from the list of structures :math:`\{ (\mathbf{C_s},\mathbf{p_s}) \}_s`
Expand Down Expand Up @@ -81,6 +81,9 @@ def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, e
Maximum number of iterations for the Conjugate Gradient. Default is 200.
verbose : bool, optional
Print the reconstruction loss every epoch. Default is False.
random_state : int, RandomState instance or None, default=None
Determines random number generation. Pass an int for reproducible
output across multiple function calls.
Returns
-------
Expand All @@ -90,6 +93,7 @@ def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, e
The dictionary leading to the best loss over an epoch is saved and returned.
log: dict
If use_log is True, contains loss evolutions by batches and epochs.
References
-------
.. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
Expand All @@ -110,10 +114,11 @@ def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, e
q = unif(nt)
else:
q = nx.to_numpy(q)
rng = check_random_state(random_state)
if Cdict_init is None:
# Initialize randomly structures of dictionary atoms based on samples
dataset_means = [C.mean() for C in Cs]
Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt))
Cdict = rng.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt))
else:
Cdict = nx.to_numpy(Cdict_init).copy()
assert Cdict.shape == (D, nt, nt)
Expand Down Expand Up @@ -141,7 +146,7 @@ def gromov_wasserstein_dictionary_learning(Cs, D, nt, reg=0., ps=None, q=None, e

for _ in range(iter_by_epoch):
# batch sampling
batch = np.random.choice(range(dataset_size), size=batch_size, replace=False)
batch = rng.choice(range(dataset_size), size=batch_size, replace=False)
cumulated_loss_over_batch = 0.
unmixings = np.zeros((batch_size, D))
Cs_embedded = np.zeros((batch_size, nt, nt))
Expand Down Expand Up @@ -469,7 +474,8 @@ def _linesearch_gromov_wasserstein_unmixing(w, grad_w, x, Cdict, Cembedded, cons

def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., ps=None, q=None, epochs=20, batch_size=32, learning_rate_C=1., learning_rate_Y=1.,
Cdict_init=None, Ydict_init=None, projection='nonnegative_symmetric', use_log=False,
tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False, **kwargs):
tol_outer=10**(-5), tol_inner=10**(-5), max_iter_outer=20, max_iter_inner=200, use_adam_optimizer=True, verbose=False,
random_state=None, **kwargs):
r"""
Infer Fused Gromov-Wasserstein linear dictionary :math:`\{ (\mathbf{C_{dict}[d]}, \mathbf{Y_{dict}[d]}, \mathbf{q}) \}_{d \in [D]}` from the list of S attributed structures :math:`\{ (\mathbf{C_s}, \mathbf{Y_s},\mathbf{p_s}) \}_s`
Expand Down Expand Up @@ -548,6 +554,9 @@ def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., p
Maximum number of iterations for the Conjugate Gradient. Default is 200.
verbose : bool, optional
Print the reconstruction loss every epoch. Default is False.
random_state : int, RandomState instance or None, default=None
Determines random number generation. Pass an int for reproducible
output across multiple function calls.
Returns
-------
Expand All @@ -560,6 +569,7 @@ def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., p
The dictionary leading to the best loss over an epoch is saved and returned.
log: dict
If use_log is True, contains loss evolutions by batches and epochs.
References
-------
.. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online
Expand All @@ -583,17 +593,18 @@ def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., p
else:
q = nx.to_numpy(q)

rng = check_random_state(random_state)
if Cdict_init is None:
# Initialize randomly structures of dictionary atoms based on samples
dataset_means = [C.mean() for C in Cs]
Cdict = np.random.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt))
Cdict = rng.normal(loc=np.mean(dataset_means), scale=np.std(dataset_means), size=(D, nt, nt))
else:
Cdict = nx.to_numpy(Cdict_init).copy()
assert Cdict.shape == (D, nt, nt)
if Ydict_init is None:
# Initialize randomly features of dictionary atoms based on samples distribution by feature component
dataset_feature_means = np.stack([F.mean(axis=0) for F in Ys])
Ydict = np.random.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(D, nt, d))
Ydict = rng.normal(loc=dataset_feature_means.mean(axis=0), scale=dataset_feature_means.std(axis=0), size=(D, nt, d))
else:
Ydict = nx.to_numpy(Ydict_init).copy()
assert Ydict.shape == (D, nt, d)
Expand Down Expand Up @@ -626,7 +637,7 @@ def fused_gromov_wasserstein_dictionary_learning(Cs, Ys, D, nt, alpha, reg=0., p
for _ in range(iter_by_epoch):

# Batch iterations
batch = np.random.choice(range(dataset_size), size=batch_size, replace=False)
batch = rng.choice(range(dataset_size), size=batch_size, replace=False)
cumulated_loss_over_batch = 0.
unmixings = np.zeros((batch_size, D))
Cs_embedded = np.zeros((batch_size, nt, nt))
Expand Down
28 changes: 20 additions & 8 deletions ot/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# License: MIT License

import numpy as np
from .utils import dist
from .utils import dist, check_random_state
from .backend import get_backend

##############################################################################
Expand Down Expand Up @@ -69,7 +69,7 @@ def coordinate_grad_semi_dual(b, M, reg, beta, i):
return b - khi


def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None):
def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None, random_state=None):
r"""
Compute the SAG algorithm to solve the regularized discrete measures optimal transport max problem
Expand Down Expand Up @@ -110,6 +110,9 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None):
Number of iteration.
lr : float
Learning rate.
random_state : int, RandomState instance or None, default=None
Determines random number generation. Pass an int for reproducible
output across multiple function calls.
Returns
-------
Expand All @@ -129,8 +132,9 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None):
cur_beta = np.zeros(n_target)
stored_gradient = np.zeros((n_source, n_target))
sum_stored_gradient = np.zeros(n_target)
rng = check_random_state(random_state)
for _ in range(numItermax):
i = np.random.randint(n_source)
i = rng.randint(n_source)
cur_coord_grad = a[i] * coordinate_grad_semi_dual(b, M, reg,
cur_beta, i)
sum_stored_gradient += (cur_coord_grad - stored_gradient[i])
Expand All @@ -139,7 +143,7 @@ def sag_entropic_transport(a, b, M, reg, numItermax=10000, lr=None):
return cur_beta


def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None):
def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None, random_state=None):
r'''
Compute the ASGD algorithm to solve the regularized semi continous measures optimal transport max problem
Expand Down Expand Up @@ -177,6 +181,9 @@ def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None):
Number of iteration.
lr : float
Learning rate.
random_state : int, RandomState instance or None, default=None
Determines random number generation. Pass an int for reproducible
output across multiple function calls.
Returns
-------
Expand All @@ -195,9 +202,10 @@ def averaged_sgd_entropic_transport(a, b, M, reg, numItermax=300000, lr=None):
n_target = np.shape(M)[1]
cur_beta = np.zeros(n_target)
ave_beta = np.zeros(n_target)
rng = check_random_state(random_state)
for cur_iter in range(numItermax):
k = cur_iter + 1
i = np.random.randint(n_source)
i = rng.randint(n_source)
cur_coord_grad = coordinate_grad_semi_dual(b, M, reg, cur_beta, i)
cur_beta += (lr / np.sqrt(k)) * cur_coord_grad
ave_beta = (1. / k) * cur_beta + (1 - 1. / k) * ave_beta
Expand Down Expand Up @@ -422,7 +430,7 @@ def batch_grad_dual(a, b, M, reg, alpha, beta, batch_size, batch_alpha,
return grad_alpha, grad_beta


def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr, random_state=None):
r'''
Compute the sgd algorithm to solve the regularized discrete measures optimal transport dual problem
Expand Down Expand Up @@ -460,6 +468,9 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
number of iteration
lr : float
learning rate
random_state : int, RandomState instance or None, default=None
Determines random number generation. Pass an int for reproducible
output across multiple function calls.
Returns
-------
Expand All @@ -477,10 +488,11 @@ def sgd_entropic_regularization(a, b, M, reg, batch_size, numItermax, lr):
n_target = np.shape(M)[1]
cur_alpha = np.zeros(n_source)
cur_beta = np.zeros(n_target)
rng = check_random_state(random_state)
for cur_iter in range(numItermax):
k = np.sqrt(cur_iter + 1)
batch_alpha = np.random.choice(n_source, batch_size, replace=False)
batch_beta = np.random.choice(n_target, batch_size, replace=False)
batch_alpha = rng.choice(n_source, batch_size, replace=False)
batch_beta = rng.choice(n_target, batch_size, replace=False)
update_alpha, update_beta = batch_grad_dual(a, b, M, reg, cur_alpha,
cur_beta, batch_size,
batch_alpha, batch_beta)
Expand Down
4 changes: 2 additions & 2 deletions test/test_1d_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,8 @@ def test_emd_1d_emd2_1d():
np.testing.assert_allclose(G, G_1d, atol=1e-15)

# check AssertionError is raised if called on non 1d arrays
u = np.random.randn(n, 2)
v = np.random.randn(m, 2)
u = rng.randn(n, 2)
v = rng.randn(m, 2)
with pytest.raises(AssertionError):
ot.emd_1d(u, v, [], [])

Expand Down
19 changes: 12 additions & 7 deletions test/test_bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,8 @@ def test_sinkhorn_variants(nx):
def test_sinkhorn_variants_dtype_device(nx, method):
n = 100

x = np.random.randn(n, 2)
rng = np.random.RandomState(42)
x = rng.randn(n, 2)
u = ot.utils.unif(n)

M = ot.dist(x, x)
Expand All @@ -317,7 +318,8 @@ def test_sinkhorn_variants_dtype_device(nx, method):
def test_sinkhorn2_variants_dtype_device(nx, method):
n = 100

x = np.random.randn(n, 2)
rng = np.random.RandomState(42)
x = rng.randn(n, 2)
u = ot.utils.unif(n)

M = ot.dist(x, x)
Expand All @@ -337,7 +339,8 @@ def test_sinkhorn2_variants_dtype_device(nx, method):
def test_sinkhorn2_variants_device_tf(method):
nx = ot.backend.TensorflowBackend()
n = 100
x = np.random.randn(n, 2)
rng = np.random.RandomState(42)
x = rng.randn(n, 2)
u = ot.utils.unif(n)
M = ot.dist(x, x)

Expand Down Expand Up @@ -690,11 +693,12 @@ def test_barycenter_stabilization(nx):

@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"])
def test_wasserstein_bary_2d(nx, method):
rng = np.random.RandomState(42)
size = 20 # size of a square image
a1 = np.random.rand(size, size)
a1 = rng.rand(size, size)
a1 += a1.min()
a1 = a1 / np.sum(a1)
a2 = np.random.rand(size, size)
a2 = rng.rand(size, size)
a2 += a2.min()
a2 = a2 / np.sum(a2)
# creating matrix A containing all distributions
Expand Down Expand Up @@ -724,11 +728,12 @@ def test_wasserstein_bary_2d(nx, method):

@pytest.mark.parametrize("method", ["sinkhorn", "sinkhorn_log"])
def test_wasserstein_bary_2d_debiased(nx, method):
rng = np.random.RandomState(42)
size = 20 # size of a square image
a1 = np.random.rand(size, size)
a1 = rng.rand(size, size)
a1 += a1.min()
a1 = a1 / np.sum(a1)
a2 = np.random.rand(size, size)
a2 = rng.rand(size, size)
a2 += a2.min()
a2 = a2 / np.sum(a2)
# creating matrix A containing all distributions
Expand Down
13 changes: 7 additions & 6 deletions test/test_coot.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,21 +223,22 @@ def test_coot_warmstart(nx):
xt_nx = nx.from_numpy(xt)

# initialize warmstart
init_pi_sample = np.random.rand(n_samples, n_samples)
rng = np.random.RandomState(42)
init_pi_sample = rng.rand(n_samples, n_samples)
init_pi_sample = init_pi_sample / np.sum(init_pi_sample)
init_pi_sample_nx = nx.from_numpy(init_pi_sample)

init_pi_feature = np.random.rand(2, 2)
init_pi_feature = rng.rand(2, 2)
init_pi_feature /= init_pi_feature / np.sum(init_pi_feature)
init_pi_feature_nx = nx.from_numpy(init_pi_feature)

init_duals_sample = (np.random.random(n_samples) * 2 - 1,
np.random.random(n_samples) * 2 - 1)
init_duals_sample = (rng.random(n_samples) * 2 - 1,
rng.random(n_samples) * 2 - 1)
init_duals_sample_nx = (nx.from_numpy(init_duals_sample[0]),
nx.from_numpy(init_duals_sample[1]))

init_duals_feature = (np.random.random(2) * 2 - 1,
np.random.random(2) * 2 - 1)
init_duals_feature = (rng.random(2) * 2 - 1,
rng.random(2) * 2 - 1)
init_duals_feature_nx = (nx.from_numpy(init_duals_feature[0]),
nx.from_numpy(init_duals_feature[1]))

Expand Down
8 changes: 3 additions & 5 deletions test/test_da.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,12 +567,11 @@ def test_mapping_transport_class_specific_seed(nx):
# check that it does not crash when derphi is very close to 0
ns = 20
nt = 30
np.random.seed(39)
Xs, ys = make_data_classif('3gauss', ns)
Xt, yt = make_data_classif('3gauss2', nt)
rng = np.random.RandomState(39)
Xs, ys = make_data_classif('3gauss', ns, random_state=rng)
Xt, yt = make_data_classif('3gauss2', nt, random_state=rng)
otda = ot.da.MappingTransport(kernel="gaussian", bias=False)
otda.fit(Xs=nx.from_numpy(Xs), Xt=nx.from_numpy(Xt))
np.random.seed(None)


@pytest.skip_backend("jax")
Expand Down Expand Up @@ -712,7 +711,6 @@ def test_jcpot_barycenter(nx):
nt = 50

sigma = 0.1
np.random.seed(1985)

ps1 = .2
ps2 = .9
Expand Down
1 change: 0 additions & 1 deletion test/test_dmmot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@


def create_test_data(nx):
np.random.seed(1234)
n = 4
a1 = ot.datasets.make_1D_gauss(n, m=20, s=5)
a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)
Expand Down
Loading

0 comments on commit 20cc202

Please sign in to comment.