-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Model] Fix broken CDGNN example (dmlc#111)
* pretty printer * Conflicts: python/dgl/data/sbm.py * refined line_graph implementation * fix broken api calls * small fix to trigger CI * requested change
- Loading branch information
Showing
6 changed files
with
164 additions
and
143 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,68 +1,124 @@ | ||
""" | ||
Supervised Community Detection with Hierarchical Graph Neural Networks | ||
https://arxiv.org/abs/1705.08415 | ||
Author's implementation: https://github.com/joanbruna/GNN_community | ||
""" | ||
|
||
from __future__ import division | ||
import time | ||
|
||
import argparse | ||
from itertools import permutations | ||
|
||
import networkx as nx | ||
import torch as th | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
from torch.utils.data import DataLoader | ||
|
||
import dgl | ||
from dgl.data import SBMMixture | ||
import gnn | ||
import utils | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--batch-size', type=int, | ||
help='Batch size', default=1) | ||
parser.add_argument('--gpu', type=int, | ||
help='GPU', default=-1) | ||
parser.add_argument('--n-communities', type=int, | ||
help='Number of communities', default=2) | ||
parser.add_argument('--n-features', type=int, | ||
help='Number of features per layer', default=2) | ||
parser.add_argument('--n-graphs', type=int, | ||
help='Number of graphs', default=6000) | ||
parser.add_argument('--n-iterations', type=int, | ||
help='Number of iterations', default=10000) | ||
parser.add_argument('--n-layers', type=int, | ||
help='Number of layers', default=30) | ||
parser.add_argument('--n-nodes', type=int, | ||
help='Number of nodes', default=1000) | ||
parser.add_argument('--model-path', type=str, | ||
help='Path to the checkpoint of model', default='model') | ||
parser.add_argument('--radius', type=int, | ||
help='Radius', default=3) | ||
parser.add_argument('--batch-size', type=int, help='Batch size', default=1) | ||
parser.add_argument('--gpu', type=int, help='GPU index', default=-1) | ||
parser.add_argument('--lr', type=float, help='Learning rate', default=0.001) | ||
parser.add_argument('--n-communities', type=int, help='Number of communities', default=2) | ||
parser.add_argument('--n-epochs', type=int, help='Number of epochs', default=100) | ||
parser.add_argument('--n-features', type=int, help='Number of features', default=16) | ||
parser.add_argument('--n-graphs', type=int, help='Number of graphs', default=10) | ||
parser.add_argument('--n-layers', type=int, help='Number of layers', default=30) | ||
parser.add_argument('--n-nodes', type=int, help='Number of nodes', default=10000) | ||
parser.add_argument('--optim', type=str, help='Optimizer', default='Adam') | ||
parser.add_argument('--radius', type=int, help='Radius', default=3) | ||
parser.add_argument('--verbose', action='store_true') | ||
args = parser.parse_args() | ||
|
||
dev = th.device('cpu') if args.gpu < 0 else th.device('cuda:%d' % args.gpu) | ||
K = args.n_communities | ||
|
||
training_dataset = SBMMixture(args.n_graphs, args.n_nodes, K) | ||
training_loader = DataLoader(training_dataset, args.batch_size, | ||
collate_fn=training_dataset.collate_fn, drop_last=True) | ||
|
||
ones = th.ones(args.n_nodes // K) | ||
y_list = [th.cat([x * ones for x in p]).long().to(dev) for p in permutations(range(K))] | ||
|
||
feats = [1] + [args.n_features] * args.n_layers + [K] | ||
model = gnn.GNN(feats, args.radius, K).to(dev) | ||
optimizer = getattr(optim, args.optim)(model.parameters(), lr=args.lr) | ||
|
||
def compute_overlap(z_list): | ||
ybar_list = [th.max(z, 1)[1] for z in z_list] | ||
overlap_list = [] | ||
for y_bar in ybar_list: | ||
accuracy = max(th.sum(y_bar == y).item() for y in y_list) / args.n_nodes | ||
overlap = (accuracy - 1 / K) / (1 - 1 / K) | ||
overlap_list.append(overlap) | ||
return sum(overlap_list) / len(overlap_list) | ||
|
||
def step(i, j, g, lg, deg_g, deg_lg, pm_pd): | ||
""" One step of training. """ | ||
t0 = time.time() | ||
z = model(g, lg, deg_g, deg_lg, pm_pd) | ||
t_forward = time.time() - t0 | ||
|
||
dataset = SBMMixture(args.n_graphs, args.n_nodes, args.n_communities) | ||
loader = utils.cycle(DataLoader(dataset, args.batch_size, | ||
shuffle=True, collate_fn=dataset.collate_fn, drop_last=True)) | ||
|
||
ones = th.ones(args.n_nodes // args.n_communities) | ||
y_list = [th.cat([th.cat([x * ones for x in p])] * args.batch_size).long().to(dev) | ||
for p in permutations(range(args.n_communities))] | ||
|
||
feats = [1] + [args.n_features] * args.n_layers + [args.n_communities] | ||
model = gnn.GNN(feats, args.radius, args.n_communities).to(dev) | ||
opt = optim.Adamax(model.parameters(), lr=0.04) | ||
|
||
for i in range(args.n_iterations): | ||
g, lg, deg_g, deg_lg, eid2nid = next(loader) | ||
deg_g = deg_g.to(dev) | ||
deg_lg = deg_lg.to(dev) | ||
eid2nid = eid2nid.to(dev) | ||
y_bar = model(g, lg, deg_g, deg_lg, eid2nid) | ||
loss = min(F.cross_entropy(y_bar, y) for y in y_list) | ||
opt.zero_grad() | ||
z_list = th.chunk(z, args.batch_size, 0) | ||
loss = sum(min(F.cross_entropy(z, y) for y in y_list) for z in z_list) / args.batch_size | ||
overlap = compute_overlap(z_list) | ||
|
||
optimizer.zero_grad() | ||
t0 = time.time() | ||
loss.backward() | ||
opt.step() | ||
t_backward = time.time() - t0 | ||
optimizer.step() | ||
|
||
return loss, overlap, t_forward, t_backward | ||
|
||
def test(): | ||
p_list =[6, 5.5, 5, 4.5, 1.5, 1, 0.5, 0] | ||
q_list =[0, 0.5, 1, 1.5, 4.5, 5, 5.5, 6] | ||
N = 1 | ||
overlap_list = [] | ||
for p, q in zip(p_list, q_list): | ||
dataset = SBMMixture(N, args.n_nodes, K, pq=[[p, q]] * N) | ||
loader = DataLoader(dataset, N, collate_fn=dataset.collate_fn) | ||
g, lg, deg_g, deg_lg, pm_pd = next(iter(loader)) | ||
deg_g = deg_g.to(dev) | ||
deg_lg = deg_lg.to(dev) | ||
pm_pd = pm_pd.to(dev) | ||
z = model(g, lg, deg_g, deg_lg, pm_pd) | ||
overlap_list.append(compute_overlap(th.chunk(z, N, 0))) | ||
return overlap_list | ||
|
||
n_iterations = args.n_graphs // args.batch_size | ||
for i in range(args.n_epochs): | ||
total_loss, total_overlap, s_forward, s_backward = 0, 0, 0, 0 | ||
for j, [g, lg, deg_g, deg_lg, pm_pd] in enumerate(training_loader): | ||
deg_g = deg_g.to(dev) | ||
deg_lg = deg_lg.to(dev) | ||
pm_pd = pm_pd.to(dev) | ||
loss, overlap, t_forward, t_backward = step(i, j, g, lg, deg_g, deg_lg, pm_pd) | ||
|
||
total_loss += loss | ||
total_overlap += overlap | ||
s_forward += t_forward | ||
s_backward += t_backward | ||
|
||
epoch = '0' * (len(str(args.n_epochs)) - len(str(i))) | ||
iteration = '0' * (len(str(n_iterations)) - len(str(j))) | ||
if args.verbose: | ||
print('[epoch %s%d iteration %s%d]loss %.3f | overlap %.3f' | ||
% (epoch, i, iteration, j, loss, overlap)) | ||
|
||
placeholder = '0' * (len(str(args.n_iterations)) - len(str(i))) | ||
print('[iteration %s%d]loss %f' % (placeholder, i, loss)) | ||
epoch = '0' * (len(str(args.n_epochs)) - len(str(i))) | ||
loss = total_loss / (j + 1) | ||
overlap = total_overlap / (j + 1) | ||
t_forward = s_forward / (j + 1) | ||
t_backward = s_backward / (j + 1) | ||
print('[epoch %s%d]loss %.3f | overlap %.3f | forward time %.3fs | backward time %.3fs' | ||
% (epoch, i, loss, overlap, t_forward, t_backward)) | ||
|
||
th.save(model.state_dict(), args.model_path) | ||
overlap_list = test() | ||
overlap_str = ' - '.join(['%.3f' % overlap for overlap in overlap_list]) | ||
print('[epoch %s%d]overlap: %s' % (epoch, i, overlap_str)) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.