Skip to content

Commit

Permalink
[Feature] Add Device Flag in Data Loaders (dmlc#2450)
Browse files Browse the repository at this point in the history
* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

* Update

Co-authored-by: Minjie Wang <[email protected]>
  • Loading branch information
mufeili and jermainewang authored Jan 14, 2021
1 parent b1fb3c1 commit 5da3439
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 18 deletions.
57 changes: 41 additions & 16 deletions python/dgl/dataloading/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,36 +139,52 @@ def collate(self, items):
_pop_blocks_storage(result[-1], self.g_sampling)
return result

def _to_device(data, device):
if isinstance(data, dict):
for k, v in data.items():
data[k] = v.to(device)
elif isinstance(data, list):
data = [item.to(device) for item in data]
else:
data = data.to(device)
return data

class _NodeDataLoaderIter:
def __init__(self, node_dataloader):
self.device = node_dataloader.device
self.node_dataloader = node_dataloader
self.iter_ = iter(node_dataloader.dataloader)

def __next__(self):
# input_nodes, output_nodes, [items], blocks
result = next(self.iter_)
_restore_blocks_storage(result[-1], self.node_dataloader.collator.g)
result_ = next(self.iter_)
_restore_blocks_storage(result_[-1], self.node_dataloader.collator.g)

result = []
for data in result_:
result.append(_to_device(data, self.device))
return result

class _EdgeDataLoaderIter:
def __init__(self, edge_dataloader):
self.device = edge_dataloader.device
self.edge_dataloader = edge_dataloader
self.iter_ = iter(edge_dataloader.dataloader)

def __next__(self):
if self.edge_dataloader.collator.negative_sampler is None:
# input_nodes, pair_graph, [items], blocks
result = next(self.iter_)
_restore_subgraph_storage(result[1], self.edge_dataloader.collator.g)
_restore_blocks_storage(result[-1], self.edge_dataloader.collator.g_sampling)
return result
else:
result_ = next(self.iter_)

if self.edge_dataloader.collator.negative_sampler is not None:
# input_nodes, pair_graph, neg_pair_graph, [items], blocks
result = next(self.iter_)
_restore_subgraph_storage(result[1], self.edge_dataloader.collator.g)
_restore_subgraph_storage(result[2], self.edge_dataloader.collator.g)
_restore_blocks_storage(result[-1], self.edge_dataloader.collator.g_sampling)
return result
# Otherwise, input_nodes, pair_graph, [items], blocks
_restore_subgraph_storage(result_[2], self.edge_dataloader.collator.g)
_restore_subgraph_storage(result_[1], self.edge_dataloader.collator.g)
_restore_blocks_storage(result_[-1], self.edge_dataloader.collator.g_sampling)

result = []
for data in result_:
result.append(_to_device(data, self.device))
return result

class NodeDataLoader:
"""PyTorch dataloader for batch-iterating over a set of nodes, generating the list
Expand All @@ -182,6 +198,9 @@ class NodeDataLoader:
The node set to compute outputs.
block_sampler : dgl.dataloading.BlockSampler
The neighborhood sampler.
device : device context, optional
The device of the generated blocks in each iteration, which should be a
PyTorch device object (e.g., ``torch.device``).
kwargs : dict
Arguments being passed to :py:class:`torch.utils.data.DataLoader`.
Expand All @@ -200,7 +219,7 @@ class NodeDataLoader:
"""
collator_arglist = inspect.getfullargspec(NodeCollator).args

def __init__(self, g, nids, block_sampler, **kwargs):
def __init__(self, g, nids, block_sampler, device='cpu', **kwargs):
collator_kwargs = {}
dataloader_kwargs = {}
for k, v in kwargs.items():
Expand All @@ -210,6 +229,7 @@ def __init__(self, g, nids, block_sampler, **kwargs):
dataloader_kwargs[k] = v

if isinstance(g, DistGraph):
assert device == 'cpu', 'Only cpu is supported in the case of a DistGraph.'
# Distributed DataLoader currently does not support heterogeneous graphs
# and does not copy features. Fallback to normal solution
self.collator = NodeCollator(g, nids, block_sampler, **collator_kwargs)
Expand All @@ -224,6 +244,7 @@ def __init__(self, g, nids, block_sampler, **kwargs):
collate_fn=self.collator.collate,
**dataloader_kwargs)
self.is_distributed = False
self.device = device

