-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathutils.py
89 lines (71 loc) · 2.83 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
from torch_geometric.utils import to_networkx, degree
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
def convert_to_nodeDegreeFeatures(graphs):
graph_infos = []
maxdegree = 0
for i, graph in enumerate(graphs):
g = to_networkx(graph, to_undirected=True)
gdegree = max(dict(g.degree).values())
if gdegree > maxdegree:
maxdegree = gdegree
graph_infos.append((graph, g.degree, graph.num_nodes)) # (graph, node_degrees, num_nodes)
new_graphs = []
for i, tuple in enumerate(graph_infos):
idx, x = tuple[0].edge_index[0], tuple[0].x
deg = degree(idx, tuple[2], dtype=torch.long)
deg = F.one_hot(deg, num_classes=maxdegree + 1).to(torch.float)
new_graph = tuple[0].clone()
new_graph.__setitem__('x', deg)
new_graphs.append(new_graph)
return new_graphs
def get_maxDegree(graphs):
maxdegree = 0
for i, graph in enumerate(graphs):
g = to_networkx(graph, to_undirected=True)
gdegree = max(dict(g.degree).values())
if gdegree > maxdegree:
maxdegree = gdegree
return maxdegree
def use_node_attributes(graphs):
num_node_attributes = graphs.num_node_attributes
new_graphs = []
for i, graph in enumerate(graphs):
new_graph = graph.clone()
new_graph.__setitem__('x', graph.x[:, :num_node_attributes])
new_graphs.append(new_graph)
return new_graphs
def split_data(graphs, train=None, test=None, shuffle=True, seed=None):
y = torch.cat([graph.y for graph in graphs])
graphs_tv, graphs_test = train_test_split(graphs, train_size=train, test_size=test, stratify=y, shuffle=shuffle, random_state=seed)
return graphs_tv, graphs_test
def get_numGraphLabels(dataset):
s = set()
for g in dataset:
s.add(g.y.item())
return len(s)
def _get_avg_nodes_edges(graphs):
numNodes = 0.
numEdges = 0.
numGraphs = len(graphs)
for g in graphs:
numNodes += g.num_nodes
numEdges += g.num_edges / 2. # undirected
return numNodes/numGraphs, numEdges/numGraphs
def get_stats(df, ds, graphs_train, graphs_val=None, graphs_test=None):
df.loc[ds, "#graphs_train"] = len(graphs_train)
avgNodes, avgEdges = _get_avg_nodes_edges(graphs_train)
df.loc[ds, 'avgNodes_train'] = avgNodes
df.loc[ds, 'avgEdges_train'] = avgEdges
if graphs_val:
df.loc[ds, '#graphs_val'] = len(graphs_val)
avgNodes, avgEdges = _get_avg_nodes_edges(graphs_val)
df.loc[ds, 'avgNodes_val'] = avgNodes
df.loc[ds, 'avgEdges_val'] = avgEdges
if graphs_test:
df.loc[ds, '#graphs_test'] = len(graphs_test)
avgNodes, avgEdges = _get_avg_nodes_edges(graphs_test)
df.loc[ds, 'avgNodes_test'] = avgNodes
df.loc[ds, 'avgEdges_test'] = avgEdges
return df