Skip to content

Commit

Permalink
[Examples] Add pure gpu example of graphsage (dmlc#3796)
Browse files Browse the repository at this point in the history
* Add pure_gpu example of graphsage

* move to advanced directory

Co-authored-by: Quan Gan <[email protected]>
  • Loading branch information
nv-dlasalle and BarclayII authored Mar 11, 2022
1 parent f908f35 commit 57d2f31
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 0 deletions.
2 changes: 2 additions & 0 deletions examples/pytorch/graphsage/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ Inductive Representation Learning on Large Graphs (GraphSAGE)
- Author's code repo: [https://github.com/williamleif/graphsage-simple](https://github.com/williamleif/graphsage-simple). Note that the original code is
simple reference implementation of GraphSAGE.

Advanced usages, including how to run pure GPU sampling, how to train with PyTorch Lightning, etc., are in the `advanced` directory.

Requirements
------------
- requests
Expand Down
6 changes: 6 additions & 0 deletions examples/pytorch/graphsage/advanced/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
More Examples for Training GraphSAGE
============================

### Pure GPU sampling

```bash
python3 pure_gpu_node_classification.py
```

### Unsupervised training

Train w/ mini-batch sampling in an unsupervised fashion (on the Reddit dataset)
Expand Down
131 changes: 131 additions & 0 deletions examples/pytorch/graphsage/advanced/pure_gpu_node_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics.functional as MF
import dgl
import dgl.nn as dglnn
import time
import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset
import tqdm
import argparse

class SAGE(nn.Module):
def __init__(self, in_feats, n_hidden, n_classes):
super().__init__()
self.layers = nn.ModuleList()
self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean'))
self.dropout = nn.Dropout(0.5)
self.n_hidden = n_hidden
self.n_classes = n_classes

def forward(self, blocks, x):
h = x
for l, (layer, block) in enumerate(zip(self.layers, blocks)):
h = layer(block, h)
if l != len(self.layers) - 1:
h = F.relu(h)
h = self.dropout(h)
return h

def inference(self, g, device, batch_size, num_workers, buffer_device=None):
# The difference between this inference function and the one in the official
# example is that the intermediate results can also benefit from prefetching.
feat = g.ndata['feat']
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['feat'])
dataloader = dgl.dataloading.NodeDataLoader(
g, torch.arange(g.num_nodes()).to(g.device), sampler, device=device,
batch_size=batch_size, shuffle=False, drop_last=False,
num_workers=num_workers)

if buffer_device is None:
buffer_device = device

for l, layer in enumerate(self.layers):
y = torch.empty(
g.num_nodes(), self.n_hidden if l != len(self.layers) - 1 else self.n_classes,
device=buffer_device, pin_memory=True)
feat = feat.to(device)
for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
# use an explicitly contuous slice
x = feat[input_nodes]
h = layer(blocks[0], x)
if l != len(self.layers) - 1:
h = F.relu(h)
h = self.dropout(h)
# be design, our output nodes are contiguous so we can take
# advantage of that here
y[output_nodes[0]:output_nodes[-1]+1] = h.to(buffer_device)
feat = y
return y



dataset = DglNodePropPredDataset('ogbn-products')
graph, labels = dataset[0]
graph.ndata['label'] = labels.squeeze()
split_idx = dataset.get_idx_split()
train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test']

device = 'cuda'
train_idx = train_idx.to(device)
valid_idx = valid_idx.to(device)
test_idx = test_idx.to(device)

graph = graph.to(device)

model = SAGE(graph.ndata['feat'].shape[1], 256, dataset.num_classes).to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

sampler = dgl.dataloading.NeighborSampler(
[15, 10, 5], prefetch_node_feats=['feat'], prefetch_labels=['label'])
train_dataloader = dgl.dataloading.DataLoader(
graph, train_idx, sampler, device=device, batch_size=1024, shuffle=True,
drop_last=False, num_workers=0, use_uva=False)
valid_dataloader = dgl.dataloading.NodeDataLoader(
graph, valid_idx, sampler, device=device, batch_size=1024, shuffle=True,
drop_last=False, num_workers=0, use_uva=False)

durations = []
for _ in range(10):
model.train()
t0 = time.time()
for it, (input_nodes, output_nodes, blocks) in enumerate(train_dataloader):
x = blocks[0].srcdata['feat']
y = blocks[-1].dstdata['label']
y_hat = model(blocks, x)
loss = F.cross_entropy(y_hat, y)
opt.zero_grad()
loss.backward()
opt.step()
if it % 20 == 0:
acc = MF.accuracy(torch.argmax(y_hat, dim=1), y)
mem = torch.cuda.max_memory_allocated() / 1000000
print('Loss', loss.item(), 'Acc', acc.item(), 'GPU Mem', mem, 'MB')
tt = time.time()
print(tt - t0)
durations.append(tt - t0)

model.eval()
ys = []
y_hats = []
for it, (input_nodes, output_nodes, blocks) in enumerate(valid_dataloader):
with torch.no_grad():
x = blocks[0].srcdata['feat']
ys.append(blocks[-1].dstdata['label'])
y_hats.append(torch.argmax(model(blocks, x), dim=1))
acc = MF.accuracy(torch.cat(y_hats), torch.cat(ys))
print('Validation acc:', acc.item())

print(np.mean(durations[4:]), np.std(durations[4:]))

# Test accuracy and offline inference of all nodes
model.eval()
with torch.no_grad():
pred = model.inference(graph, device, 4096, 0, 'cpu')
pred = pred[test_idx].to(device)
label = graph.ndata['label'][test_idx]
acc = MF.accuracy(torch.argmax(pred, dim=1), label)
print('Test acc:', acc.item())

0 comments on commit 57d2f31

Please sign in to comment.