Skip to content

Commit

Permalink
[MRG] Gromov_Wasserstein2 not performing backward properly on GPU (Py…
Browse files Browse the repository at this point in the history
…thonOT#352)

* Resolves gromov wasserstein backward bug

* release file updated
  • Loading branch information
ncassereau authored Mar 2, 2022
1 parent 1781472 commit 9412f0a
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 29 deletions.
3 changes: 3 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
- Fix bug in instantiating an `autograd` function `ValFunction` (Issue #337,
PR #338)
- Fix POT ABI compatibility with old and new numpy (Issue #346, PR #349)
- Fix bug where gromov_wasserstein2 does not perform backpropagation with CUDA
tensors (Issue #351, PR #352)


## 0.8.1.0
*December 2021*
Expand Down
12 changes: 8 additions & 4 deletions ot/gromov.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,8 +546,10 @@ def df(G):
gw = log_gw['gw_dist']

if loss_fun == 'square_loss':
gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T))
gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T))
gC1 = 2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)
gC2 = 2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)
gC1 = nx.from_numpy(gC1, type_as=C10)
gC2 = nx.from_numpy(gC2, type_as=C10)
gw = nx.set_gradients(gw, (p0, q0, C10, C20),
(log_gw['u'], log_gw['v'], gC1, gC2))

Expand Down Expand Up @@ -786,8 +788,10 @@ def df(G):
log_fgw['T'] = T0

if loss_fun == 'square_loss':
gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T))
gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T))
gC1 = 2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)
gC2 = 2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)
gC1 = nx.from_numpy(gC1, type_as=C10)
gC2 = nx.from_numpy(gC2, type_as=C10)
fgw_dist = nx.set_gradients(fgw_dist, (p0, q0, C10, C20, M0),
(log_fgw['u'], log_fgw['v'], alpha * gC1, alpha * gC2, (1 - alpha) * T0))

Expand Down
60 changes: 35 additions & 25 deletions test/test_gromov.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,19 +181,24 @@ def test_gromov2_gradients():

if torch:

p1 = torch.tensor(p, requires_grad=True)
q1 = torch.tensor(q, requires_grad=True)
C11 = torch.tensor(C1, requires_grad=True)
C12 = torch.tensor(C2, requires_grad=True)
devices = [torch.device("cpu")]
if torch.cuda.is_available():
devices.append(torch.device("cuda"))
for device in devices:
p1 = torch.tensor(p, requires_grad=True, device=device)
q1 = torch.tensor(q, requires_grad=True, device=device)
C11 = torch.tensor(C1, requires_grad=True, device=device)
C12 = torch.tensor(C2, requires_grad=True, device=device)

val = ot.gromov_wasserstein2(C11, C12, p1, q1)
val = ot.gromov_wasserstein2(C11, C12, p1, q1)

val.backward()
val.backward()

assert q1.shape == q1.grad.shape
assert p1.shape == p1.grad.shape
assert C11.shape == C11.grad.shape
assert C12.shape == C12.grad.shape
assert val.device == p1.device
assert q1.shape == q1.grad.shape
assert p1.shape == p1.grad.shape
assert C11.shape == C11.grad.shape
assert C12.shape == C12.grad.shape


@pytest.skip_backend("jax", reason="test very slow with jax backend")
Expand Down Expand Up @@ -636,21 +641,26 @@ def test_fgw2_gradients():

if torch:

p1 = torch.tensor(p, requires_grad=True)
q1 = torch.tensor(q, requires_grad=True)
C11 = torch.tensor(C1, requires_grad=True)
C12 = torch.tensor(C2, requires_grad=True)
M1 = torch.tensor(M, requires_grad=True)

val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1)

val.backward()

assert q1.shape == q1.grad.shape
assert p1.shape == p1.grad.shape
assert C11.shape == C11.grad.shape
assert C12.shape == C12.grad.shape
assert M1.shape == M1.grad.shape
devices = [torch.device("cpu")]
if torch.cuda.is_available():
devices.append(torch.device("cuda"))
for device in devices:
p1 = torch.tensor(p, requires_grad=True, device=device)
q1 = torch.tensor(q, requires_grad=True, device=device)
C11 = torch.tensor(C1, requires_grad=True, device=device)
C12 = torch.tensor(C2, requires_grad=True, device=device)
M1 = torch.tensor(M, requires_grad=True, device=device)

val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1)

val.backward()

assert val.device == p1.device
assert q1.shape == q1.grad.shape
assert p1.shape == p1.grad.shape
assert C11.shape == C11.grad.shape
assert C12.shape == C12.grad.shape
assert M1.shape == M1.grad.shape


def test_fgw_barycenter(nx):
Expand Down

0 comments on commit 9412f0a

Please sign in to comment.