Skip to content

Commit

Permalink
[Model][Sampler] GraphSAGE model, bipartite graph conversion & remove…
Browse files Browse the repository at this point in the history
… edges API (dmlc#1297)

* remove edge and to bipartite and graphsage with sampling

* fixes

* fixes

* fixes

* reenable multigpu training

* fixes

* compatibility in DGLGraph

* rename to compact_as_bipartite

* bugfix

* lint

* add offline inference

* skip GPU tests

* fix

* addresses comments

* fix

* fix

* fix

* more tests

* more docs and unit tests

* workaround for empty slice on empty data
  • Loading branch information
BarclayII authored Mar 7, 2020
1 parent ce6e19f commit a9520f7
Show file tree
Hide file tree
Showing 26 changed files with 1,515 additions and 130 deletions.
303 changes: 303 additions & 0 deletions examples/pytorch/graphsage/graphsage_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,303 @@
import dgl
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import time
import argparse
from _thread import start_new_thread
from functools import wraps
from dgl.data import RedditDataset
from torch.nn.parallel import DistributedDataParallel
import tqdm

#### Neighbor sampler

class NeighborSampler(object):
def __init__(self, g, fanouts):
self.g = g
self.fanouts = fanouts

def sample_blocks(self, seeds):
blocks = []
for fanout in self.fanouts:
# For each seed node, sample ``fanout`` neighbors.
frontier = dgl.sampling.sample_neighbors(g, seeds, fanout)
# Then we compact the frontier into a bipartite graph for message passing.
block = dgl.to_block(frontier, seeds)
# Obtain the seed nodes for next layer.
seeds = block.srcdata[dgl.NID]

blocks.insert(0, block)
return blocks

class SAGE(nn.Module):
def __init__(self,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout):
super().__init__()
self.n_layers = n_layers
self.n_hidden = n_hidden
self.n_classes = n_classes
self.layers = nn.ModuleList()
self.layers.append(dglnn.SAGEConv(
in_feats, n_hidden, 'mean', feat_drop=dropout, activation=activation))
for i in range(1, n_layers - 1):
self.layers.append(dglnn.SAGEConv(
n_hidden, n_hidden, 'mean', feat_drop=dropout, activation=activation))
self.layers.append(dglnn.SAGEConv(
n_hidden, n_classes, 'mean', feat_drop=dropout))

def forward(self, blocks, x):
h = x
for layer, block in zip(self.layers, blocks):
# We need to first copy the representation of nodes on the RHS from the
# appropriate nodes on the LHS.
# Note that the shape of h is (num_nodes_LHS, D) and the shape of h_dst
# would be (num_nodes_RHS, D)
h_dst = h[:block.number_of_nodes(block.dsttype)]
# Then we compute the updated representation on the RHS.
# The shape of h now becomes (num_nodes_RHS, D)
h = layer(block, (h, h_dst))
return h

def inference(self, g, x, batch_size, device):
"""
Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling).
g : the entire graph.
x : the input of entire node set.
The inference code is written in a fashion that it could handle any number of nodes and
layers.
"""
# During inference with sampling, multi-layer blocks are very inefficient because
# lots of computations in the first few layers are repeated.
# Therefore, we compute the representation of all nodes layer by layer. The nodes
# on each layer are of course splitted in batches.
# TODO: can we standardize this?
nodes = th.arange(g.number_of_nodes())
for l, layer in enumerate(self.layers):
y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)

for start in tqdm.trange(0, len(nodes), batch_size):
end = start + batch_size
batch_nodes = nodes[start:end]
block = dgl.to_block(dgl.in_subgraph(g, batch_nodes), batch_nodes)
induced_nodes = block.srcdata[dgl.NID]

h = x[induced_nodes].to(device)
h_dst = h[:block.number_of_nodes(block.dsttype)]
h = layer(block, (h, h_dst))

y[start:end] = h.cpu()

x = y
return y

#### Miscellaneous functions

# According to https://github.com/pytorch/pytorch/issues/17199, this decorator
# is necessary to make fork() and openmp work together.
#
# TODO: confirm if this is necessary for MXNet and Tensorflow. If so, we need
# to standardize worker process creation since our operators are implemented with
# OpenMP.
def thread_wrapped_func(func):
"""
Wraps a process entry point to make it work with OpenMP.
"""
@wraps(func)
def decorated_function(*args, **kwargs):
queue = mp.Queue()
def _queue_result():
exception, trace, res = None, None, None
try:
res = func(*args, **kwargs)
except Exception as e:
exception = e
trace = traceback.format_exc()
queue.put((res, exception, trace))

