From 20e1bb455b0304d6d39125b9d4f528fe17947c27 Mon Sep 17 00:00:00 2001 From: "Quan (Andy) Gan" Date: Tue, 10 Mar 2020 09:22:41 +0800 Subject: [PATCH] rewrite to use dataloader (#1333) Co-authored-by: Minjie Wang --- .../pytorch/graphsage/graphsage_sampling.py | 66 ++++++++++++++----- 1 file changed, 48 insertions(+), 18 deletions(-) diff --git a/examples/pytorch/graphsage/graphsage_sampling.py b/examples/pytorch/graphsage/graphsage_sampling.py index 160781f9b41c..39f9aee86baf 100644 --- a/examples/pytorch/graphsage/graphsage_sampling.py +++ b/examples/pytorch/graphsage/graphsage_sampling.py @@ -5,6 +5,7 @@ import torch.nn.functional as F import torch.optim as optim import torch.multiprocessing as mp +from torch.utils.data import DataLoader import dgl.function as fn import dgl.nn.pytorch as dglnn import time @@ -23,10 +24,11 @@ def __init__(self, g, fanouts): self.fanouts = fanouts def sample_blocks(self, seeds): + seeds = th.LongTensor(np.asarray(seeds)) blocks = [] for fanout in self.fanouts: # For each seed node, sample ``fanout`` neighbors. - frontier = dgl.sampling.sample_neighbors(g, seeds, fanout) + frontier = dgl.sampling.sample_neighbors(g, seeds, fanout, replace=True) # Then we compact the frontier into a bipartite graph for message passing. block = dgl.to_block(frontier, seeds) # Obtain the seed nodes for next layer. @@ -91,9 +93,9 @@ def inference(self, g, x, batch_size, device): end = start + batch_size batch_nodes = nodes[start:end] block = dgl.to_block(dgl.in_subgraph(g, batch_nodes), batch_nodes) - induced_nodes = block.srcdata[dgl.NID] + input_nodes = block.srcdata[dgl.NID] - h = x[induced_nodes].to(device) + h = x[input_nodes].to(device) h_dst = h[:block.number_of_nodes(block.dsttype)] h = layer(block, (h, h_dst)) @@ -135,6 +137,18 @@ def _queue_result(): raise exception.__class__(trace) return decorated_function +def prepare_mp(g): + """ + Explicitly materialize the CSR, CSC and COO representation of the given graph + so that they could be shared via copy-on-write to sampler workers and GPU + trainers. + + This is a workaround before full shared memory support on heterogeneous graphs. + """ + g.in_degree(0) + g.out_degree(0) + g.find_edges([0]) + def compute_acc(pred, labels): """ Compute the accuracy of prediction given the labels. @@ -157,11 +171,11 @@ def evaluate(model, g, inputs, labels, val_mask, batch_size, device): model.train() return compute_acc(pred[val_mask], labels[val_mask]) -def load_subtensor(g, labels, seeds, induced_nodes, dev_id): +def load_subtensor(g, labels, seeds, input_nodes, dev_id): """ Copys features and labels of a set of nodes onto GPU. """ - batch_inputs = g.ndata['features'][induced_nodes].to(dev_id) + batch_inputs = g.ndata['features'][input_nodes].to(dev_id) batch_labels = labels[seeds].to(dev_id) return batch_inputs, batch_labels @@ -194,7 +208,16 @@ def run(proc_id, n_gpus, args, devices, data): train_nid = th.split(train_nid, len(train_nid) // n_gpus)[dev_id] # Create sampler - sampler = NeighborSampler(g, [args.fan_out] * args.num_layers) + sampler = NeighborSampler(g, [int(fanout) for fanout in args.fan_out.split(',')]) + + # Create PyTorch DataLoader for constructing blocks + dataloader = DataLoader( + dataset=train_nid.numpy(), + batch_size=args.batch_size, + collate_fn=sampler.sample_blocks, + shuffle=True, + drop_last=False, + num_workers=args.num_workers_per_gpu) # Define model and optimizer model = SAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, dropout) @@ -210,18 +233,20 @@ def run(proc_id, n_gpus, args, devices, data): iter_tput = [] for epoch in range(args.num_epochs): tic = time.time() - train_nid_batches = train_nid[th.randperm(len(train_nid))] - n_batches = (len(train_nid_batches) + args.batch_size - 1) // args.batch_size - for step in range(n_batches): - seeds = train_nid_batches[step * args.batch_size:(step+1) * args.batch_size] + + # Loop over the dataloader to sample the computation dependency graph as a list of + # blocks. + for step, blocks in enumerate(dataloader): if proc_id == 0: tic_step = time.time() - # Sample blocks for message propagation - blocks = sampler.sample_blocks(seeds) - induced_nodes = blocks[0].srcdata[dgl.NID] + # The nodes for input lies at the LHS side of the first block. + # The nodes for output lies at the RHS side of the last block. + input_nodes = blocks[0].srcdata[dgl.NID] + seeds = blocks[-1].dstdata[dgl.NID] + # Load the input features as well as output labels - batch_inputs, batch_labels = load_subtensor(g, labels, seeds, induced_nodes, dev_id) + batch_inputs, batch_labels = load_subtensor(g, labels, seeds, input_nodes, dev_id) # Compute loss and prediction batch_pred = model(blocks, batch_inputs) @@ -241,8 +266,8 @@ def run(proc_id, n_gpus, args, devices, data): iter_tput.append(len(seeds) * n_gpus / (time.time() - tic_step)) if step % args.log_every == 0 and proc_id == 0: acc = compute_acc(batch_pred, batch_labels) - print('Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f}'.format( - epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:]))) + print('Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MiB'.format( + epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:]), th.cuda.max_memory_allocated() / 1000000)) if n_gpus > 1: th.distributed.barrier() @@ -253,7 +278,10 @@ def run(proc_id, n_gpus, args, devices, data): if epoch >= 5: avg += toc - tic if epoch % args.eval_every == 0 and epoch != 0: - eval_acc = evaluate(model, g, g.ndata['features'], labels, val_mask, args.batch_size, 0) + if n_gpus == 1: + eval_acc = evaluate(model, g, g.ndata['features'], labels, val_mask, args.batch_size, 0) + else: + eval_acc = evaluate(model.module, g, g.ndata['features'], labels, val_mask, args.batch_size, 0) print('Eval Acc {:.4f}'.format(eval_acc)) if n_gpus > 1: @@ -267,11 +295,12 @@ def run(proc_id, n_gpus, args, devices, data): argparser.add_argument('--num-epochs', type=int, default=20) argparser.add_argument('--num-hidden', type=int, default=16) argparser.add_argument('--num-layers', type=int, default=2) - argparser.add_argument('--fan-out', type=int, default=10) + argparser.add_argument('--fan-out', type=str, default='10,25') argparser.add_argument('--batch-size', type=int, default=1000) argparser.add_argument('--log-every', type=int, default=20) argparser.add_argument('--eval-every', type=int, default=5) argparser.add_argument('--lr', type=float, default=0.003) + argparser.add_argument('--num-workers-per-gpu', type=int, default=0) args = argparser.parse_args() devices = list(map(int, args.gpu.split(','))) @@ -288,6 +317,7 @@ def run(proc_id, n_gpus, args, devices, data): # Construct graph g = dgl.graph(data.graph.all_edges()) g.ndata['features'] = features + prepare_mp(g) # Pack data data = train_mask, val_mask, in_feats, labels, n_classes, g