-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgin_conv.patch
45 lines (39 loc) · 1.55 KB
/
gin_conv.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
--- gin_conv_ori.py 2024-07-31 16:28:05.702803897 +0000
+++ gin_conv.py 2024-07-31 16:30:46.620322883 +0000
@@ -54,8 +54,8 @@
:math:`(|\mathcal{V}_t|, F_{out})` if bipartite
"""
def __init__(self, nn: Callable, eps: float = 0., train_eps: bool = False,
- **kwargs):
- kwargs.setdefault('aggr', 'add')
+ save_intermediate=False, **kwargs):
+ kwargs.setdefault('aggr', 'max')
super().__init__(**kwargs)
self.nn = nn
self.initial_eps = eps
@@ -64,6 +64,7 @@
else:
self.register_buffer('eps', torch.empty(1))
self.reset_parameters()
+ self.save_intermediate = save_intermediate #InkStream Patch
def reset_parameters(self):
super().reset_parameters()
@@ -79,15 +80,22 @@
if isinstance(x, Tensor):
x = (x, x)
+
+ intermediate_result = {} #InkStream Patch
+ if not self.training and self.save_intermediate: #InkStream Patch
+ intermediate_result["a-"] = x[0].detach() #InkStream Patch
# propagate_type: (x: OptPairTensor)
out = self.propagate(edge_index, x=x, size=size)
+ if not self.training and self.save_intermediate: #InkStream Patch
+ intermediate_result["a"] = out.detach() #InkStream Patch
+
x_r = x[1]
if x_r is not None:
out = out + (1 + self.eps) * x_r
- return self.nn(out)
+ return self.nn(out), intermediate_result #InkStream Patch
def message(self, x_j: Tensor) -> Tensor:
return x_j