start_new_thread(_queue_result, ())
result, exception, trace = queue.get()
if exception is None:
return result
else:
assert isinstance(exception, Exception)
raise exception.__class__(trace)
return decorated_function

def compute_acc(pred, labels):
"""
Compute the accuracy of prediction given the labels.
"""
return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)

def evaluate(model, g, inputs, labels, val_mask, batch_size, device):
"""
Evaluate the model on the validation set specified by ``val_mask``.
g : The entire graph.
inputs : The features of all the nodes.
labels : The labels of all the nodes.
val_mask : A 0-1 mask indicating which nodes do we actually compute the accuracy for.
batch_size : Number of nodes to compute at the same time.
device : The GPU device to evaluate on.
"""
model.eval()
with th.no_grad():
pred = model.inference(g, inputs, batch_size, device)
model.train()
return compute_acc(pred[val_mask], labels[val_mask])

def load_subtensor(g, labels, seeds, induced_nodes, dev_id):
"""
Copys features and labels of a set of nodes onto GPU.
"""
batch_inputs = g.ndata['features'][induced_nodes].to(dev_id)
batch_labels = labels[seeds].to(dev_id)
return batch_inputs, batch_labels

#### Entry point

@thread_wrapped_func
def run(proc_id, n_gpus, args, devices, data):
dropout = 0.2

# Start up distributed training, if enabled.
dev_id = devices[proc_id]
if n_gpus > 1:
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
master_ip='127.0.0.1', master_port='12345')
world_size = n_gpus
th.distributed.init_process_group(backend="nccl",
init_method=dist_init_method,
world_size=world_size,
rank=dev_id)
th.cuda.set_device(dev_id)

# Unpack data
train_mask, val_mask, in_feats, labels, n_classes, g = data
train_nid = th.LongTensor(np.nonzero(train_mask)[0])
val_nid = th.LongTensor(np.nonzero(val_mask)[0])
train_mask = th.BoolTensor(train_mask)
val_mask = th.BoolTensor(val_mask)

# Split train_nid
train_nid = th.split(train_nid, len(train_nid) // n_gpus)[dev_id]

# Create sampler
sampler = NeighborSampler(g, [args.fan_out] * args.num_layers)

# Define model and optimizer
model = SAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, dropout)
model = model.to(dev_id)
if n_gpus > 1:
model = DistributedDataParallel(model, device_ids=[dev_id], output_device=dev_id)
loss_fcn = nn.CrossEntropyLoss()
loss_fcn = loss_fcn.to(dev_id)
optimizer = optim.Adam(model.parameters(), lr=args.lr)

# Training loop
avg = 0
iter_tput = []
for epoch in range(args.num_epochs):
tic = time.time()
train_nid_batches = train_nid[th.randperm(len(train_nid))]
n_batches = (len(train_nid_batches) + args.batch_size - 1) // args.batch_size
for step in range(n_batches):
seeds = train_nid_batches[step * args.batch_size:(step+1) * args.batch_size]
if proc_id == 0:
tic_step = time.time()

# Sample blocks for message propagation
blocks = sampler.sample_blocks(seeds)
induced_nodes = blocks[0].srcdata[dgl.NID]
# Load the input features as well as output labels
batch_inputs, batch_labels = load_subtensor(g, labels, seeds, induced_nodes, dev_id)

# Compute loss and prediction
batch_pred = model(blocks, batch_inputs)
loss = loss_fcn(batch_pred, batch_labels)
optimizer.zero_grad()
loss.backward()

if n_gpus > 1:
for param in model.parameters():
if param.requires_grad and param.grad is not None:
th.distributed.all_reduce(param.grad.data,
op=th.distributed.ReduceOp.SUM)
param.grad.data /= n_gpus
optimizer.step()

if proc_id == 0:
iter_tput.append(len(seeds) * n_gpus / (time.time() - tic_step))
if step % args.log_every == 0 and proc_id == 0:
acc = compute_acc(batch_pred, batch_labels)
print('Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f}'.format(
epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:])))

if n_gpus > 1:
th.distributed.barrier()

toc = time.time()
if proc_id == 0:
print('Epoch Time(s): {:.4f}'.format(toc - tic))
if epoch >= 5:
avg += toc - tic
if epoch % args.eval_every == 0 and epoch != 0:
eval_acc = evaluate(model, g, g.ndata['features'], labels, val_mask, args.batch_size, 0)
print('Eval Acc {:.4f}'.format(eval_acc))

