Skip to content

Commit

Permalink
[Example][Optimization] Performance optimization for graphsage unsupe…
Browse files Browse the repository at this point in the history
…rvised example (dmlc#1531)

* test

* profile

* opt

* Some fix

* upd

* upd

* Add multigpu training support for graphsage unsupervised

* Add share neg

* Fix

* Add profile

* turn on eval

* upd

* Fix

* performance opt

Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: Ubuntu <[email protected]>
Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
7 people authored May 27, 2020
1 parent 7639b5e commit 901e0c2
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 97 deletions.
37 changes: 2 additions & 35 deletions examples/pytorch/graphsage/train_sampling_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
import dgl.nn.pytorch as dglnn
import time
import argparse
from _thread import start_new_thread
from functools import wraps
from dgl.data import RedditDataset
from torch.nn.parallel import DistributedDataParallel
import tqdm
import traceback

from utils import thread_wrapped_func

#### Neighbor sampler

class NeighborSampler(object):
Expand Down Expand Up @@ -110,39 +110,6 @@ def inference(self, g, x, batch_size, device):
x = y
return y

#### Miscellaneous functions

# According to https://github.com/pytorch/pytorch/issues/17199, this decorator
# is necessary to make fork() and openmp work together.
#
# TODO: confirm if this is necessary for MXNet and Tensorflow. If so, we need
# to standardize worker process creation since our operators are implemented with
# OpenMP.
def thread_wrapped_func(func):
"""
Wraps a process entry point to make it work with OpenMP.
"""
@wraps(func)
def decorated_function(*args, **kwargs):
queue = mp.Queue()
def _queue_result():
exception, trace, res = None, None, None
try:
res = func(*args, **kwargs)
except Exception as e:
exception = e
trace = traceback.format_exc()
queue.put((res, exception, trace))

start_new_thread(_queue_result, ())
result, exception, trace = queue.get()
if exception is None:
return result
else:
assert isinstance(exception, Exception)
raise exception.__class__(trace)
return decorated_function

def prepare_mp(g):
"""
Explicitly materialize the CSR, CSC and COO representation of the given graph
Expand Down
179 changes: 117 additions & 62 deletions examples/pytorch/graphsage/train_sampling_unsupervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
from _thread import start_new_thread
from functools import wraps
from dgl.data import RedditDataset
from torch.nn.parallel import DistributedDataParallel
import tqdm
import traceback
import sklearn.linear_model as lm
import sklearn.metrics as skm

from utils import thread_wrapped_func

#### Negative sampler

class NegativeSampler(object):
Expand All @@ -30,18 +33,26 @@ def __call__(self, num_samples):
#### Neighbor sampler

class NeighborSampler(object):
def __init__(self, g, fanouts, num_negs):
def __init__(self, g, fanouts, num_negs, neg_share=False):
self.g = g
self.fanouts = fanouts
self.neg_sampler = NegativeSampler(g)
self.num_negs = num_negs
self.neg_share = neg_share

def sample_blocks(self, seed_edges):
n_edges = len(seed_edges)
seed_edges = th.LongTensor(np.asarray(seed_edges))
heads, tails = self.g.find_edges(seed_edges)
neg_tails = self.neg_sampler(self.num_negs * n_edges)
neg_heads = heads.view(-1, 1).expand(n_edges, self.num_negs).flatten()
if self.neg_share and n_edges % self.num_negs == 0:
neg_tails = self.neg_sampler(n_edges)
neg_tails = neg_tails.view(-1, 1, self.num_negs).expand(n_edges//self.num_negs,
self.num_negs,
self.num_negs).flatten()
neg_heads = heads.view(-1, 1).expand(n_edges, self.num_negs).flatten()
else:
neg_tails = self.neg_sampler(self.num_negs * n_edges)
neg_heads = heads.view(-1, 1).expand(n_edges, self.num_negs).flatten()

# Maintain the correspondence between heads, tails and negative tails as two
# graphs.
Expand All @@ -60,7 +71,7 @@ def sample_blocks(self, seed_edges):
blocks = []
for fanout in self.fanouts:
# For each seed node, sample ``fanout`` neighbors.
frontier = dgl.sampling.sample_neighbors(g, seeds, fanout, replace=True)
frontier = dgl.sampling.sample_neighbors(self.g, seeds, fanout, replace=True)
# Remove all edges between heads and tails, as well as heads and neg_tails.
_, _, edge_ids = frontier.edge_ids(
th.cat([heads, tails, neg_heads, neg_tails]),
Expand All @@ -69,12 +80,24 @@ def sample_blocks(self, seed_edges):
frontier = dgl.remove_edges(frontier, edge_ids)
# Then we compact the frontier into a bipartite graph for message passing.
block = dgl.to_block(frontier, seeds)

# Pre-generate CSR format that it can be used in training directly
block.in_degree(0)
# Obtain the seed nodes for next layer.
seeds = block.srcdata[dgl.NID]

blocks.insert(0, block)

# Pre-generate CSR format that it can be used in training directly
return pos_graph, neg_graph, blocks

def load_subtensor(g, input_nodes, device):
"""
Copys features and labels of a set of nodes onto GPU.
"""
batch_inputs = g.ndata['features'][input_nodes].to(device)
return batch_inputs

class SAGE(nn.Module):
def __init__(self,
in_feats,
Expand Down Expand Up @@ -163,18 +186,6 @@ def forward(self, block_outputs, pos_graph, neg_graph):
loss = F.binary_cross_entropy_with_logits(score, label.float())
return loss

def prepare_mp(g):
"""
Explicitly materialize the CSR, CSC and COO representation of the given graph
so that they could be shared via copy-on-write to sampler workers and GPU
trainers.
This is a workaround before full shared memory support on heterogeneous graphs.
"""
g.in_degree(0)
g.out_degree(0)
g.find_edges([0])

def compute_acc(emb, labels, train_nids, val_nids, test_nids):
"""
Compute the accuracy of prediction given the labels.
Expand Down Expand Up @@ -211,65 +222,84 @@ def evaluate(model, g, inputs, labels, train_nids, val_nids, test_nids, batch_si
"""
model.eval()
with th.no_grad():
pred = model.inference(g, inputs, batch_size, device)
# single gpu
if isinstance(model, SAGE):
pred = model.inference(g, inputs, batch_size, device)
# multi gpu
else:
pred = model.module.inference(g, inputs, batch_size, device)
model.train()
return compute_acc(pred, labels, train_nids, val_nids, test_nids)

def load_subtensor(g, seeds, input_nodes, device):
"""
Copys features and labels of a set of nodes onto GPU.
"""
batch_inputs = g.ndata['features'][input_nodes].to(device)
return batch_inputs

#### Entry point
def run(args, device, data):
def run(proc_id, n_gpus, args, devices, data):
# Unpack data
device = devices[proc_id]
if n_gpus > 1:
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(
master_ip='127.0.0.1', master_port='12345')
world_size = n_gpus
th.distributed.init_process_group(backend="nccl",
init_method=dist_init_method,
world_size=world_size,
rank=proc_id)
train_mask, val_mask, test_mask, in_feats, labels, n_classes, g = data

train_nid = th.LongTensor(np.nonzero(train_mask)[0])
val_nid = th.LongTensor(np.nonzero(val_mask)[0])
test_nid = th.LongTensor(np.nonzero(test_mask)[0])

# Create sampler
sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')], args.num_negs)
sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')], args.num_negs, args.neg_share)

# Create PyTorch DataLoader for constructing blocks
train_seeds = np.arange(g.number_of_edges())
if n_gpus > 0:
num_per_gpu = (train_seeds.shape[0] + n_gpus -1) // n_gpus
train_seeds = train_seeds[proc_id * num_per_gpu :
(proc_id + 1) * num_per_gpu \
if (proc_id + 1) * num_per_gpu < train_seeds.shape[0]
else train_seeds.shape[0]]

dataloader = DataLoader(
dataset=np.arange(g.number_of_edges()),
dataset=train_seeds,
batch_size=args.batch_size,
collate_fn=sampler.sample_blocks,
shuffle=True,
drop_last=False,
pin_memory=True,
num_workers=args.num_workers)

# Define model and optimizer
model = SAGE(in_feats, args.num_hidden, args.num_hidden, args.num_layers, F.relu, args.dropout)
model = model.to(device)
if n_gpus > 1:
model = DistributedDataParallel(model, device_ids=[device], output_device=device)
loss_fcn = CrossEntropyLoss()
loss_fcn = loss_fcn.to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr)

# Training loop
avg = 0
iter_tput = []
iter_pos = []
iter_neg = []
iter_d = []
iter_t = []
best_eval_acc = 0
best_test_acc = 0
for epoch in range(args.num_epochs):
tic = time.time()

# Loop over the dataloader to sample the computation dependency graph as a list of
# blocks.
for step, (pos_graph, neg_graph, blocks) in enumerate(dataloader):
tic_step = time.time()

tic_step = time.time()
for step, (pos_graph, neg_graph, blocks) in enumerate(dataloader):
# The nodes for input lies at the LHS side of the first block.
# The nodes for output lies at the RHS side of the last block.
input_nodes = blocks[0].srcdata[dgl.NID]
seeds = blocks[-1].dstdata[dgl.NID]

# Load the input features as well as output labels
batch_inputs = load_subtensor(g, seeds, input_nodes, device)
batch_inputs = load_subtensor(g, input_nodes, device)
d_step = time.time()

# Compute loss and prediction
batch_pred = model(blocks, batch_inputs)
Expand All @@ -278,30 +308,74 @@ def run(args, device, data):
loss.backward()
optimizer.step()

iter_tput.append(len(seeds) / (time.time() - tic_step))
t = time.time()
pos_edges = pos_graph.number_of_edges()
neg_edges = neg_graph.number_of_edges()
iter_pos.append(pos_edges / (t - tic_step))
iter_neg.append(neg_edges / (t - tic_step))
iter_d.append(d_step - tic_step)
iter_t.append(t - d_step)
if step % args.log_every == 0:
gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0
print('Epoch {:05d} | Step {:05d} | Loss {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MiB'.format(
epoch, step, loss.item(), np.mean(iter_tput[3:]), gpu_mem_alloc))
print('[{}]Epoch {:05d} | Step {:05d} | Loss {:.4f} | Speed (samples/sec) {:.4f}|{:.4f} | Load {:.4f}| train {:.4f} | GPU {:.1f} MiB'.format(
proc_id, epoch, step, loss.item(), np.mean(iter_pos[3:]), np.mean(iter_neg[3:]), np.mean(iter_d[3:]), np.mean(iter_t[3:]), gpu_mem_alloc))
tic_step = time.time()

if step % args.eval_every == 0:
if step % args.eval_every == 0 and proc_id == 0:
eval_acc, test_acc = evaluate(model, g, g.ndata['features'], labels, train_nid, val_nid, test_nid, args.batch_size, device)
print('Eval Acc {:.4f} Test Acc {:.4f}'.format(eval_acc, test_acc))
if eval_acc > best_eval_acc:
best_eval_acc = eval_acc
best_test_acc = test_acc
print('Best Eval Acc {:.4f} Test Acc {:.4f}'.format(best_eval_acc, best_test_acc))

if n_gpus > 1:
th.distributed.barrier()
print('Avg epoch time: {}'.format(avg / (epoch - 4)))

def main(args, devices):
# load reddit data
data = RedditDataset(self_loop=True)
train_mask = data.train_mask
val_mask = data.val_mask
test_mask = data.test_mask
features = th.Tensor(data.features)
in_feats = features.shape[1]
labels = th.LongTensor(data.labels)
n_classes = data.num_labels
# Construct graph
g = dgl.graph(data.graph.all_edges())
g.ndata['features'] = features
# Pack data
data = train_mask, val_mask, test_mask, in_feats, labels, n_classes, g

n_gpus = len(devices)
if devices[0] == -1:
run(0, 0, args, ['cpu'], data)
if n_gpus == 1:
run(0, n_gpus, args, devices, data)
else:
procs = []
for proc_id in range(n_gpus):
p = mp.Process(target=thread_wrapped_func(run),
args=(proc_id, n_gpus, args, devices, data))
p.start()
procs.append(p)
for p in procs:
p.join()

run(args, device, data)


if __name__ == '__main__':
argparser = argparse.ArgumentParser("multi-gpu training")
argparser.add_argument('--gpu', type=int, default=0,
help="GPU device ID. Use -1 for CPU training")
argparser.add_argument("--gpu", type=str, default='0',
help="GPU, can be a list of gpus for multi-gpu trianing, e.g., 0,1,2,3; -1 for CPU")
argparser.add_argument('--num-epochs', type=int, default=20)
argparser.add_argument('--num-hidden', type=int, default=16)
argparser.add_argument('--num-layers', type=int, default=2)
argparser.add_argument('--num-negs', type=int, default=1)
argparser.add_argument('--neg-share', default=False, action='store_true',
help="sharing neg nodes for positive nodes")
argparser.add_argument('--fan-out', type=str, default='10,25')
argparser.add_argument('--batch-size', type=int, default=10000)
argparser.add_argument('--log-every', type=int, default=20)
Expand All @@ -312,25 +386,6 @@ def run(args, device, data):
help="Number of sampling processes. Use 0 for no extra process.")
args = argparser.parse_args()

if args.gpu >= 0:
device = th.device('cuda:%d' % args.gpu)
else:
device = th.device('cpu')
devices = list(map(int, args.gpu.split(',')))

# load reddit data
data = RedditDataset(self_loop=True)
train_mask = data.train_mask
val_mask = data.val_mask
test_mask = data.test_mask
features = th.Tensor(data.features)
in_feats = features.shape[1]
labels = th.LongTensor(data.labels)
n_classes = data.num_labels
# Construct graph
g = dgl.graph(data.graph.all_edges())
g.ndata['features'] = features
prepare_mp(g)
# Pack data
data = train_mask, val_mask, test_mask, in_feats, labels, n_classes, g

run(args, device, data)
main(args, devices)
Loading

0 comments on commit 901e0c2

Please sign in to comment.