Skip to content

Commit

Permalink
[Bug Fix] Fix A Bug Related to GroupRevRes (dmlc#4181)
Browse files Browse the repository at this point in the history
Co-authored-by: Minjie Wang <[email protected]>
Co-authored-by: Xin Yao <[email protected]>
  • Loading branch information
3 people authored Jun 28, 2022
1 parent 1518861 commit a25a14f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
3 changes: 2 additions & 1 deletion python/dgl/nn/pytorch/conv/grouprevres.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def backward(ctx, *grad_outputs):
detached_inputs = tuple(detached_inputs)
temp_output = ctx.fn(*detached_inputs)

filtered_detached_inputs = tuple(filter(lambda x: x.requires_grad, detached_inputs))
filtered_detached_inputs = tuple(filter(lambda x: getattr(x, 'requires_grad', False),
detached_inputs))
gradients = torch.autograd.grad(outputs=(temp_output,),
inputs=filtered_detached_inputs + ctx.weights,
grad_outputs=grad_outputs)
Expand Down
9 changes: 5 additions & 4 deletions tests/pytorch/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,14 +508,14 @@ def test_egat_conv(g, idtype, out_node_feats, out_edge_feats, num_heads):
efeat = F.randn((g.number_of_edges(), 5))
egat = egat.to(ctx)
h, f = egat(g, nfeat, efeat)

th.save(egat, tmp_buffer)

assert h.shape == (g.number_of_nodes(), num_heads, out_node_feats)
assert f.shape == (g.number_of_edges(), num_heads, out_edge_feats)
_, _, attn = egat(g, nfeat, efeat, True)
assert attn.shape == (g.number_of_edges(), num_heads, 1)

@parametrize_idtype
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
@pytest.mark.parametrize('out_node_feats', [1, 5])
Expand All @@ -533,7 +533,7 @@ def test_egat_conv_bi(g, idtype, out_node_feats, out_edge_feats, num_heads):
efeat = F.randn((g.number_of_edges(), 7))
egat = egat.to(ctx)
h, f = egat(g, nfeat, efeat)

th.save(egat, tmp_buffer)

assert h.shape == (g.number_of_dst_nodes(), num_heads, out_node_feats)
Expand Down Expand Up @@ -1473,7 +1473,8 @@ def test_group_rev_res(idtype):
h = th.randn(num_nodes, feats).to(dev)
conv = nn.GraphConv(feats // groups, feats // groups)
model = nn.GroupRevRes(conv, groups).to(dev)
model(g, h)
result = model(g, h)
result.sum().backward()

@pytest.mark.parametrize('in_size', [16, 32])
@pytest.mark.parametrize('hidden_size', [16, 32])
Expand Down

0 comments on commit a25a14f

Please sign in to comment.