Skip to content

Commit

Permalink
Update test_transform.py (dmlc#4190)
Browse files Browse the repository at this point in the history
  • Loading branch information
mufeili authored Jun 29, 2022
1 parent 32f12ee commit 8b19c28
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions tests/compute/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2356,6 +2356,8 @@ def test_module_laplacian_pe(idtype):
@pytest.mark.parametrize('g', get_cases(['has_scalar_e_feature']))
def test_module_sign(g):
import torch

atol = 1e-06

ctx = F.ctx()
g = g.to(ctx)
Expand All @@ -2372,25 +2374,25 @@ def test_module_sign(g):
transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', diffuse_op='raw')
g = transform(g)
target = torch.matmul(adj, g.ndata['h'])
assert torch.allclose(g.ndata['out_feat_1'], target)
assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol)

transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', eweight_name='scalar_w', diffuse_op='raw')
g = transform(g)
target = torch.matmul(weight_adj, g.ndata['h'])
assert torch.allclose(g.ndata['out_feat_1'], target)
assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol)

# rw
adj_rw = torch.matmul(torch.diag(1 / adj.sum(dim=1)), adj)
transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', diffuse_op='rw')
g = transform(g)
target = torch.matmul(adj_rw, g.ndata['h'])
assert torch.allclose(g.ndata['out_feat_1'], target)
assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol)

weight_adj_rw = torch.matmul(torch.diag(1 / weight_adj.sum(dim=1)), weight_adj)
transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', eweight_name='scalar_w', diffuse_op='rw')
g = transform(g)
target = torch.matmul(weight_adj_rw, g.ndata['h'])
assert torch.allclose(g.ndata['out_feat_1'], target)
assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol)

# gcn
raw_eweight = g.edata['scalar_w']
Expand All @@ -2401,7 +2403,7 @@ def test_module_sign(g):
transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', diffuse_op='gcn')
g = transform(g)
target = torch.matmul(adj_gcn, g.ndata['h'])
assert torch.allclose(g.ndata['out_feat_1'], target)
assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol)

gcn_norm = dgl.GCNNorm('scalar_w')
g = gcn_norm(g)
Expand All @@ -2412,20 +2414,20 @@ def test_module_sign(g):
eweight_name='scalar_w', diffuse_op='gcn')
g = transform(g)
target = torch.matmul(weight_adj_gcn, g.ndata['h'])
assert torch.allclose(g.ndata['out_feat_1'], target)
assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol)

# ppr
alpha = 0.2
transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', diffuse_op='ppr', alpha=alpha)
g = transform(g)
target = (1 - alpha) * torch.matmul(adj_gcn, g.ndata['h']) + alpha * g.ndata['h']
assert torch.allclose(g.ndata['out_feat_1'], target)
assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol)

transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', eweight_name='scalar_w',
diffuse_op='ppr', alpha=alpha)
g = transform(g)
target = (1 - alpha) * torch.matmul(weight_adj_gcn, g.ndata['h']) + alpha * g.ndata['h']
assert torch.allclose(g.ndata['out_feat_1'], target)
assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol)

@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')
@parametrize_idtype
Expand Down

0 comments on commit 8b19c28

Please sign in to comment.