-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathpyg_torch_scatter_test.py
132 lines (97 loc) · 3.69 KB
/
pyg_torch_scatter_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#!/usr/bin/env python3
# Copyright (c) 2022 Graphcore Ltd. All rights reserved.
# Tests for PyG torch_scatter ops integration with PopTorch
from functools import partial
import torch
import pytest
import helpers
import poptorch
if helpers.is_running_tests:
from torch_scatter import scatter, scatter_log_softmax, scatter_softmax, scatter_std, scatter_logsumexp, scatter_add, scatter_max, scatter_min, scatter_mul
else:
def scatter():
pass
def scatter_log_softmax():
pass
def scatter_softmax():
pass
def scatter_std():
pass
def scatter_add():
pass
def scatter_max():
pass
def scatter_min():
pass
def scatter_mul():
pass
def scatter_logsumexp():
pass
def torch_scatter_harness(func, src, index, out=None):
dim_size = int(index.max()) + 1
class Model(torch.nn.Module):
def forward(self, src, index, out=None):
if out is None:
return func(src, index, dim_size=dim_size)
return func(src, index, out=out, dim_size=dim_size)
model = Model()
poptorch_model = poptorch.inferenceModel(model)
out_in_plac_native = None
if out is not None:
out_in_plac_native = out.clone()
native_out = func(src,
index,
out=out_in_plac_native,
dim_size=dim_size)
ipu_out = poptorch_model(src, index, out=out)
else:
native_out = func(src, index, dim_size=dim_size)
ipu_out = poptorch_model(src, index)
helpers.assert_allclose(actual=ipu_out, expected=native_out)
if out is not None:
helpers.assert_allclose(actual=out, expected=out_in_plac_native)
poptorch_model.destroy()
@pytest.mark.parametrize("reduce", ['sum', 'mean', 'max', 'min', 'mul'])
def test_scatter(reduce):
func = partial(scatter, reduce=reduce)
src = torch.tensor([1, 3, 2, 4, 5, 6]).float()
index = torch.tensor([0, 1, 0, 1, 1, 3]).long()
torch_scatter_harness(func, src, index)
@pytest.mark.parametrize(
"func",
[scatter_log_softmax, scatter_logsumexp, scatter_softmax, scatter_std])
def test_composites(func):
src = torch.tensor([1, 3, 2, 4, 5, 6]).float()
index = torch.tensor([0, 1, 0, 1, 5, 3]).long()
torch_scatter_harness(func, src, index)
@pytest.mark.parametrize("func", [scatter_max, scatter_min, scatter_mul])
def test_scatter_inplace(func):
src = torch.tensor([1, 3, 2, 4, 5, 6]).float()
index = torch.tensor([0, 1, 4, 2, 3, 5]).long()
out = torch.tensor([10, 1, 11, 1, 23, 1]).float()
torch_scatter_harness(func, src, index, out)
@helpers.printCapfdOnExit
@helpers.overridePoptorchLogLevel("TRACE")
def test_scatter_add_zeros_optimized(capfd):
src = torch.tensor([1, 3, 2, 4, 5, 6]).float()
index = torch.tensor([0, 1, 0, 1, 1, 3]).long()
torch_scatter_harness(scatter_add, src, index)
it = helpers.LogChecker(capfd).createIterator()
it.findNext("Removing zeros output to scatter_add")
@helpers.printCapfdOnExit
@helpers.overridePoptorchLogLevel("TRACE")
def test_scatter_add_nd_expand_removed(capfd):
torch.manual_seed(0)
src = torch.randn(10, 6, 16)
index = torch.tensor([0, 1, 0, 1, 1, 3]).long()
func = partial(scatter_add, dim=1)
torch_scatter_harness(func, src, index)
it = helpers.LogChecker(capfd).createIterator()
it.findNext("Removing index expansion node:")
@pytest.mark.parametrize("shape", [(5, ), (2, 5), (2, 5, 5)])
@pytest.mark.parametrize("func", [scatter_max, scatter_min, scatter_mul])
def test_scatter_overloads(shape, func):
torch.manual_seed(0)
x = torch.rand(shape)
ind = torch.randint(3, shape)
torch_scatter_harness(func, x, ind)