forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Benchmark] update node classification for multigpu benchmark (dmlc#5674
) Co-authored-by: Hongzhi (Steve), Chen <[email protected]>
- Loading branch information
1 parent
0b386a1
commit ac6be7d
Showing
2 changed files
with
319 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Multiple GPU Training | ||
|
||
## Requirements | ||
|
||
```bash | ||
pip install torchmetrics==0.11.4 | ||
``` | ||
|
||
## How to run | ||
|
||
### Node classification | ||
|
||
Run with following (available dataset: "ogbn-products", "ogbn-arxiv") | ||
|
||
```bash | ||
python3 node_classification_sage.py --dataset_name ogbn-products | ||
``` | ||
|
||
#### __Results__ with default arguments | ||
``` | ||
* Test Accuracy of "ogbn-products": ~0.7716 | ||
* Test Accuracy of "ogbn-arxiv": ~0.6994 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,296 @@ | ||
import argparse | ||
import os | ||
import time | ||
|
||
import dgl | ||
import dgl.nn as dglnn | ||
|
||
import torch | ||
import torch.distributed as dist | ||
import torch.multiprocessing as mp | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torchmetrics.functional as MF | ||
import tqdm | ||
from dgl.data import AsNodePredDataset | ||
from dgl.dataloading import ( | ||
DataLoader, | ||
MultiLayerFullNeighborSampler, | ||
NeighborSampler, | ||
) | ||
from dgl.multiprocessing import shared_tensor | ||
from ogb.nodeproppred import DglNodePropPredDataset | ||
from torch.nn.parallel import DistributedDataParallel | ||
|
||
|
||
class SAGE(nn.Module): | ||
def __init__(self, in_size, hid_size, out_size): | ||
super().__init__() | ||
self.layers = nn.ModuleList() | ||
# three-layer GraphSAGE-mean | ||
self.layers.append(dglnn.SAGEConv(in_size, hid_size, "mean")) | ||
self.layers.append(dglnn.SAGEConv(hid_size, hid_size, "mean")) | ||
self.layers.append(dglnn.SAGEConv(hid_size, out_size, "mean")) | ||
self.dropout = nn.Dropout(0.5) | ||
self.hid_size = hid_size | ||
self.out_size = out_size | ||
|
||
def forward(self, blocks, x): | ||
h = x | ||
for l, (layer, block) in enumerate(zip(self.layers, blocks)): | ||
h = layer(block, h) | ||
if l != len(self.layers) - 1: | ||
h = F.relu(h) | ||
h = self.dropout(h) | ||
return h | ||
|
||
def inference(self, g, device, batch_size, use_uva): | ||
g.ndata["h"] = g.ndata["feat"] | ||
sampler = MultiLayerFullNeighborSampler(1, prefetch_node_feats=["h"]) | ||
for l, layer in enumerate(self.layers): | ||
dataloader = DataLoader( | ||
g, | ||
torch.arange(g.num_nodes(), device=device), | ||
sampler, | ||
device=device, | ||
batch_size=batch_size, | ||
shuffle=False, | ||
drop_last=False, | ||
num_workers=0, | ||
use_ddp=True, | ||
use_uva=use_uva, | ||
) | ||
# in order to prevent running out of GPU memory, allocate a | ||
# shared output tensor 'y' in host memory | ||
y = shared_tensor( | ||
( | ||
g.num_nodes(), | ||
self.hid_size | ||
if l != len(self.layers) - 1 | ||
else self.out_size, | ||
) | ||
) | ||
for input_nodes, output_nodes, blocks in ( | ||
tqdm.tqdm(dataloader) if dist.get_rank() == 0 else dataloader | ||
): | ||
x = blocks[0].srcdata["h"] | ||
h = layer(blocks[0], x) # len(blocks) = 1 | ||
if l != len(self.layers) - 1: | ||
h = F.relu(h) | ||
h = self.dropout(h) | ||
# non_blocking (with pinned memory) to accelerate data transfer | ||
y[output_nodes] = h.to(y.device, non_blocking=True) | ||
# make sure all GPUs are done writing to 'y' | ||
dist.barrier() | ||
g.ndata["h"] = y if use_uva else y.to(device) | ||
|
||
g.ndata.pop("h") | ||
return y | ||
|
||
|
||
def evaluate(model, g, num_classes, dataloader): | ||
model.eval() | ||
ys = [] | ||
y_hats = [] | ||
for it, (input_nodes, output_nodes, blocks) in enumerate(dataloader): | ||
with torch.no_grad(): | ||
x = blocks[0].srcdata["feat"] | ||
ys.append(blocks[-1].dstdata["label"]) | ||
y_hats.append(model(blocks, x)) | ||
return MF.accuracy( | ||
torch.cat(y_hats), | ||
torch.cat(ys), | ||
task="multiclass", | ||
num_classes=num_classes, | ||
) | ||
|
||
|
||
def layerwise_infer( | ||
proc_id, device, g, num_classes, nid, model, use_uva, batch_size=2**10 | ||
): | ||
model.eval() | ||
with torch.no_grad(): | ||
pred = model.module.inference(g, device, batch_size, use_uva) | ||
pred = pred[nid] | ||
labels = g.ndata["label"][nid].to(pred.device) | ||
if proc_id == 0: | ||
acc = MF.accuracy( | ||
pred, labels, task="multiclass", num_classes=num_classes | ||
) | ||
print("Test accuracy {:.4f}".format(acc.item())) | ||
|
||
|
||
def train( | ||
proc_id, | ||
nprocs, | ||
device, | ||
g, | ||
num_classes, | ||
train_idx, | ||
val_idx, | ||
model, | ||
use_uva, | ||
num_epochs, | ||
): | ||
sampler = NeighborSampler( | ||
[10, 10, 10], prefetch_node_feats=["feat"], prefetch_labels=["label"] | ||
) | ||
train_dataloader = DataLoader( | ||
g, | ||
train_idx, | ||
sampler, | ||
device=device, | ||
batch_size=1024, | ||
shuffle=True, | ||
drop_last=False, | ||
num_workers=0, | ||
use_ddp=True, | ||
use_uva=use_uva, | ||
) | ||
val_dataloader = DataLoader( | ||
g, | ||
val_idx, | ||
sampler, | ||
device=device, | ||
batch_size=1024, | ||
shuffle=True, | ||
drop_last=False, | ||
num_workers=0, | ||
use_ddp=True, | ||
use_uva=use_uva, | ||
) | ||
opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4) | ||
for epoch in range(num_epochs): | ||
t0 = time.time() | ||
model.train() | ||
total_loss = 0 | ||
for it, (_, _, blocks) in enumerate(train_dataloader): | ||
x = blocks[0].srcdata["feat"] | ||
y = blocks[-1].dstdata["label"] | ||
y_hat = model(blocks, x) | ||
loss = F.cross_entropy(y_hat, y) | ||
opt.zero_grad() | ||
loss.backward() | ||
opt.step() | ||
total_loss += loss | ||
acc = ( | ||
evaluate(model, g, num_classes, val_dataloader).to(device) / nprocs | ||
) | ||
t1 = time.time() | ||
dist.reduce(acc, 0) | ||
if proc_id == 0: | ||
print( | ||
"Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f} | " | ||
"Time {:.4f}".format( | ||
epoch, total_loss / (it + 1), acc.item(), t1 - t0 | ||
) | ||
) | ||
|
||
|
||
def run(proc_id, nprocs, devices, g, data, mode, num_epochs): | ||
# find corresponding device for my rank | ||
device = devices[proc_id] | ||
torch.cuda.set_device(device) | ||
# initialize process group and unpack data for sub-processes | ||
dist.init_process_group( | ||
backend="nccl", | ||
init_method="tcp://127.0.0.1:12345", | ||
world_size=nprocs, | ||
rank=proc_id, | ||
) | ||
num_classes, train_idx, val_idx, test_idx = data | ||
train_idx = train_idx.to(device) | ||
val_idx = val_idx.to(device) | ||
g = g.to(device if mode == "puregpu" else "cpu") | ||
# create GraphSAGE model (distributed) | ||
in_size = g.ndata["feat"].shape[1] | ||
model = SAGE(in_size, 256, num_classes).to(device) | ||
model = DistributedDataParallel( | ||
model, device_ids=[device], output_device=device | ||
) | ||
# training + testing | ||
use_uva = mode == "mixed" | ||
train( | ||
proc_id, | ||
nprocs, | ||
device, | ||
g, | ||
num_classes, | ||
train_idx, | ||
val_idx, | ||
model, | ||
use_uva, | ||
num_epochs, | ||
) | ||
layerwise_infer(proc_id, device, g, num_classes, test_idx, model, use_uva) | ||
# cleanup process group | ||
dist.destroy_process_group() | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--mode", | ||
default="mixed", | ||
choices=["mixed", "puregpu"], | ||
help="Training mode. 'mixed' for CPU-GPU mixed training, " | ||
"'puregpu' for pure-GPU training.", | ||
) | ||
parser.add_argument( | ||
"--gpu", | ||
type=str, | ||
default="0", | ||
help="GPU(s) in use. Can be a list of gpu ids for multi-gpu training," | ||
" e.g., 0,1,2,3.", | ||
) | ||
parser.add_argument( | ||
"--num_epochs", | ||
type=int, | ||
default=20, | ||
help="Number of epochs for train.", | ||
) | ||
parser.add_argument( | ||
"--dataset_name", | ||
type=str, | ||
default="ogbn-products", | ||
help="Dataset name.", | ||
) | ||
parser.add_argument( | ||
"--dataset_dir", | ||
type=str, | ||
default="dataset", | ||
help="Root directory of dataset.", | ||
) | ||
args = parser.parse_args() | ||
devices = list(map(int, args.gpu.split(","))) | ||
nprocs = len(devices) | ||
assert ( | ||
torch.cuda.is_available() | ||
), f"Must have GPUs to enable multi-gpu training." | ||
print(f"Training in {args.mode} mode using {nprocs} GPU(s)") | ||
|
||
# load and preprocess dataset | ||
print("Loading data") | ||
dataset = AsNodePredDataset( | ||
DglNodePropPredDataset(args.dataset_name, root=args.dataset_dir) | ||
) | ||
g = dataset[0] | ||
# avoid creating certain graph formats in each sub-process to save momory | ||
g.create_formats_() | ||
if args.dataset_name == "ogbn-arxiv": | ||
g = dgl.to_bidirected(g, copy_ndata=True) | ||
g = dgl.add_self_loop(g) | ||
# thread limiting to avoid resource competition | ||
os.environ["OMP_NUM_THREADS"] = str(mp.cpu_count() // 2 // nprocs) | ||
data = ( | ||
dataset.num_classes, | ||
dataset.train_idx, | ||
dataset.val_idx, | ||
dataset.test_idx, | ||
) | ||
|
||
mp.spawn( | ||
run, | ||
args=(nprocs, devices, g, data, args.mode, args.num_epochs), | ||
nprocs=nprocs, | ||
) |