Skip to content

Commit

Permalink
[Enhancement] Add DGLGraph.to for PyTorch and MXNet backend (dmlc#600)
Browse files Browse the repository at this point in the history
* add graph_to

* use backend copy_to

* add test

* fix test

* framework agnostic to() test

* disable pylint complaint

* add examples

* fix docstring

* formatting

* Format

* Update test_to_device.py
  • Loading branch information
HQ01 authored and VoVAllen committed Jun 8, 2019
1 parent baa1623 commit 993fd3f
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source/api/python/graph.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ Converting from/to other format
DGLGraph.adjacency_matrix
DGLGraph.adjacency_matrix_scipy
DGLGraph.incidence_matrix
DGLGraph.to

Using Node/edge features
------------------------
Expand Down
3 changes: 2 additions & 1 deletion python/dgl/backend/pytorch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ def copy_to(input, ctx):
if ctx.type == 'cpu':
return input.cpu()
elif ctx.type == 'cuda':
th.cuda.set_device(ctx.index)
if ctx.index is not None:
th.cuda.set_device(ctx.index)
return input.cuda()
else:
raise RuntimeError('Invalid context', ctx)
Expand Down
25 changes: 25 additions & 0 deletions python/dgl/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3229,3 +3229,28 @@ def __repr__(self):
return ret.format(node=self.number_of_nodes(), edge=self.number_of_edges(),
ndata=str(self.node_attr_schemes()),
edata=str(self.edge_attr_schemes()))

# pylint: disable=invalid-name
def to(self, ctx):
"""
Move both ndata and edata to the targeted mode (cpu/gpu)
Framework agnostic
Parameters
----------
ctx : framework specific context object
Examples (Pytorch & MXNet)
--------
>>> import backend as F
>>> G = dgl.DGLGraph()
>>> G.add_nodes(5, {'h': torch.ones((5, 2))})
>>> G.add_edges([0, 1], [1, 2], {'m' : torch.ones((2, 2))})
>>> G.add_edges([0, 1], [1, 2], {'m' : torch.ones((2, 2))})
>>> G.to(F.cuda())
"""
for k in self.ndata.keys():
self.ndata[k] = F.copy_to(self.ndata[k], ctx)
for k in self.edata.keys():
self.edata[k] = F.copy_to(self.edata[k], ctx)
4 changes: 4 additions & 0 deletions tests/backend/backend_unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ def cuda():
"""Context object for CUDA."""
pass

def is_cuda_available():
"""Check whether CUDA is available."""
pass

###############################################################################
# Tensor functions on feature data
# --------------------------------
Expand Down
8 changes: 8 additions & 0 deletions tests/backend/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@
def cuda():
return mx.gpu()

def is_cuda_available():
# TODO: Does MXNet have a convenient function to test GPU availability/compilation?
try:
a = nd.array([1, 2, 3], ctx=mx.gpu())
return True
except mx.MXNetError:
return False

def array_equal(a, b):
return nd.equal(a, b).asnumpy().all()

Expand Down
3 changes: 3 additions & 0 deletions tests/backend/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
def cuda():
return th.device('cuda:0')

def is_cuda_available():
return th.cuda.is_available()

def array_equal(a, b):
return th.equal(a.cpu(), b.cpu())

Expand Down
13 changes: 13 additions & 0 deletions tests/compute/test_to_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import dgl
import backend as F

def test_to_device():
g = dgl.DGLGraph()
g.add_nodes(5, {'h' : F.ones((5, 2))})
g.add_edges([0, 1], [1, 2], {'m' : F.ones((2, 2))})
if F.is_cuda_available():
g.to(F.cuda())


if __name__ == '__main__':
test_to_device()

0 comments on commit 993fd3f

Please sign in to comment.