-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Model][Sampler] GraphSAGE model, bipartite graph conversion & remove…
… edges API (dmlc#1297) * remove edge and to bipartite and graphsage with sampling * fixes * fixes * fixes * reenable multigpu training * fixes * compatibility in DGLGraph * rename to compact_as_bipartite * bugfix * lint * add offline inference * skip GPU tests * fix * addresses comments * fix * fix * fix * more tests * more docs and unit tests * workaround for empty slice on empty data
- Loading branch information
Showing
26 changed files
with
1,515 additions
and
130 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,303 @@ | ||
import dgl | ||
import numpy as np | ||
import torch as th | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
import torch.multiprocessing as mp | ||
import dgl.function as fn | ||
import dgl.nn.pytorch as dglnn | ||
import time | ||
import argparse | ||
from _thread import start_new_thread | ||
from functools import wraps | ||
from dgl.data import RedditDataset | ||
from torch.nn.parallel import DistributedDataParallel | ||
import tqdm | ||
|
||
#### Neighbor sampler | ||
|
||
class NeighborSampler(object): | ||
def __init__(self, g, fanouts): | ||
self.g = g | ||
self.fanouts = fanouts | ||
|
||
def sample_blocks(self, seeds): | ||
blocks = [] | ||
for fanout in self.fanouts: | ||
# For each seed node, sample ``fanout`` neighbors. | ||
frontier = dgl.sampling.sample_neighbors(g, seeds, fanout) | ||
# Then we compact the frontier into a bipartite graph for message passing. | ||
block = dgl.to_block(frontier, seeds) | ||
# Obtain the seed nodes for next layer. | ||
seeds = block.srcdata[dgl.NID] | ||
|
||
blocks.insert(0, block) | ||
return blocks | ||
|
||
class SAGE(nn.Module): | ||
def __init__(self, | ||
in_feats, | ||
n_hidden, | ||
n_classes, | ||
n_layers, | ||
activation, | ||
dropout): | ||
super().__init__() | ||
self.n_layers = n_layers | ||
self.n_hidden = n_hidden | ||
self.n_classes = n_classes | ||
self.layers = nn.ModuleList() | ||
self.layers.append(dglnn.SAGEConv( | ||
in_feats, n_hidden, 'mean', feat_drop=dropout, activation=activation)) | ||
for i in range(1, n_layers - 1): | ||
self.layers.append(dglnn.SAGEConv( | ||
n_hidden, n_hidden, 'mean', feat_drop=dropout, activation=activation)) | ||
self.layers.append(dglnn.SAGEConv( | ||
n_hidden, n_classes, 'mean', feat_drop=dropout)) | ||
|
||
def forward(self, blocks, x): | ||
h = x | ||
for layer, block in zip(self.layers, blocks): | ||
# We need to first copy the representation of nodes on the RHS from the | ||
# appropriate nodes on the LHS. | ||
# Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst | ||
# would be (num_nodes_RHS, D) | ||
h_dst = h[:block.number_of_nodes(block.dsttype)] | ||
# Then we compute the updated representation on the RHS. | ||
# The shape of h now becomes (num_nodes_RHS, D) | ||
h = layer(block, (h, h_dst)) | ||
return h | ||
|
||
def inference(self, g, x, batch_size, device): | ||
""" | ||
Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling). | ||
g : the entire graph. | ||
x : the input of entire node set. | ||
The inference code is written in a fashion that it could handle any number of nodes and | ||
layers. | ||
""" | ||
# During inference with sampling, multi-layer blocks are very inefficient because | ||
# lots of computations in the first few layers are repeated. | ||
# Therefore, we compute the representation of all nodes layer by layer. The nodes | ||
# on each layer are of course splitted in batches. | ||
# TODO: can we standardize this? | ||
nodes = th.arange(g.number_of_nodes()) | ||
for l, layer in enumerate(self.layers): | ||
y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes) | ||
|
||
for start in tqdm.trange(0, len(nodes), batch_size): | ||
end = start + batch_size | ||
batch_nodes = nodes[start:end] | ||
block = dgl.to_block(dgl.in_subgraph(g, batch_nodes), batch_nodes) | ||
induced_nodes = block.srcdata[dgl.NID] | ||
|
||
h = x[induced_nodes].to(device) | ||
h_dst = h[:block.number_of_nodes(block.dsttype)] | ||
h = layer(block, (h, h_dst)) | ||
|
||
y[start:end] = h.cpu() | ||
|
||
x = y | ||
return y | ||
|
||
#### Miscellaneous functions | ||
|
||
# According to https://github.com/pytorch/pytorch/issues/17199, this decorator | ||
# is necessary to make fork() and openmp work together. | ||
# | ||
# TODO: confirm if this is necessary for MXNet and Tensorflow. If so, we need | ||
# to standardize worker process creation since our operators are implemented with | ||
# OpenMP. | ||
def thread_wrapped_func(func): | ||
""" | ||
Wraps a process entry point to make it work with OpenMP. | ||
""" | ||
@wraps(func) | ||
def decorated_function(*args, **kwargs): | ||
queue = mp.Queue() | ||
def _queue_result(): | ||
exception, trace, res = None, None, None | ||
try: | ||
res = func(*args, **kwargs) | ||
except Exception as e: | ||
exception = e | ||
trace = traceback.format_exc() | ||
queue.put((res, exception, trace)) | ||
|
||
start_new_thread(_queue_result, ()) | ||
result, exception, trace = queue.get() | ||
if exception is None: | ||
return result | ||
else: | ||
assert isinstance(exception, Exception) | ||
raise exception.__class__(trace) | ||
return decorated_function | ||
|
||
def compute_acc(pred, labels): | ||
""" | ||
Compute the accuracy of prediction given the labels. | ||
""" | ||
return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred) | ||
|
||
def evaluate(model, g, inputs, labels, val_mask, batch_size, device): | ||
""" | ||
Evaluate the model on the validation set specified by ``val_mask``. | ||
g : The entire graph. | ||
inputs : The features of all the nodes. | ||
labels : The labels of all the nodes. | ||
val_mask : A 0-1 mask indicating which nodes do we actually compute the accuracy for. | ||
batch_size : Number of nodes to compute at the same time. | ||
device : The GPU device to evaluate on. | ||
""" | ||
model.eval() | ||
with th.no_grad(): | ||
pred = model.inference(g, inputs, batch_size, device) | ||
model.train() | ||
return compute_acc(pred[val_mask], labels[val_mask]) | ||
|
||
def load_subtensor(g, labels, seeds, induced_nodes, dev_id): | ||
""" | ||
Copys features and labels of a set of nodes onto GPU. | ||
""" | ||
batch_inputs = g.ndata['features'][induced_nodes].to(dev_id) | ||
batch_labels = labels[seeds].to(dev_id) | ||
return batch_inputs, batch_labels | ||
|
||
#### Entry point | ||
|
||
@thread_wrapped_func | ||
def run(proc_id, n_gpus, args, devices, data): | ||
dropout = 0.2 | ||
|
||
# Start up distributed training, if enabled. | ||
dev_id = devices[proc_id] | ||
if n_gpus > 1: | ||
dist_init_method = 'tcp://{master_ip}:{master_port}'.format( | ||
master_ip='127.0.0.1', master_port='12345') | ||
world_size = n_gpus | ||
th.distributed.init_process_group(backend="nccl", | ||
init_method=dist_init_method, | ||
world_size=world_size, | ||
rank=dev_id) | ||
th.cuda.set_device(dev_id) | ||
|
||
# Unpack data | ||
train_mask, val_mask, in_feats, labels, n_classes, g = data | ||
train_nid = th.LongTensor(np.nonzero(train_mask)[0]) | ||
val_nid = th.LongTensor(np.nonzero(val_mask)[0]) | ||
train_mask = th.BoolTensor(train_mask) | ||
val_mask = th.BoolTensor(val_mask) | ||
|
||
# Split train_nid | ||
train_nid = th.split(train_nid, len(train_nid) // n_gpus)[dev_id] | ||
|
||
# Create sampler | ||
sampler = NeighborSampler(g, [args.fan_out] * args.num_layers) | ||
|
||
# Define model and optimizer | ||
model = SAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, dropout) | ||
model = model.to(dev_id) | ||
if n_gpus > 1: | ||
model = DistributedDataParallel(model, device_ids=[dev_id], output_device=dev_id) | ||
loss_fcn = nn.CrossEntropyLoss() | ||
loss_fcn = loss_fcn.to(dev_id) | ||
optimizer = optim.Adam(model.parameters(), lr=args.lr) | ||
|
||
# Training loop | ||
avg = 0 | ||
iter_tput = [] | ||
for epoch in range(args.num_epochs): | ||
tic = time.time() | ||
train_nid_batches = train_nid[th.randperm(len(train_nid))] | ||
n_batches = (len(train_nid_batches) + args.batch_size - 1) // args.batch_size | ||
for step in range(n_batches): | ||
seeds = train_nid_batches[step * args.batch_size:(step+1) * args.batch_size] | ||
if proc_id == 0: | ||
tic_step = time.time() | ||
|
||
# Sample blocks for message propagation | ||
blocks = sampler.sample_blocks(seeds) | ||
induced_nodes = blocks[0].srcdata[dgl.NID] | ||
# Load the input features as well as output labels | ||
batch_inputs, batch_labels = load_subtensor(g, labels, seeds, induced_nodes, dev_id) | ||
|
||
# Compute loss and prediction | ||
batch_pred = model(blocks, batch_inputs) | ||
loss = loss_fcn(batch_pred, batch_labels) | ||
optimizer.zero_grad() | ||
loss.backward() | ||
|
||
if n_gpus > 1: | ||
for param in model.parameters(): | ||
if param.requires_grad and param.grad is not None: | ||
th.distributed.all_reduce(param.grad.data, | ||
op=th.distributed.ReduceOp.SUM) | ||
param.grad.data /= n_gpus | ||
optimizer.step() | ||
|
||
if proc_id == 0: | ||
iter_tput.append(len(seeds) * n_gpus / (time.time() - tic_step)) | ||
if step % args.log_every == 0 and proc_id == 0: | ||
acc = compute_acc(batch_pred, batch_labels) | ||
print('Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f}'.format( | ||
epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:]))) | ||
|
||
if n_gpus > 1: | ||
th.distributed.barrier() | ||
|
||
toc = time.time() | ||
if proc_id == 0: | ||
print('Epoch Time(s): {:.4f}'.format(toc - tic)) | ||
if epoch >= 5: | ||
avg += toc - tic | ||
if epoch % args.eval_every == 0 and epoch != 0: | ||
eval_acc = evaluate(model, g, g.ndata['features'], labels, val_mask, args.batch_size, 0) | ||
print('Eval Acc {:.4f}'.format(eval_acc)) | ||
|
||
if n_gpus > 1: | ||
th.distributed.barrier() | ||
if proc_id == 0: | ||
print('Avg epoch time: {}'.format(avg / (epoch - 4))) | ||
|
||
if __name__ == '__main__': | ||
argparser = argparse.ArgumentParser("multi-gpu training") | ||
argparser.add_argument('--gpu', type=str, default='0') | ||
argparser.add_argument('--num-epochs', type=int, default=20) | ||
argparser.add_argument('--num-hidden', type=int, default=16) | ||
argparser.add_argument('--num-layers', type=int, default=2) | ||
argparser.add_argument('--fan-out', type=int, default=10) | ||
argparser.add_argument('--batch-size', type=int, default=1000) | ||
argparser.add_argument('--log-every', type=int, default=20) | ||
argparser.add_argument('--eval-every', type=int, default=5) | ||
argparser.add_argument('--lr', type=float, default=0.003) | ||
args = argparser.parse_args() | ||
|
||
devices = list(map(int, args.gpu.split(','))) | ||
n_gpus = len(devices) | ||
|
||
# load reddit data | ||
data = RedditDataset(self_loop=True) | ||
train_mask = data.train_mask | ||
val_mask = data.val_mask | ||
features = th.Tensor(data.features) | ||
in_feats = features.shape[1] | ||
labels = th.LongTensor(data.labels) | ||
n_classes = data.num_labels | ||
# Construct graph | ||
g = dgl.graph(data.graph.all_edges()) | ||
g.ndata['features'] = features | ||
# Pack data | ||
data = train_mask, val_mask, in_feats, labels, n_classes, g | ||
|
||
if n_gpus == 1: | ||
run(0, n_gpus, args, devices, data) | ||
else: | ||
procs = [] | ||
for proc_id in range(n_gpus): | ||
p = mp.Process(target=run, args=(proc_id, n_gpus, args, devices, data)) | ||
p.start() | ||
procs.append(p) | ||
for p in procs: | ||
p.join() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.