Skip to content

Commit

Permalink
[MRG] Update tests and documentation (PythonOT#484)
Browse files Browse the repository at this point in the history
* remove old macos and windows tets update requirements

* speedup ssw and continuaous ot exmaples

* speedup regpath and variane

* speedup conv 2d example + continuous stick

* speedup regpath
  • Loading branch information
rflamary authored Jun 9, 2023
1 parent 5faa4fb commit 6c1e1f3
Show file tree
Hide file tree
Showing 11 changed files with 44 additions and 41 deletions.
16 changes: 8 additions & 8 deletions .github/workflows/build_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ jobs:
pip install -e .
- name: Run tests
run: |
python -m pytest --durations=20 -v test/ ot/ --ignore ot/gpu/ --color=yes
python -m pytest --durations=20 -v test/ ot/ --color=yes
macos:
Expand All @@ -92,7 +92,7 @@ jobs:
strategy:
max-parallel: 4
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10"]
python-version: ["3.10"]

steps:
- uses: actions/checkout@v1
Expand All @@ -107,10 +107,10 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install pytest "pytest-cov<2.6"
pip install pytest
- name: Run tests
run: |
python -m pytest --durations=20 -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot --color=yes
python -m pytest --durations=20 -v test/ ot/ --color=yes
windows:
Expand All @@ -119,7 +119,7 @@ jobs:
strategy:
max-parallel: 4
matrix:
python-version: ["3.7", "3.8", "3.9", "3.10"]
python-version: ["3.10"]

steps:
- uses: actions/checkout@v1
Expand Down Expand Up @@ -151,8 +151,8 @@ jobs:
- name: Install dependencies
run: |
python -m pip install -r .github/requirements_test_windows.txt
python -m pip install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
python -m pip install pytest "pytest-cov<2.6"
python -m pip3 install torch torchvision torchaudio
python -m pip install pytest
- name: Run tests
run: |
python -m pytest --durations=20 -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot --color=yes
python -m pytest --durations=20 -v test/ ot/ --color=yes
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- Added the sparsity-constrained OT solver to `ot.smooth` and added ` projection_sparse_simplex` to `ot.utils` (PR #459)
- Add tests on GPU for master branch and approved PR (PR #473)
- Add `median` method to all inherited classes of `backend.Backend` (PR #472)
- Update tests for macOS and Windows, speedup documentation (PR #484)

#### Closed issues

Expand Down
2 changes: 2 additions & 0 deletions docs/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ help:
.PHONY: clean
clean:
rm -rf $(BUILDDIR)/*
rm -rf source/gen_modules/*
rm -rf source/auto_examples/*

.PHONY: html
html:
Expand Down
8 changes: 4 additions & 4 deletions examples/backends/plot_sliced_wass_grad_flow_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@
import ot
import matplotlib.animation as animation

I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::4, ::4, 2]
I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::4, ::4, 2]
I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::5, ::5, 2]
I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::5, ::5, 2]

sz = I2.shape[0]
XX, YY = np.meshgrid(np.arange(sz), np.arange(sz))
Expand All @@ -67,7 +67,7 @@


lr = 1e3
nb_iter_max = 100
nb_iter_max = 50

x_all = np.zeros((nb_iter_max, x1.shape[0], 2))

Expand Down Expand Up @@ -129,7 +129,7 @@ def _update_plot(i):
xbary_torch = torch.tensor(xbinit).to(device=device).requires_grad_(True)

lr = 1e3
nb_iter_max = 100
nb_iter_max = 50

x_all = np.zeros((nb_iter_max, xbary_torch.shape[0], 2))

Expand Down
12 changes: 6 additions & 6 deletions examples/backends/plot_ssw_unif_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

torch.manual_seed(1)

N = 1000
N = 500
x0 = torch.rand(N, 3)
x0 = F.normalize(x0, dim=-1)

Expand Down Expand Up @@ -72,8 +72,8 @@ def plot_sphere(ax):
x = x0.clone()
x.requires_grad_(True)

n_iter = 500
lr = 100
n_iter = 100
lr = 150

losses = []
xvisu = torch.zeros(n_iter, N, 3)
Expand All @@ -82,7 +82,7 @@ def plot_sphere(ax):
sw = ot.sliced_wasserstein_sphere_unif(x, n_projections=500)
grad_x = torch.autograd.grad(sw, x)[0]

x = x - lr * grad_x
x = x - lr * grad_x / np.sqrt(i / 10 + 1)
x = F.normalize(x, p=2, dim=1)

losses.append(sw.item())
Expand All @@ -102,7 +102,7 @@ def plot_sphere(ax):
# Plot trajectories of generated samples along iterations
# -------------------------------------------------------

ivisu = [0, 25, 50, 75, 100, 150, 200, 350, 499]
ivisu = [0, 10, 20, 30, 40, 50, 60, 70, 80]

fig = pl.figure(3, (10, 10))
for i in range(9):
Expand Down Expand Up @@ -149,5 +149,5 @@ def _update_plot(i):
ax.set_title('Iter. {}'.format(ivisu[i]))


ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter // 5, interval=100, repeat_delay=2000)
ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter // 5, interval=200, repeat_delay=2000)
# %%
6 changes: 3 additions & 3 deletions examples/backends/plot_stoch_continuous_ot_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
torch.manual_seed(42)
np.random.seed(42)

n_source_samples = 10000
n_target_samples = 10000
n_source_samples = 1000
n_target_samples = 1000
theta = 2 * np.pi / 20
noise_level = 0.1

Expand Down Expand Up @@ -89,7 +89,7 @@ def forward(self, x):
optimizer = torch.optim.Adam(list(u.parameters()) + list(v.parameters()), lr=.005)

# number of iteration
n_iter = 1000
n_iter = 500
n_batch = 500


Expand Down
8 changes: 4 additions & 4 deletions examples/barycenters/plot_convolutional_barycenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@
this_file = os.path.realpath('__file__')
data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')

f1 = 1 - plt.imread(os.path.join(data_path, 'redcross.png'))[:, :, 2]
f2 = 1 - plt.imread(os.path.join(data_path, 'tooth.png'))[:, :, 2]
f3 = 1 - plt.imread(os.path.join(data_path, 'heart.png'))[:, :, 2]
f4 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[:, :, 2]
f1 = 1 - plt.imread(os.path.join(data_path, 'redcross.png'))[::2, ::2, 2]
f2 = 1 - plt.imread(os.path.join(data_path, 'tooth.png'))[::2, ::2, 2]
f3 = 1 - plt.imread(os.path.join(data_path, 'heart.png'))[::2, ::2, 2]
f4 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[::2, ::2, 2]

f1 = f1 / np.sum(f1)
f2 = f2 / np.sum(f2)
Expand Down
2 changes: 1 addition & 1 deletion examples/plot_OT_1D_smooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#
# License: MIT License

# sphinx_gallery_thumbnail_number = 6
# sphinx_gallery_thumbnail_number = 5

import numpy as np
import matplotlib.pylab as pl
Expand Down
8 changes: 4 additions & 4 deletions examples/sliced-wasserstein/plot_variance.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

# %% parameters and data generation

n = 500 # nb samples
n = 200 # nb samples

mu_s = np.array([0, 0])
cov_s = np.array([[1, 0], [0, 1]])
Expand Down Expand Up @@ -58,9 +58,9 @@
# Sliced Wasserstein distance for different seeds and number of projections
# -------------------------------------------------------------------------

n_seed = 50
n_projections_arr = np.logspace(0, 3, 25, dtype=int)
res = np.empty((n_seed, 25))
n_seed = 20
n_projections_arr = np.logspace(0, 3, 10, dtype=int)
res = np.empty((n_seed, 10))

# %% Compute statistics
for seed in range(n_seed):
Expand Down
8 changes: 4 additions & 4 deletions examples/sliced-wasserstein/plot_variance_ssw.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

# %% parameters and data generation

n = 500 # nb samples
n = 200 # nb samples

xs = np.random.randn(n, 3)
xt = np.random.randn(n, 3)
Expand Down Expand Up @@ -81,9 +81,9 @@
# Spherical Sliced Wasserstein for different seeds and number of projections
# --------------------------------------------------------------------------

n_seed = 50
n_projections_arr = np.logspace(0, 3, 25, dtype=int)
res = np.empty((n_seed, 25))
n_seed = 20
n_projections_arr = np.logspace(0, 3, 10, dtype=int)
res = np.empty((n_seed, 10))

# %% Compute statistics
for seed in range(n_seed):
Expand Down
14 changes: 7 additions & 7 deletions examples/unbalanced-partial/plot_regpath.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

#%% parameters and data generation

n = 50 # nb samples
n = 20 # nb samples

mu_s = np.array([-1, -1])
cov_s = np.array([[1, 0], [0, 1]])
Expand Down Expand Up @@ -63,7 +63,7 @@
# -----------------------------------------------------------

#%%
final_gamma = 1e-8
final_gamma = 1e-6
t, t_list, g_list = ot.regpath.regularization_path(a, b, M, reg=final_gamma,
semi_relaxed=False)
t2, t_list2, g_list2 = ot.regpath.regularization_path(a, b, M, reg=final_gamma,
Expand Down Expand Up @@ -111,7 +111,7 @@
# Animation of the regpath for UOT l2
# -----------------------------------

nv = 100
nv = 50
g_list_v = np.logspace(-.5, -2.5, nv)

pl.figure(3)
Expand Down Expand Up @@ -144,7 +144,7 @@ def _update_plot(iv):
i = 0
_update_plot(i)

ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=50, repeat_delay=2000)
ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=100, repeat_delay=2000)


##############################################################################
Expand Down Expand Up @@ -183,8 +183,8 @@ def _update_plot(iv):
# Animation of the regpath for semi-relaxed UOT l2
# ------------------------------------------------

nv = 100
g_list_v = np.logspace(2.5, -2, nv)
nv = 50
g_list_v = np.logspace(2, -2, nv)

pl.figure(5)

Expand Down Expand Up @@ -216,4 +216,4 @@ def _update_plot(iv):
i = 0
_update_plot(i)

ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=50, repeat_delay=2000)
ani = animation.FuncAnimation(pl.gcf(), _update_plot, nv, interval=100, repeat_delay=2000)

0 comments on commit 6c1e1f3

Please sign in to comment.