def __iter__(self):
"""Return the iterator of the data loader."""
Expand Down Expand Up @@ -267,6 +288,9 @@ class EdgeDataLoader:
The edge set in graph :attr:`g` to compute outputs.
block_sampler : dgl.dataloading.BlockSampler
The neighborhood sampler.
device : device context, optional
The device of the generated blocks and graphs in each iteration, which should be a
PyTorch device object (e.g., ``torch.device``).
g_sampling : DGLGraph, optional
The graph where neighborhood sampling is performed.
Expand Down Expand Up @@ -397,7 +421,7 @@ class EdgeDataLoader:
"""
collator_arglist = inspect.getfullargspec(EdgeCollator).args

def __init__(self, g, eids, block_sampler, **kwargs):
def __init__(self, g, eids, block_sampler, device='cpu', **kwargs):
collator_kwargs = {}
dataloader_kwargs = {}
for k, v in kwargs.items():
Expand All @@ -412,6 +436,7 @@ def __init__(self, g, eids, block_sampler, **kwargs):
+ 'Please use DistDataLoader directly.'
self.dataloader = DataLoader(
self.collator.dataset, collate_fn=self.collator.collate, **dataloader_kwargs)
self.device = device

def __iter__(self):
"""Return the iterator of the data loader."""
Expand Down
122 changes: 121 additions & 1 deletion tests/pytorch/test_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import dgl
import backend as F
import numpy as np
import unittest
from torch.utils.data import DataLoader
from collections import defaultdict
Expand Down Expand Up @@ -217,6 +216,127 @@ def test_graph_dataloader():
assert isinstance(graph, dgl.DGLGraph)
assert F.asnumpy(label).shape[0] == batch_size

def _check_device(data):
if isinstance(data, dict):
for k, v in data.items():
assert v.device == F.ctx()
elif isinstance(data, list):
for v in data:
assert v.device == F.ctx()
else:
assert data.device == F.ctx()

def test_node_dataloader():
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)

g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4]))
g1.ndata['feat'] = F.copy_to(F.randn((5, 8)), F.cpu())

# return_indices = False
dataloader = dgl.dataloading.NodeDataLoader(
g1, g1.nodes(), sampler, device=F.ctx(), batch_size=g1.num_nodes())
for input_nodes, output_nodes, blocks in dataloader:
_check_device(input_nodes)
_check_device(output_nodes)
_check_device(blocks)

# return_indices = True
dataloader = dgl.dataloading.NodeDataLoader(
g1, g1.nodes(), sampler, device=F.ctx(), batch_size=g1.num_nodes(), return_indices=True)
for input_nodes, output_nodes, items, blocks in dataloader:
_check_device(input_nodes)
_check_device(output_nodes)
_check_device(items)
_check_device(blocks)

g2 = dgl.heterograph({
('user', 'follow', 'user'): ([0, 0, 0, 1, 1, 1, 2], [1, 2, 3, 0, 2, 3, 0]),
('user', 'followed-by', 'user'): ([1, 2, 3, 0, 2, 3, 0], [0, 0, 0, 1, 1, 1, 2]),
('user', 'play', 'game'): ([0, 1, 1, 3, 5], [0, 1, 2, 0, 2]),
('game', 'played-by', 'user'): ([0, 1, 2, 0, 2], [0, 1, 1, 3, 5])
})
for ntype in g2.ntypes:
g2.nodes[ntype].data['feat'] = F.copy_to(F.randn((g2.num_nodes(ntype), 8)), F.cpu())
batch_size = max(g2.num_nodes(nty) for nty in g2.ntypes)

# return_indices = False
dataloader = dgl.dataloading.NodeDataLoader(
g2, {nty: g2.nodes(nty) for nty in g2.ntypes},
sampler, device=F.ctx(), batch_size=batch_size)
for input_nodes, output_nodes, blocks in dataloader:
_check_device(input_nodes)
_check_device(output_nodes)
_check_device(blocks)

# return_indices = True
dataloader = dgl.dataloading.NodeDataLoader(
g2, {nty: g2.nodes(nty) for nty in g2.ntypes},
sampler, device=F.ctx(), batch_size=batch_size, return_indices=True)
for input_nodes, output_nodes, items, blocks in dataloader:
_check_device(input_nodes)
_check_device(output_nodes)
_check_device(items)
_check_device(blocks)

def test_edge_dataloader():
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)
neg_sampler = dgl.dataloading.negative_sampler.Uniform(2)

g1 = dgl.graph(([0, 0, 0, 1, 1], [1, 2, 3, 3, 4]))
g1.ndata['feat'] = F.copy_to(F.randn((5, 8)), F.cpu())

# return_indices = False & no negative sampler
dataloader = dgl.dataloading.EdgeDataLoader(
g1, g1.edges(form='eid'), sampler, device=F.ctx(), batch_size=g1.num_edges())
for input_nodes, pos_pair_graph, blocks in dataloader:
_check_device(input_nodes)
_check_device(pos_pair_graph)
_check_device(blocks)

# return_indices = False & negative sampler
dataloader = dgl.dataloading.EdgeDataLoader(
g1, g1.edges(form='eid'), sampler, device=F.ctx(),
negative_sampler=neg_sampler, batch_size=g1.num_edges())
for input_nodes, pos_pair_graph, neg_pair_graph, blocks in dataloader:
_check_device(input_nodes)
_check_device(pos_pair_graph)
_check_device(neg_pair_graph)
_check_device(blocks)

g2 = dgl.heterograph({
('user', 'follow', 'user'): ([0, 0, 0, 1, 1, 1, 2], [1, 2, 3, 0, 2, 3, 0]),
('user', 'followed-by', 'user'): ([1, 2, 3, 0, 2, 3, 0], [0, 0, 0, 1, 1, 1, 2]),
('user', 'play', 'game'): ([0, 1, 1, 3, 5], [0, 1, 2, 0, 2]),
('game', 'played-by', 'user'): ([0, 1, 2, 0, 2], [0, 1, 1, 3, 5])
})
for ntype in g2.ntypes:
g2.nodes[ntype].data['feat'] = F.copy_to(F.randn((g2.num_nodes(ntype), 8)), F.cpu())
batch_size = max(g2.num_edges(ety) for ety in g2.canonical_etypes)

# return_indices = True & no negative sampler
dataloader = dgl.dataloading.EdgeDataLoader(
g2, {ety: g2.edges(form='eid', etype=ety) for ety in g2.canonical_etypes},
sampler, device=F.ctx(), batch_size=batch_size, return_indices=True)
for input_nodes, pos_pair_graph, items, blocks in dataloader:
_check_device(input_nodes)
_check_device(pos_pair_graph)
_check_device(items)
_check_device(blocks)

# return_indices = True & negative sampler
dataloader = dgl.dataloading.EdgeDataLoader(
g2, {ety: g2.edges(form='eid', etype=ety) for ety in g2.canonical_etypes},
sampler, device=F.ctx(), negative_sampler=neg_sampler,
batch_size=batch_size, return_indices=True)
for input_nodes, pos_pair_graph, neg_pair_graph, items, blocks in dataloader:
_check_device(input_nodes)
_check_device(pos_pair_graph)
_check_device(neg_pair_graph)
_check_device(items)
_check_device(blocks)

if __name__ == '__main__':
test_neighbor_sampler_dataloader()
test_graph_dataloader()
test_node_dataloader()
test_edge_dataloader()
1 change: 0 additions & 1 deletion tests/pytorch/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from test_utils import parametrize_dtype
from copy import deepcopy

import numpy as np
import scipy as sp

def _AXWb(A, X, W, b):
Expand Down

0 comments on commit 5da3439

Please sign in to comment.