diff --git a/tests/compute/test_transform.py b/tests/compute/test_transform.py index 77487717a14b..a18333bd5c15 100644 --- a/tests/compute/test_transform.py +++ b/tests/compute/test_transform.py @@ -2356,6 +2356,8 @@ def test_module_laplacian_pe(idtype): @pytest.mark.parametrize('g', get_cases(['has_scalar_e_feature'])) def test_module_sign(g): import torch + + atol = 1e-06 ctx = F.ctx() g = g.to(ctx) @@ -2372,25 +2374,25 @@ def test_module_sign(g): transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', diffuse_op='raw') g = transform(g) target = torch.matmul(adj, g.ndata['h']) - assert torch.allclose(g.ndata['out_feat_1'], target) + assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol) transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', eweight_name='scalar_w', diffuse_op='raw') g = transform(g) target = torch.matmul(weight_adj, g.ndata['h']) - assert torch.allclose(g.ndata['out_feat_1'], target) + assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol) # rw adj_rw = torch.matmul(torch.diag(1 / adj.sum(dim=1)), adj) transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', diffuse_op='rw') g = transform(g) target = torch.matmul(adj_rw, g.ndata['h']) - assert torch.allclose(g.ndata['out_feat_1'], target) + assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol) weight_adj_rw = torch.matmul(torch.diag(1 / weight_adj.sum(dim=1)), weight_adj) transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', eweight_name='scalar_w', diffuse_op='rw') g = transform(g) target = torch.matmul(weight_adj_rw, g.ndata['h']) - assert torch.allclose(g.ndata['out_feat_1'], target) + assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol) # gcn raw_eweight = g.edata['scalar_w'] @@ -2401,7 +2403,7 @@ def test_module_sign(g): transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', diffuse_op='gcn') g = transform(g) target = torch.matmul(adj_gcn, g.ndata['h']) - assert torch.allclose(g.ndata['out_feat_1'], target) + assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol) gcn_norm = dgl.GCNNorm('scalar_w') g = gcn_norm(g) @@ -2412,20 +2414,20 @@ def test_module_sign(g): eweight_name='scalar_w', diffuse_op='gcn') g = transform(g) target = torch.matmul(weight_adj_gcn, g.ndata['h']) - assert torch.allclose(g.ndata['out_feat_1'], target) + assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol) # ppr alpha = 0.2 transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', diffuse_op='ppr', alpha=alpha) g = transform(g) target = (1 - alpha) * torch.matmul(adj_gcn, g.ndata['h']) + alpha * g.ndata['h'] - assert torch.allclose(g.ndata['out_feat_1'], target) + assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol) transform = dgl.SIGNDiffusion(k=1, in_feat_name='h', eweight_name='scalar_w', diffuse_op='ppr', alpha=alpha) g = transform(g) target = (1 - alpha) * torch.matmul(weight_adj_gcn, g.ndata['h']) + alpha * g.ndata['h'] - assert torch.allclose(g.ndata['out_feat_1'], target) + assert torch.allclose(g.ndata['out_feat_1'], target, atol=atol) @unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now') @parametrize_idtype