if n_gpus > 1:
th.distributed.barrier()
if proc_id == 0:
print('Avg epoch time: {}'.format(avg / (epoch - 4)))

if __name__ == '__main__':
argparser = argparse.ArgumentParser("multi-gpu training")
argparser.add_argument('--gpu', type=str, default='0')
argparser.add_argument('--num-epochs', type=int, default=20)
argparser.add_argument('--num-hidden', type=int, default=16)
argparser.add_argument('--num-layers', type=int, default=2)
argparser.add_argument('--fan-out', type=int, default=10)
argparser.add_argument('--batch-size', type=int, default=1000)
argparser.add_argument('--log-every', type=int, default=20)
argparser.add_argument('--eval-every', type=int, default=5)
argparser.add_argument('--lr', type=float, default=0.003)
args = argparser.parse_args()

devices = list(map(int, args.gpu.split(',')))
n_gpus = len(devices)

# load reddit data
data = RedditDataset(self_loop=True)
train_mask = data.train_mask
val_mask = data.val_mask
features = th.Tensor(data.features)
in_feats = features.shape[1]
labels = th.LongTensor(data.labels)
n_classes = data.num_labels
# Construct graph
g = dgl.graph(data.graph.all_edges())
g.ndata['features'] = features
# Pack data
data = train_mask, val_mask, in_feats, labels, n_classes, g

if n_gpus == 1:
run(0, n_gpus, args, devices, data)
else:
procs = []
for proc_id in range(n_gpus):
p = mp.Process(target=run, args=(proc_id, n_gpus, args, devices, data))
p.start()
procs.append(p)
for p in procs:
p.join()
34 changes: 24 additions & 10 deletions include/dgl/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,22 @@ typedef NDArray TypeArray;
* \brief Sparse format.
*/
enum class SparseFormat {
ANY = 0,
COO = 1,
CSR = 2,
CSC = 3
kAny = 0,
kCOO = 1,
kCSR = 2,
kCSC = 3
};

// Parse sparse format from string.
inline SparseFormat ParseSparseFormat(const std::string& name) {
if (name == "coo")
return SparseFormat::COO;
return SparseFormat::kCOO;
else if (name == "csr")
return SparseFormat::CSR;
return SparseFormat::kCSR;
else if (name == "csc")
return SparseFormat::CSC;
return SparseFormat::kCSC;
else
return SparseFormat::ANY;
return SparseFormat::kAny;
}

// Sparse matrix object that is exposed to python API.
Expand Down Expand Up @@ -328,7 +328,7 @@ struct CSRMatrix {

// Convert to a SparseMatrix object that can return to python.
SparseMatrix ToSparseMatrix() const {
return SparseMatrix(static_cast<int32_t>(SparseFormat::CSR), num_rows,
return SparseMatrix(static_cast<int32_t>(SparseFormat::kCSR), num_rows,
num_cols, {indptr, indices, data}, {sorted});
}

Expand Down Expand Up @@ -408,7 +408,7 @@ struct COOMatrix {

// Convert to a SparseMatrix object that can return to python.
SparseMatrix ToSparseMatrix() const {
return SparseMatrix(static_cast<int32_t>(SparseFormat::COO), num_rows,
return SparseMatrix(static_cast<int32_t>(SparseFormat::kCOO), num_rows,
num_cols, {row, col, data}, {row_sorted, col_sorted});
}

Expand Down Expand Up @@ -548,6 +548,13 @@ bool CSRHasDuplicate(CSRMatrix csr);
*/
void CSRSort_(CSRMatrix* csr);

/*!
* \brief Remove entries from CSR matrix by entry indices (data indices)
* \return A new CSR matrix as well as a mapping from the new CSR entries to the old CSR
* entries.
*/
CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries);

/*!
* \brief Randomly select a fixed number of non-zero entries along each given row independently.
*
Expand Down Expand Up @@ -723,6 +730,13 @@ std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo);
*/
COOMatrix COOSort(COOMatrix mat, bool sort_column = false);

/*!
* \brief Remove entries from COO matrix by entry indices (data indices)
* \return A new COO matrix as well as a mapping from the new COO entries to the old COO
* entries.
*/
COOMatrix COORemove(COOMatrix coo, IdArray entries);

/*!
* \brief Randomly select a fixed number of non-zero entries along each given row independently.
*
Expand Down
Loading

0 comments on commit a9520f7

Please sign in to comment.