Skip to content

Commit

Permalink
[Example] Move Data to GPU before Minibatch Training (dmlc#2453)
Browse files Browse the repository at this point in the history
* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update
  • Loading branch information
mufeili authored Dec 28, 2020
1 parent 72ef642 commit 927d2b3
Show file tree
Hide file tree
Showing 8 changed files with 203 additions and 219 deletions.
74 changes: 40 additions & 34 deletions examples/pytorch/graphsage/train_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,12 @@
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import time
import argparse
from _thread import start_new_thread
from functools import wraps
from dgl.data import RedditDataset
import tqdm
import traceback

from load_graph import load_reddit, load_ogb, inductive_split
from load_graph import load_reddit, inductive_split

class SAGE(nn.Module):
def __init__(self,
Expand Down Expand Up @@ -47,7 +40,7 @@ def forward(self, blocks, x):
h = self.dropout(h)
return h

def inference(self, g, x, batch_size, device):
def inference(self, g, x, device):
"""
Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling).
g : the entire graph.
Expand All @@ -62,12 +55,12 @@ def inference(self, g, x, batch_size, device):
# on each layer are of course splitted in batches.
# TODO: can we standardize this?
for l, layer in enumerate(self.layers):
y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)
y = th.zeros(g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)

sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.NodeDataLoader(
g,
th.arange(g.number_of_nodes()),
th.arange(g.num_nodes()),
sampler,
batch_size=args.batch_size,
shuffle=True,
Expand Down Expand Up @@ -96,34 +89,35 @@ def compute_acc(pred, labels):
labels = labels.long()
return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)

def evaluate(model, g, inputs, labels, val_nid, batch_size, device):
def evaluate(model, g, nfeat, labels, val_nid, device):
"""
Evaluate the model on the validation set specified by ``val_nid``.
g : The entire graph.
inputs : The features of all the nodes.
labels : The labels of all the nodes.
val_nid : the node Ids for validation.
batch_size : Number of nodes to compute at the same time.
device : The GPU device to evaluate on.
"""
model.eval()
with th.no_grad():
pred = model.inference(g, inputs, batch_size, device)
pred = model.inference(g, nfeat, device)
model.train()
return compute_acc(pred[val_nid], labels[val_nid])
return compute_acc(pred[val_nid], labels[val_nid].to(pred.device))

def load_subtensor(g, seeds, input_nodes, device):
def load_subtensor(nfeat, labels, seeds, input_nodes, device):
"""
Copys features and labels of a set of nodes onto GPU.
Extracts features and labels for a subset of nodes
"""
batch_inputs = g.ndata['features'][input_nodes].to(device)
batch_labels = g.ndata['labels'][seeds].to(device)
batch_inputs = nfeat[input_nodes].to(device)
batch_labels = labels[seeds].to(device)
return batch_inputs, batch_labels

#### Entry point
def run(args, device, data):
# Unpack data
in_feats, n_classes, train_g, val_g, test_g = data
n_classes, train_g, val_g, test_g, train_nfeat, train_labels, \
val_nfeat, val_labels, test_nfeat, test_labels = data
in_feats = train_nfeat.shape[1]
train_nid = th.nonzero(train_g.ndata['train_mask'], as_tuple=True)[0]
val_nid = th.nonzero(val_g.ndata['val_mask'], as_tuple=True)[0]
test_nid = th.nonzero(~(test_g.ndata['train_mask'] | test_g.ndata['val_mask']), as_tuple=True)[0]
Expand All @@ -144,7 +138,6 @@ def run(args, device, data):
model = SAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, args.dropout)
model = model.to(device)
loss_fcn = nn.CrossEntropyLoss()
loss_fcn = loss_fcn.to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr)

# Training loop
Expand All @@ -158,10 +151,9 @@ def run(args, device, data):
tic_step = time.time()
for step, (input_nodes, seeds, blocks) in enumerate(dataloader):
# Load the input features as well as output labels
#batch_inputs, batch_labels = load_subtensor(train_g, seeds, input_nodes, device)
batch_inputs, batch_labels = load_subtensor(train_nfeat, train_labels,
seeds, input_nodes, device)
blocks = [block.int().to(device) for block in blocks]
batch_inputs = blocks[0].srcdata['features']
batch_labels = blocks[-1].dstdata['labels']

# Compute loss and prediction
batch_pred = model(blocks, batch_inputs)
Expand All @@ -183,17 +175,17 @@ def run(args, device, data):
if epoch >= 5:
avg += toc - tic
if epoch % args.eval_every == 0 and epoch != 0:
eval_acc = evaluate(model, val_g, val_g.ndata['features'], val_g.ndata['labels'], val_nid, args.batch_size, device)
eval_acc = evaluate(model, val_g, val_nfeat, val_labels, val_nid, device)
print('Eval Acc {:.4f}'.format(eval_acc))
test_acc = evaluate(model, test_g, test_g.ndata['features'], test_g.ndata['labels'], test_nid, args.batch_size, device)
test_acc = evaluate(model, test_g, test_nfeat, test_labels, test_nid, device)
print('Test Acc: {:.4f}'.format(test_acc))

print('Avg epoch time: {}'.format(avg / (epoch - 4)))

if __name__ == '__main__':
argparser = argparse.ArgumentParser("multi-gpu training")
argparser.add_argument('--gpu', type=int, default=0,
help="GPU device ID. Use -1 for CPU training")
help="GPU device ID. Use -1 for CPU training")
argparser.add_argument('--dataset', type=str, default='reddit')
argparser.add_argument('--num-epochs', type=int, default=20)
argparser.add_argument('--num-hidden', type=int, default=16)
Expand All @@ -205,9 +197,14 @@ def run(args, device, data):
argparser.add_argument('--lr', type=float, default=0.003)
argparser.add_argument('--dropout', type=float, default=0.5)
argparser.add_argument('--num-workers', type=int, default=4,
help="Number of sampling processes. Use 0 for no extra process.")
help="Number of sampling processes. Use 0 for no extra process.")
argparser.add_argument('--inductive', action='store_true',
help="Inductive learning setting")
help="Inductive learning setting")
argparser.add_argument('--data-cpu', action='store_true',
help="By default the script puts all node features and labels "
"on GPU when using it to save time for data copy. This may "
"be undesired if they cannot fit in GPU memory at once. "
"This flag disables that.")
args = argparser.parse_args()

if args.gpu >= 0:
Expand All @@ -217,24 +214,33 @@ def run(args, device, data):

if args.dataset == 'reddit':
g, n_classes = load_reddit()
elif args.dataset == 'ogb-product':
g, n_classes = load_ogb('ogbn-products')
else:
raise Exception('unknown dataset')

in_feats = g.ndata['features'].shape[1]

if args.inductive:
train_g, val_g, test_g = inductive_split(g)
train_nfeat = train_g.ndata.pop('features')
val_nfeat = val_g.ndata.pop('features')
test_nfeat = test_g.ndata.pop('features')
train_labels = train_g.ndata.pop('labels')
val_labels = val_g.ndata.pop('labels')
test_labels = test_g.ndata.pop('labels')
else:
train_g = val_g = test_g = g
train_nfeat = val_nfeat = test_nfeat = g.ndata.pop('features')
train_labels = val_labels = test_labels = g.ndata.pop('labels')

if not args.data_cpu:
train_nfeat = train_nfeat.to(device)
train_labels = train_labels.to(device)

# Create csr/coo/csc formats before launching training processes with multi-gpu.
# This avoids creating certain formats in each sub-process, which saves momory and CPU.
train_g.create_formats_()
val_g.create_formats_()
test_g.create_formats_()
# Pack data
data = in_feats, n_classes, train_g, val_g, test_g
data = n_classes, train_g, val_g, test_g, train_nfeat, train_labels, \
val_nfeat, val_labels, test_nfeat, test_labels

run(args, device, data)
68 changes: 42 additions & 26 deletions examples/pytorch/graphsage/train_sampling_multi_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,12 @@
import torch.nn.functional as F
import torch.optim as optim
import torch.multiprocessing as mp
from torch.utils.data import DataLoader
import dgl.function as fn
import dgl.nn.pytorch as dglnn
import time
import math
import argparse
from dgl.data import RedditDataset
from torch.nn.parallel import DistributedDataParallel
import tqdm
import traceback

from utils import thread_wrapped_func
from load_graph import load_reddit, inductive_split
Expand Down Expand Up @@ -48,7 +44,7 @@ def forward(self, blocks, x):
h = self.dropout(h)
return h

def inference(self, g, x, batch_size, device):
def inference(self, g, x, device):
"""
Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling).
g : the entire graph.
Expand All @@ -62,14 +58,13 @@ def inference(self, g, x, batch_size, device):
# Therefore, we compute the representation of all nodes layer by layer. The nodes
# on each layer are of course splitted in batches.
# TODO: can we standardize this?
nodes = th.arange(g.number_of_nodes())
for l, layer in enumerate(self.layers):
y = th.zeros(g.number_of_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)
y = th.zeros(g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes)

sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.dataloading.NodeDataLoader(
g,
th.arange(g.number_of_nodes()),
th.arange(g.num_nodes()),
sampler,
batch_size=args.batch_size,
shuffle=True,
Expand Down Expand Up @@ -97,27 +92,26 @@ def compute_acc(pred, labels):
"""
return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)

def evaluate(model, g, inputs, labels, val_nid, batch_size, device):
def evaluate(model, g, nfeat, labels, val_nid, device):
"""
Evaluate the model on the validation set specified by ``val_nid``.
g : The entire graph.
inputs : The features of all the nodes.
labels : The labels of all the nodes.
val_nid : A node ID tensor indicating which nodes do we actually compute the accuracy for.
batch_size : Number of nodes to compute at the same time.
device : The GPU device to evaluate on.
"""
model.eval()
with th.no_grad():
pred = model.inference(g, inputs, batch_size, device)
pred = model.inference(g, nfeat, device)
model.train()
return compute_acc(pred[val_nid], labels[val_nid])

def load_subtensor(g, labels, seeds, input_nodes, dev_id):
def load_subtensor(nfeat, labels, seeds, input_nodes, dev_id):
"""
Copys features and labels of a set of nodes onto GPU.
Extracts features and labels for a subset of nodes.
"""
batch_inputs = g.ndata['features'][input_nodes].to(dev_id)
batch_inputs = nfeat[input_nodes].to(dev_id)
batch_labels = labels[seeds].to(dev_id)
return batch_inputs, batch_labels

Expand All @@ -137,7 +131,25 @@ def run(proc_id, n_gpus, args, devices, data):
th.cuda.set_device(dev_id)

# Unpack data
in_feats, n_classes, train_g, val_g, test_g = data
n_classes, train_g, val_g, test_g = data

if args.inductive:
train_nfeat = train_g.ndata.pop('features')
val_nfeat = val_g.ndata.pop('features')
test_nfeat = test_g.ndata.pop('features')
train_labels = train_g.ndata.pop('labels')
val_labels = val_g.ndata.pop('labels')
test_labels = test_g.ndata.pop('labels')
else:
train_nfeat = val_nfeat = test_nfeat = g.ndata.pop('features')
train_labels = val_labels = test_labels = g.ndata.pop('labels')

if not args.data_cpu:
train_nfeat = train_nfeat.to(dev_id)
train_labels = train_labels.to(dev_id)

in_feats = train_nfeat.shape[1]

train_mask = train_g.ndata['train_mask']
val_mask = val_g.ndata['val_mask']
test_mask = ~(test_g.ndata['train_mask'] | test_g.ndata['val_mask'])
Expand Down Expand Up @@ -166,7 +178,6 @@ def run(proc_id, n_gpus, args, devices, data):
if n_gpus > 1:
model = DistributedDataParallel(model, device_ids=[dev_id], output_device=dev_id)
loss_fcn = nn.CrossEntropyLoss()
loss_fcn = loss_fcn.to(dev_id)
optimizer = optim.Adam(model.parameters(), lr=args.lr)

# Training loop
Expand All @@ -182,7 +193,8 @@ def run(proc_id, n_gpus, args, devices, data):
tic_step = time.time()

# Load the input features as well as output labels
batch_inputs, batch_labels = load_subtensor(train_g, train_g.ndata['labels'], seeds, input_nodes, dev_id)
batch_inputs, batch_labels = load_subtensor(train_nfeat, train_labels,
seeds, input_nodes, dev_id)
blocks = [block.int().to(dev_id) for block in blocks]
# Compute loss and prediction
batch_pred = model(blocks, batch_inputs)
Expand All @@ -209,14 +221,14 @@ def run(proc_id, n_gpus, args, devices, data):
if epoch % args.eval_every == 0 and epoch != 0:
if n_gpus == 1:
eval_acc = evaluate(
model, val_g, val_g.ndata['features'], val_g.ndata['labels'], val_nid, args.batch_size, devices[0])
model, val_g, val_nfeat, val_labels, val_nid, devices[0])
test_acc = evaluate(
model, test_g, test_g.ndata['features'], test_g.ndata['labels'], test_nid, args.batch_size, devices[0])
model, test_g, test_nfeat, test_labels, test_nid, devices[0])
else:
eval_acc = evaluate(
model.module, val_g, val_g.ndata['features'], val_g.ndata['labels'], val_nid, args.batch_size, devices[0])
model.module, val_g, val_nfeat, val_labels, val_nid, devices[0])
test_acc = evaluate(
model.module, test_g, test_g.ndata['features'], test_g.ndata['labels'], test_nid, args.batch_size, devices[0])
model.module, test_g, test_nfeat, test_labels, test_nid, devices[0])
print('Eval Acc {:.4f}'.format(eval_acc))
print('Test Acc: {:.4f}'.format(test_acc))

Expand All @@ -229,7 +241,7 @@ def run(proc_id, n_gpus, args, devices, data):
if __name__ == '__main__':
argparser = argparse.ArgumentParser("multi-gpu training")
argparser.add_argument('--gpu', type=str, default='0',
help="Comma separated list of GPU device IDs.")
help="Comma separated list of GPU device IDs.")
argparser.add_argument('--num-epochs', type=int, default=20)
argparser.add_argument('--num-hidden', type=int, default=16)
argparser.add_argument('--num-layers', type=int, default=2)
Expand All @@ -240,9 +252,14 @@ def run(proc_id, n_gpus, args, devices, data):
argparser.add_argument('--lr', type=float, default=0.003)
argparser.add_argument('--dropout', type=float, default=0.5)
argparser.add_argument('--num-workers', type=int, default=0,
help="Number of sampling processes. Use 0 for no extra process.")
help="Number of sampling processes. Use 0 for no extra process.")
argparser.add_argument('--inductive', action='store_true',
help="Inductive learning setting")
help="Inductive learning setting")
argparser.add_argument('--data-cpu', action='store_true',
help="By default the script puts all node features and labels "
"on GPU when using it to save time for data copy. This may "
"be undesired if they cannot fit in GPU memory at once. "
"This flag disables that.")
args = argparser.parse_args()

devices = list(map(int, args.gpu.split(',')))
Expand All @@ -251,7 +268,6 @@ def run(proc_id, n_gpus, args, devices, data):
g, n_classes = load_reddit()
# Construct graph
g = dgl.as_heterograph(g)
in_feats = g.ndata['features'].shape[1]

if args.inductive:
train_g, val_g, test_g = inductive_split(g)
Expand All @@ -264,7 +280,7 @@ def run(proc_id, n_gpus, args, devices, data):
val_g.create_formats_()
test_g.create_formats_()
# Pack data
data = in_feats, n_classes, train_g, val_g, test_g
data = n_classes, train_g, val_g, test_g

if n_gpus == 1:
run(0, n_gpus, args, devices, data)
Expand Down
Loading

0 comments on commit 927d2b3

Please sign in to comment.