Skip to content

Commit

Permalink
[Model] Fix broken CDGNN example (dmlc#111)
Browse files Browse the repository at this point in the history
* pretty printer

* Conflicts:
	python/dgl/data/sbm.py

* refined line_graph implementation

* fix broken api calls

* small fix to trigger CI

* requested change
  • Loading branch information
GaiYu0 authored Nov 4, 2018
1 parent b355d1e commit b420a5b
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 143 deletions.
4 changes: 3 additions & 1 deletion examples/pytorch/line_graph/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,7 @@ python train.py

An experiment on the Stochastic Block Model in customized settings can be run with
```bash
python train.py --batch-size BATCH_SIZE --gpu GPU --n-communities N_COMMUNITIES --n-features N_FEATURES --n-graphs N_GRAPH --n-iterations N_ITERATIONS --n-layers N_LAYER --n-nodes N_NODE --model-path MODEL_PATH --radius RADIUS
python train.py --batch-size BATCH_SIZE --gpu GPU --n-communities N_COMMUNITIES \
--n-features N_FEATURES --n-graphs N_GRAPH --n-iterations N_ITERATIONS \
--n-layers N_LAYER --n-nodes N_NODE --model-path MODEL_PATH --radius RADIUS
```
70 changes: 27 additions & 43 deletions examples/pytorch/line_graph/gnn.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,3 @@
"""
Supervised Community Detection with Hierarchical Graph Neural Networks
https://arxiv.org/abs/1705.08415
Deviations from paper:
- Pm Pd
"""


import copy
import itertools
import dgl
Expand All @@ -16,59 +7,58 @@
import torch.nn as nn
import torch.nn.functional as F


class GNNModule(nn.Module):
def __init__(self, in_feats, out_feats, radius):
super().__init__()
self.out_feats = out_feats
self.radius = radius

new_linear = lambda: nn.Linear(in_feats, out_feats * 2)
new_module_list = lambda: nn.ModuleList([new_linear() for i in range(radius)])
new_linear = lambda: nn.Linear(in_feats, out_feats)
new_linear_list = lambda: nn.ModuleList([new_linear() for i in range(radius)])

self.theta_x, self.theta_deg, self.theta_y = \
new_linear(), new_linear(), new_linear()
self.theta_list = new_module_list()
self.theta_list = new_linear_list()

self.gamma_y, self.gamma_deg, self.gamma_x = \
new_linear(), new_linear(), new_linear()
self.gamma_list = new_module_list()
self.gamma_list = new_linear_list()

self.bn_x = nn.BatchNorm1d(out_feats)
self.bn_y = nn.BatchNorm1d(out_feats)

def aggregate(self, g, z):
z_list = []
g.set_n_repr(z)
g.update_all(fn.copy_src(), fn.sum())
z_list.append(g.get_n_repr())
g.set_n_repr({'z' : z})
g.update_all(fn.copy_src(src='z', out='m'), fn.sum(msg='m', out='z'))
z_list.append(g.get_n_repr()['z'])
for i in range(self.radius - 1):
for j in range(2 ** i):
g.update_all(fn.copy_src(), fn.sum())
z_list.append(g.get_n_repr())
g.update_all(fn.copy_src(src='z', out='m'), fn.sum(msg='m', out='z'))
z_list.append(g.get_n_repr()['z'])
return z_list

def forward(self, g, lg, x, y, deg_g, deg_lg, eid2nid):
xy = F.embedding(eid2nid, x)
def forward(self, g, lg, x, y, deg_g, deg_lg, pm_pd):
pmpd_x = F.embedding(pm_pd, x)

x_list = [theta(z) for theta, z in zip(self.theta_list, self.aggregate(g, x))]
sum_x = sum(theta(z) for theta, z in zip(self.theta_list, self.aggregate(g, x)))

g.set_e_repr(y)
g.update_all(fn.copy_edge(), fn.sum())
yx = g.get_n_repr()
g.set_e_repr({'y' : y})
g.update_all(fn.copy_edge(edge='y', out='m'), fn.sum('m', 'pmpd_y'))
pmpd_y = g.pop_n_repr('pmpd_y')

x = self.theta_x(x) + self.theta_deg(deg_g * x) + sum(x_list) + self.theta_y(yx)
x = self.bn_x(x[:, :self.out_feats] + F.relu(x[:, self.out_feats:]))
x = self.theta_x(x) + self.theta_deg(deg_g * x) + sum_x + self.theta_y(pmpd_y)
n = self.out_feats // 2
x = th.cat([x[:, :n], F.relu(x[:, n:])], 1)
x = self.bn_x(x)

y_list = [gamma(z) for gamma, z in zip(self.gamma_list, self.aggregate(lg, y))]
lg.set_n_repr(xy)
lg.update_all(fn.copy_src(), fn.sum())
xy = lg.get_n_repr()
y = self.gamma_y(y) + self.gamma_deg(deg_lg * y) + sum(y_list) + self.gamma_x(xy)
y = self.bn_y(y[:, :self.out_feats] + F.relu(y[:, self.out_feats:]))
sum_y = sum(gamma(z) for gamma, z in zip(self.gamma_list, self.aggregate(lg, y)))

return x, y
y = self.gamma_y(y) + self.gamma_deg(deg_lg * y) + sum_y + self.gamma_x(pmpd_x)
y = th.cat([y[:, :n], F.relu(y[:, n:])], 1)
y = self.bn_y(y)

return x, y

class GNN(nn.Module):
def __init__(self, feats, radius, n_classes):
Expand All @@ -82,14 +72,8 @@ def __init__(self, feats, radius, n_classes):
self.module_list = nn.ModuleList([GNNModule(m, n, radius)
for m, n in zip(feats[:-1], feats[1:])])

def forward(self, g, lg, deg_g, deg_lg, eid2nid):
def normalize(x):
x = x - th.mean(x, 0)
x = x / th.sqrt(th.mean(x * x, 0))
return x

x = normalize(deg_g)
y = normalize(deg_lg)
def forward(self, g, lg, deg_g, deg_lg, pm_pd):
x, y = deg_g, deg_lg
for module in self.module_list:
x, y = module(g, lg, x, y, deg_g, deg_lg, eid2nid)
x, y = module(g, lg, x, y, deg_g, deg_lg, pm_pd)
return self.linear(x)
150 changes: 103 additions & 47 deletions examples/pytorch/line_graph/train.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,124 @@
"""
Supervised Community Detection with Hierarchical Graph Neural Networks
https://arxiv.org/abs/1705.08415
Author's implementation: https://github.com/joanbruna/GNN_community
"""

from __future__ import division
import time

import argparse
from itertools import permutations

import networkx as nx
import torch as th
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

import dgl
from dgl.data import SBMMixture
import gnn
import utils

parser = argparse.ArgumentParser()
parser.add_argument('--batch-size', type=int,
help='Batch size', default=1)
parser.add_argument('--gpu', type=int,
help='GPU', default=-1)
parser.add_argument('--n-communities', type=int,
help='Number of communities', default=2)
parser.add_argument('--n-features', type=int,
help='Number of features per layer', default=2)
parser.add_argument('--n-graphs', type=int,
help='Number of graphs', default=6000)
parser.add_argument('--n-iterations', type=int,
help='Number of iterations', default=10000)
parser.add_argument('--n-layers', type=int,
help='Number of layers', default=30)
parser.add_argument('--n-nodes', type=int,
help='Number of nodes', default=1000)
parser.add_argument('--model-path', type=str,
help='Path to the checkpoint of model', default='model')
parser.add_argument('--radius', type=int,
help='Radius', default=3)
parser.add_argument('--batch-size', type=int, help='Batch size', default=1)
parser.add_argument('--gpu', type=int, help='GPU index', default=-1)
parser.add_argument('--lr', type=float, help='Learning rate', default=0.001)
parser.add_argument('--n-communities', type=int, help='Number of communities', default=2)
parser.add_argument('--n-epochs', type=int, help='Number of epochs', default=100)
parser.add_argument('--n-features', type=int, help='Number of features', default=16)
parser.add_argument('--n-graphs', type=int, help='Number of graphs', default=10)
parser.add_argument('--n-layers', type=int, help='Number of layers', default=30)
parser.add_argument('--n-nodes', type=int, help='Number of nodes', default=10000)
parser.add_argument('--optim', type=str, help='Optimizer', default='Adam')
parser.add_argument('--radius', type=int, help='Radius', default=3)
parser.add_argument('--verbose', action='store_true')
args = parser.parse_args()

dev = th.device('cpu') if args.gpu < 0 else th.device('cuda:%d' % args.gpu)
K = args.n_communities

training_dataset = SBMMixture(args.n_graphs, args.n_nodes, K)
training_loader = DataLoader(training_dataset, args.batch_size,
collate_fn=training_dataset.collate_fn, drop_last=True)

ones = th.ones(args.n_nodes // K)
y_list = [th.cat([x * ones for x in p]).long().to(dev) for p in permutations(range(K))]

feats = [1] + [args.n_features] * args.n_layers + [K]
model = gnn.GNN(feats, args.radius, K).to(dev)
optimizer = getattr(optim, args.optim)(model.parameters(), lr=args.lr)

def compute_overlap(z_list):
ybar_list = [th.max(z, 1)[1] for z in z_list]
overlap_list = []
for y_bar in ybar_list:
accuracy = max(th.sum(y_bar == y).item() for y in y_list) / args.n_nodes
overlap = (accuracy - 1 / K) / (1 - 1 / K)
overlap_list.append(overlap)
return sum(overlap_list) / len(overlap_list)

def step(i, j, g, lg, deg_g, deg_lg, pm_pd):
""" One step of training. """
t0 = time.time()
z = model(g, lg, deg_g, deg_lg, pm_pd)
t_forward = time.time() - t0

dataset = SBMMixture(args.n_graphs, args.n_nodes, args.n_communities)
loader = utils.cycle(DataLoader(dataset, args.batch_size,
shuffle=True, collate_fn=dataset.collate_fn, drop_last=True))

ones = th.ones(args.n_nodes // args.n_communities)
y_list = [th.cat([th.cat([x * ones for x in p])] * args.batch_size).long().to(dev)
for p in permutations(range(args.n_communities))]

feats = [1] + [args.n_features] * args.n_layers + [args.n_communities]
model = gnn.GNN(feats, args.radius, args.n_communities).to(dev)
opt = optim.Adamax(model.parameters(), lr=0.04)

for i in range(args.n_iterations):
g, lg, deg_g, deg_lg, eid2nid = next(loader)
deg_g = deg_g.to(dev)
deg_lg = deg_lg.to(dev)
eid2nid = eid2nid.to(dev)
y_bar = model(g, lg, deg_g, deg_lg, eid2nid)
loss = min(F.cross_entropy(y_bar, y) for y in y_list)
opt.zero_grad()
z_list = th.chunk(z, args.batch_size, 0)
loss = sum(min(F.cross_entropy(z, y) for y in y_list) for z in z_list) / args.batch_size
overlap = compute_overlap(z_list)

optimizer.zero_grad()
t0 = time.time()
loss.backward()
opt.step()
t_backward = time.time() - t0
optimizer.step()

return loss, overlap, t_forward, t_backward

def test():
p_list =[6, 5.5, 5, 4.5, 1.5, 1, 0.5, 0]
q_list =[0, 0.5, 1, 1.5, 4.5, 5, 5.5, 6]
N = 1
overlap_list = []
for p, q in zip(p_list, q_list):
dataset = SBMMixture(N, args.n_nodes, K, pq=[[p, q]] * N)
loader = DataLoader(dataset, N, collate_fn=dataset.collate_fn)
g, lg, deg_g, deg_lg, pm_pd = next(iter(loader))
deg_g = deg_g.to(dev)
deg_lg = deg_lg.to(dev)
pm_pd = pm_pd.to(dev)
z = model(g, lg, deg_g, deg_lg, pm_pd)
overlap_list.append(compute_overlap(th.chunk(z, N, 0)))
return overlap_list

n_iterations = args.n_graphs // args.batch_size
for i in range(args.n_epochs):
total_loss, total_overlap, s_forward, s_backward = 0, 0, 0, 0
for j, [g, lg, deg_g, deg_lg, pm_pd] in enumerate(training_loader):
deg_g = deg_g.to(dev)
deg_lg = deg_lg.to(dev)
pm_pd = pm_pd.to(dev)
loss, overlap, t_forward, t_backward = step(i, j, g, lg, deg_g, deg_lg, pm_pd)

total_loss += loss
total_overlap += overlap
s_forward += t_forward
s_backward += t_backward

epoch = '0' * (len(str(args.n_epochs)) - len(str(i)))
iteration = '0' * (len(str(n_iterations)) - len(str(j)))
if args.verbose:
print('[epoch %s%d iteration %s%d]loss %.3f | overlap %.3f'
% (epoch, i, iteration, j, loss, overlap))

placeholder = '0' * (len(str(args.n_iterations)) - len(str(i)))
print('[iteration %s%d]loss %f' % (placeholder, i, loss))
epoch = '0' * (len(str(args.n_epochs)) - len(str(i)))
loss = total_loss / (j + 1)
overlap = total_overlap / (j + 1)
t_forward = s_forward / (j + 1)
t_backward = s_backward / (j + 1)
print('[epoch %s%d]loss %.3f | overlap %.3f | forward time %.3fs | backward time %.3fs'
% (epoch, i, loss, overlap, t_forward, t_backward))

th.save(model.state_dict(), args.model_path)
overlap_list = test()
overlap_str = ' - '.join(['%.3f' % overlap for overlap in overlap_list])
print('[epoch %s%d]overlap: %s' % (epoch, i, overlap_str))
4 changes: 0 additions & 4 deletions examples/pytorch/line_graph/utils.py

This file was deleted.

37 changes: 23 additions & 14 deletions python/dgl/data/sbm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
import os
import pickle
import random

import numpy as np
import numpy.random as npr
Expand Down Expand Up @@ -68,48 +69,56 @@ class SBMMixture:
Multiplier.
avg_deg : int, optional
Average degree.
p : callable or str, optional
Random density generator.
pq : list of pair of nonnegative float or str, optional
Random densities.
rng : numpy.random.RandomState, optional
Random number generator.
"""
def __init__(self, n_graphs, n_nodes, n_communities,
k=2, avg_deg=3, p='Appendix C', rng=None):
k=2, avg_deg=3, pq='Appendix C', rng=None):
self._n_nodes = n_nodes
assert n_nodes % n_communities == 0
block_size = n_nodes // n_communities
if type(p) is str:
p = {'Appendix C' : self._appendix_c}[p]
self._k = k
self._avg_deg = avg_deg
self._gs = [DGLGraph() for i in range(n_graphs)]
adjs = [sbm(n_communities, block_size, *p()) for i in range(n_graphs)]
if type(pq) is list:
assert len(pq) == n_graphs
elif type(pq) is str:
generator = {'Appendix C' : self._appendix_c}[pq]
pq = [generator() for i in range(n_graphs)]
else:
raise RuntimeError()
adjs = [sbm(n_communities, block_size, *x) for x in pq]
for g, adj in zip(self._gs, adjs):
g.from_scipy_sparse_matrix(adj)
self._lgs = [g.line_graph() for g in self._gs]
self._lgs = [g.line_graph(backtracking=False) for g in self._gs]
in_degrees = lambda g: g.in_degrees(Index(F.arange(g.number_of_nodes(),
dtype=F.int64))).unsqueeze(1).float()
dtype=F.int64))).unsqueeze(1).float()
self._g_degs = [in_degrees(g) for g in self._gs]
self._lg_degs = [in_degrees(lg) for lg in self._lgs]
self._eid2nids = list(zip(*[g.edges(sorted=True) for g in self._gs]))[0]
self._pm_pds = list(zip(*[g.edges() for g in self._gs]))[0]

def __len__(self):
return len(self._gs)

def __getitem__(self, idx):
return self._gs[idx], self._lgs[idx], \
self._g_degs[idx], self._lg_degs[idx], self._eid2nids[idx]
self._g_degs[idx], self._lg_degs[idx], self._pm_pds[idx]

def _appendix_c(self):
q = npr.uniform(0, self._avg_deg - math.sqrt(self._avg_deg))
p = self._k * self._avg_deg - q
return p, q
if random.random() < 0.5:
return p, q
else:
return q, p

def collate_fn(self, x):
g, lg, deg_g, deg_lg, eid2nid = zip(*x)
g, lg, deg_g, deg_lg, pm_pd = zip(*x)
g_batch = batch(g)
lg_batch = batch(lg)
degg_batch = F.pack(deg_g)
deglg_batch = F.pack(deg_lg)
eid2nid_batch = F.pack([x + i * self._n_nodes for i, x in enumerate(eid2nid)])
return g_batch, lg_batch, degg_batch, deglg_batch, eid2nid_batch
pm_pd_batch = F.pack([x + i * self._n_nodes for i, x in enumerate(pm_pd)])
return g_batch, lg_batch, degg_batch, deglg_batch, pm_pd_batch
Loading

0 comments on commit b420a5b

Please sign in to comment.