Skip to content

Commit

Permalink
[MRG] fix gpu compatibility of srGW solvers (PythonOT#596)
Browse files Browse the repository at this point in the history
* fix gpu compatibility of srgw solvers

* update release and pep8
  • Loading branch information
cedricvincentcuaz authored Jan 14, 2024
1 parent 336980f commit f395e58
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -354,4 +354,4 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil

[64] Ma, X., Chu, X., Wang, Y., Lin, Y., Zhao, J., Ma, L., & Zhu, W. (2023). [Fused Gromov-Wasserstein Graph Mixup for Graph-level Classifications](https://openreview.net/pdf?id=uqkUguNu40). In Thirty-seventh Conference on Neural Information Processing Systems.

[65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). [Low-Rank Sinkhorn Factorization](https://arxiv.org/pdf/2103.04737.pdf).
[65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). [Low-Rank Sinkhorn Factorization](https://arxiv.org/pdf/2103.04737.pdf).
4 changes: 2 additions & 2 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#### Closed issues
- Fixed an issue with cost correction for mismatched labels in `ot.da.BaseTransport` fit methods. This fix addresses the original issue introduced PR #587 (PR #593)

- Fix gpu compatibility of sr(F)GW solvers when `G0 is not None`(PR #596)

## 0.9.2
*December 2023*
Expand Down Expand Up @@ -671,4 +671,4 @@ It provides the following solvers:
* Optimal transport for domain adaptation with group lasso regularization
* Conditional gradient and Generalized conditional gradient for regularized OT.

Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder.
Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder.
10 changes: 5 additions & 5 deletions ot/gromov/_semirelaxed.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def semirelaxed_gromov_wasserstein(C1, C2, p=None, loss_fun='square_loss', symme
else:
q = nx.sum(G0, 0)
# Check first marginal of G0
np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08)
assert nx.allclose(nx.sum(G0, 1), p, atol=1e-08)

constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx)

Expand Down Expand Up @@ -363,8 +363,8 @@ def semirelaxed_fused_gromov_wasserstein(
G0 = nx.outer(p, q)
else:
q = nx.sum(G0, 0)
# Check marginals of G0
np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08)
# Check first marginal of G0
assert nx.allclose(nx.sum(G0, 1), p, atol=1e-08)

constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx)

Expand Down Expand Up @@ -703,7 +703,7 @@ def entropic_semirelaxed_gromov_wasserstein(
else:
q = nx.sum(G0, 0)
# Check first marginal of G0
np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08)
assert nx.allclose(nx.sum(G0, 1), p, atol=1e-08)

constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx)

Expand Down Expand Up @@ -951,7 +951,7 @@ def entropic_semirelaxed_fused_gromov_wasserstein(
else:
q = nx.sum(G0, 0)
# Check first marginal of G0
np.testing.assert_allclose(nx.sum(G0, 1), p, atol=1e-08)
assert nx.allclose(nx.sum(G0, 1), p, atol=1e-08)

constC, hC1, hC2, fC2t = init_matrix_semirelaxed(C1, C2, p, loss_fun, nx)

Expand Down

0 comments on commit f395e58

Please sign in to comment.