-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathbin.py
52 lines (40 loc) · 1.64 KB
/
bin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class KarateClub(InMemoryDataset):
def __init__(self, transform=None):
super(KarateClub, self).__init__('.', transform, None, None)
G = nx.karate_club_graph()
x = torch.eye(G.number_of_nodes(), dtype=torch.float)
adj = nx.to_scipy_sparse_matrix(G).tocoo()
row = torch.from_numpy(adj.row.astype(np.int64)).to(torch.long)
col = torch.from_numpy(adj.col.astype(np.int64)).to(torch.long)
edge_index = torch.stack([row, col], dim=0)
partition = C.best_partition(G)
y = torch.tensor([partition[i] for i in range(G.number_of_nodes())])
train_mask = torch.zeros(y.size(0), dtype=torch.bool)
for i in range(int(y.max()) + 1):
train_mask[(y == i).nonzero(as_tuple=False)[0]] = True
data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask)
self.data, self.slices = self.collate([data])
def visualize(h, color, epoch=None, loss=None):
plt.figure(figsize=(7,7))
plt.xticks([])
plt.yticks([])
if torch.is_tensor(h):
h = h.detach().cpu().numpy()
plt.scatter(h[:, 0], h[:, 1], s=140, c=color, cmap="Set2")
if epoch is not None and loss is not None:
plt.xlabel(f'Epoch: {epoch}, Loss: {loss.item():.4f}', fontsize=16)
else:
nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False,
node_color=color, cmap="Set2")
plt.show()
node_transforms = T.Compose([
T.Normalize([0.5], [0.5])
])
edge_transforms = T.Compose([
T.RandomEdgeDrop()
])
dataset = KarateClub()
data = dataset.data
x = data.x
edge_idx = data.edge_index
y = data.y