Skip to content

Commit

Permalink
Fix ordering
Browse files Browse the repository at this point in the history
  • Loading branch information
AdrienCorenflos committed Apr 1, 2020
1 parent fa06bb3 commit 592f933
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
2 changes: 1 addition & 1 deletion ot/lp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ def emd_1d(x_a, x_b, a=None, b=None, metric='sqeuclidean', p=1., dense=True,
perm_a = np.argsort(x_a_1d)
perm_b = np.argsort(x_b_1d)

G_sorted, indices, cost = emd_1d_sorted(a, b,
G_sorted, indices, cost = emd_1d_sorted(a[perm_a.flatten()], b[perm_b.flatten()],
x_a_1d[perm_a], x_b_1d[perm_b],
metric=metric, p=p)
G = coo_matrix((G_sorted, (perm_a[indices[:, 0]], perm_b[indices[:, 1]])),
Expand Down
38 changes: 38 additions & 0 deletions test/test_ot.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,44 @@ def test_emd_1d_emd2_1d():
with pytest.raises(AssertionError):
ot.emd_1d(u, v, [], [])

def test_emd_1d_emd2_1d_with_weights():

# test emd1d gives similar results as emd
n = 20
m = 30
rng = np.random.RandomState(0)
u = rng.randn(n, 1)
v = rng.randn(m, 1)

w_u = rng.uniform(0., 1., n)
w_u = w_u / w_u.sum()

w_v = rng.uniform(0., 1., m)
w_v = w_v / w_v.sum()

M = ot.dist(u, v, metric='sqeuclidean')

G, log = ot.emd(w_u, w_v, M, log=True)
wass = log["cost"]
G_1d, log = ot.emd_1d(u, v, w_u, w_v, metric='sqeuclidean', log=True)
wass1d = log["cost"]
wass1d_emd2 = ot.emd2_1d(u, v, w_u, w_v, metric='sqeuclidean', log=False)
wass1d_euc = ot.emd2_1d(u, v, w_u, w_v, metric='euclidean', log=False)

# check loss is similar
np.testing.assert_allclose(wass, wass1d)
np.testing.assert_allclose(wass, wass1d_emd2)

# check loss is similar to scipy's implementation for Euclidean metric
wass_sp = wasserstein_distance(u.reshape((-1,)), v.reshape((-1,)))
np.testing.assert_allclose(wass_sp, wass1d_euc)

# check constraints
np.testing.assert_allclose(w_u, G.sum(1))
np.testing.assert_allclose(w_v, G.sum(0))




def test_wass_1d():
# test emd1d gives similar results as emd
Expand Down

0 comments on commit 592f933

Please sign in to comment.