Skip to content

Commit

Permalink
[GraphIndex] refactor graph caching (dmlc#150)
Browse files Browse the repository at this point in the history
* refactor graph caching

* fix mx test

* fix typo
  • Loading branch information
jermainewang authored Nov 14, 2018
1 parent a9ffb59 commit 048f6d7
Show file tree
Hide file tree
Showing 14 changed files with 171 additions and 113 deletions.
11 changes: 9 additions & 2 deletions python/dgl/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def tensor(data, dtype=None):
def sparse_matrix(data, index, shape, force_format=False):
"""Create a sparse matrix.
NOTE: Please make sure that the data and index tensors are not
copied. This is critical to the performance.
Parameters
----------
data : Tensor
Expand Down Expand Up @@ -482,7 +485,7 @@ def reshape(input, shape):
"""
pass

def zeros(shape, dtype):
def zeros(shape, dtype, ctx):
"""Create a zero tensor.
Parameters
Expand All @@ -491,6 +494,8 @@ def zeros(shape, dtype):
The tensor shape.
dtype : data type
It should be one of the values in the data type dict.
ctx : context
The device of the result tensor.
Returns
-------
Expand All @@ -499,7 +504,7 @@ def zeros(shape, dtype):
"""
pass

def ones(shape, dtype):
def ones(shape, dtype, ctx):
"""Create a one tensor.
Parameters
Expand All @@ -508,6 +513,8 @@ def ones(shape, dtype):
The tensor shape.
dtype : data type
It should be one of the values in the data type dict.
ctx : context
The device of the result tensor.
Returns
-------
Expand Down
6 changes: 4 additions & 2 deletions python/dgl/backend/mxnet/immutable_graph_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def node_subgraphs(self, vs_arr):
induced_es.append(induced_e)
return gis, induced_ns, induced_es

def adjacency_matrix(self, transpose=False):
def adjacency_matrix(self, transpose, ctx):
"""Return the adjacency matrix representation of this graph.
By default, a row of returned adjacency matrix represents the destination
Expand All @@ -281,6 +281,8 @@ def adjacency_matrix(self, transpose=False):
----------
transpose : bool
A flag to tranpose the returned adjacency matrix.
ctx : context
The device context of the returned matrix.
Returns
-------
Expand All @@ -294,7 +296,7 @@ def adjacency_matrix(self, transpose=False):

indices = mat.indices
indptr = mat.indptr
data = mx.nd.ones(indices.shape, dtype=np.float32)
data = mx.nd.ones(indices.shape, dtype=np.float32, ctx=ctx)
return mx.nd.sparse.csr_matrix((data, indices, indptr), shape=mat.shape)

def from_coo_matrix(self, out_coo):
Expand Down
8 changes: 4 additions & 4 deletions python/dgl/backend/mxnet/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,11 @@ def reshape(input, shape):
# NOTE: the input cannot be a symbol
return nd.reshape(input ,shape)

def zeros(shape, dtype):
return nd.zeros(shape, dtype=dtype)
def zeros(shape, dtype, ctx):
return nd.zeros(shape, dtype=dtype, ctx=ctx)

def ones(shape, dtype):
return nd.ones(shape, dtype=dtype)
def ones(shape, dtype, ctx):
return nd.ones(shape, dtype=dtype, ctx=ctx)

def spmm(x, y):
return nd.dot(x, y)
Expand Down
11 changes: 6 additions & 5 deletions python/dgl/backend/pytorch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def sparse_matrix(data, index, shape, force_format=False):
fmt = index[0]
if fmt != 'coo':
raise TypeError('Pytorch backend only supports COO format. But got %s.' % fmt)
return th.sparse.FloatTensor(index[1], data, shape)
# NOTE: use _sparse_coo_tensor_unsafe to avoid unnecessary boundary check
return th._sparse_coo_tensor_unsafe(index[1], data, shape)

def sparse_matrix_indices(spmat):
return ('coo', spmat._indices())
Expand Down Expand Up @@ -98,11 +99,11 @@ def unsqueeze(input, dim):
def reshape(input, shape):
return th.reshape(input ,shape)

def zeros(shape, dtype):
return th.zeros(shape, dtype=dtype)
def zeros(shape, dtype, ctx):
return th.zeros(shape, dtype=dtype, device=ctx)

def ones(shape, dtype):
return th.ones(shape, dtype=dtype)
def ones(shape, dtype, ctx):
return th.ones(shape, dtype=dtype, device=ctx)

def spmm(x, y):
return th.spmm(x, y)
Expand Down
10 changes: 4 additions & 6 deletions python/dgl/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def _warn_and_set_initializer(self):
dgl_warning('Initializer is not set. Use zero initializer instead.'
' To suppress this warning, use `set_initializer` to'
' explicitly specify which initializer to use.')
self._initializer = lambda shape, dtype: F.zeros(shape, dtype)
self._initializer = lambda shape, dtype, ctx: F.zeros(shape, dtype, ctx)

def set_initializer(self, initializer):
"""Set the initializer for empty values.
Expand Down Expand Up @@ -283,9 +283,7 @@ def add_column(self, name, scheme, ctx):
' one column in the frame so number of rows can be inferred.' % name)
if self.initializer is None:
self._warn_and_set_initializer()
# TODO(minjie): directly init data on the targer device.
init_data = self.initializer((self.num_rows,) + scheme.shape, scheme.dtype)
init_data = F.copy_to(init_data, ctx)
init_data = self.initializer((self.num_rows,) + scheme.shape, scheme.dtype, ctx)
self._columns[name] = Column(init_data, scheme)

def update_column(self, name, data):
Expand Down Expand Up @@ -601,10 +599,10 @@ def add_rows(self, num_rows):

for key in self._frame:
scheme = self._frame[key].scheme

ctx = F.context(self._frame[key].data)
if self._frame.initializer is None:
self._frame._warn_and_set_initializer()
new_data = self._frame.initializer((num_rows,) + scheme.shape, scheme.dtype)
new_data = self._frame.initializer((num_rows,) + scheme.shape, scheme.dtype, ctx)
feat_placeholders[key] = new_data

self.append(feat_placeholders)
Expand Down
28 changes: 22 additions & 6 deletions python/dgl/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,8 @@ def edge_attr_schemes(self):
def set_n_initializer(self, initializer):
"""Set the initializer for empty node features.
Initializer is a callable that returns a tensor given the shape and data type.
Initializer is a callable that returns a tensor given the shape, data type
and device context.
Parameters
----------
Expand All @@ -745,7 +746,8 @@ def set_n_initializer(self, initializer):
def set_e_initializer(self, initializer):
"""Set the initializer for empty edge features.
Initializer is a callable that returns a tensor given the shape and data type.
Initializer is a callable that returns a tensor given the shape, data type
and device context.
Parameters
----------
Expand Down Expand Up @@ -1509,20 +1511,31 @@ def merge(self, subgraphs, reduce_func='sum'):
self._edge_frame.num_rows,
reduce_func)

def adjacency_matrix(self, ctx=F.cpu()):
def adjacency_matrix(self, transpose=False, ctx=F.cpu()):
"""Return the adjacency matrix representation of this graph.
By default, a row of returned adjacency matrix represents the destination
of an edge and the column represents the source.
When transpose is True, a row represents the source and a column represents
a destination.
Parameters
----------
ctx : optional
transpose : bool, optional (default=False)
A flag to tranpose the returned adjacency matrix.
ctx : context, optional (default=cpu)
The context of returned adjacency matrix.
Returns
-------
sparse_tensor
The adjacency matrix.
"""
return self._graph.adjacency_matrix().get(ctx)
if not isinstance(transpose, bool):
raise DGLError('Expect bool value for "transpose" arg,'
' but got %s.' % (type(transpose)))
return self._graph.adjacency_matrix(transpose, ctx)

def incidence_matrix(self, oriented=False, ctx=F.cpu()):
"""Return the incidence matrix representation of this graph.
Expand All @@ -1540,7 +1553,10 @@ def incidence_matrix(self, oriented=False, ctx=F.cpu()):
sparse_tensor
The incidence matrix.
"""
return self._graph.incidence_matrix(oriented).get(ctx)
if not isinstance(oriented, bool):
raise DGLError('Expect bool value for "oriented" arg,'
' but got %s.' % (type(oriented)))
return self._graph.incidence_matrix(oriented, ctx)

def line_graph(self, backtracking=True, shared=False):
"""Return the line graph of this graph.
Expand Down
118 changes: 59 additions & 59 deletions python/dgl/graph_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from ._ffi.base import c_array
from ._ffi.function import _init_api
from .base import DGLError
from . import backend as F
from . import utils
from .immutable_graph_index import create_immutable_graph_index
Expand Down Expand Up @@ -347,11 +348,14 @@ def edges(self, sorted=False):
utils.Index
The edge ids.
"""
edge_array = _CAPI_DGLGraphEdges(self._handle, sorted)
src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2))
return src, dst, eid
key = 'edges_s%d' % sorted
if key not in self._cache:
edge_array = _CAPI_DGLGraphEdges(self._handle, sorted)
src = utils.toindex(edge_array(0))
dst = utils.toindex(edge_array(1))
eid = utils.toindex(edge_array(2))
self._cache[key] = (src, dst, eid)
return self._cache[key]

def in_degree(self, v):
"""Return the in degree of the node.
Expand Down Expand Up @@ -470,7 +474,7 @@ def edge_subgraph(self, e):
induced_nodes = utils.toindex(rst(1))
return SubgraphIndex(rst(0), self, induced_nodes, e)

def adjacency_matrix(self, transpose=False):
def adjacency_matrix(self, transpose=False, ctx=F.cpu()):
"""Return the adjacency matrix representation of this graph.
By default, a row of returned adjacency matrix represents the destination
Expand All @@ -481,31 +485,30 @@ def adjacency_matrix(self, transpose=False):
Parameters
----------
transpose : bool
transpose : bool, optional (default=False)
A flag to tranpose the returned adjacency matrix.
Returns
-------
utils.CtxCachedObject
An object that returns tensor given context.
"""
key = 'transposed adj' if transpose else 'adj'
if not key in self._cache:
src, dst, _ = self.edges(sorted=False)
src = F.unsqueeze(src.tousertensor(), 0)
dst = F.unsqueeze(dst.tousertensor(), 0)
if transpose:
idx = F.cat([src, dst], dim=0)
else:
idx = F.cat([dst, src], dim=0)
n = self.number_of_nodes()
# FIXME(minjie): data type
dat = F.ones((self.number_of_edges(),), dtype=F.float32)
mat = F.sparse_matrix(dat, ('coo', idx), (n, n))
self._cache[key] = utils.CtxCachedObject(lambda ctx: F.copy_to(mat, ctx))
return self._cache[key]

def incidence_matrix(self, oriented=False):
SparseTensor
The adjacency matrix.
"""
src, dst, _ = self.edges(sorted=False)
src = src.tousertensor(ctx) # the index of the ctx will be cached
dst = dst.tousertensor(ctx) # the index of the ctx will be cached
src = F.unsqueeze(src, dim=0)
dst = F.unsqueeze(dst, dim=0)
if transpose:
idx = F.cat([src, dst], dim=0)
else:
idx = F.cat([dst, src], dim=0)
n = self.number_of_nodes()
# FIXME(minjie): data type
dat = F.ones((self.number_of_edges(),), dtype=F.float32, ctx=ctx)
adj = F.sparse_matrix(dat, ('coo', idx), (n, n))
return adj

def incidence_matrix(self, oriented=False, ctx=F.cpu()):
"""Return the incidence matrix representation of this graph.
Parameters
Expand All @@ -515,38 +518,35 @@ def incidence_matrix(self, oriented=False):
Returns
-------
utils.CtxCachedObject
An object that returns tensor given context.
"""
key = ('oriented ' if oriented else '') + 'incidence matrix'
if not key in self._cache:
src, dst, _ = self.edges(sorted=False)
src = src.tousertensor()
dst = dst.tousertensor()
m = self.number_of_edges()
eid = F.arange(0, m)
row = F.unsqueeze(F.cat([src, dst], dim=0), 0)
col = F.unsqueeze(F.cat([eid, eid], dim=0), 0)
idx = F.cat([row, col], dim=0)

diagonal = (src == dst)
if oriented:
# FIXME(minjie): data type
x = -F.ones((m,), dtype=F.float32)
y = F.ones((m,), dtype=F.float32)
x[diagonal] = 0
y[diagonal] = 0
dat = F.cat([x, y], dim=0)
else:
# FIXME(minjie): data type
x = F.ones((m,), dtype=F.float32)
x[diagonal] = 0
dat = F.cat([x, x], dim=0)
n = self.number_of_nodes()
mat = F.sparse_matrix(dat, ('coo', idx), (n, m))
self._cache[key] = utils.CtxCachedObject(lambda ctx: F.copy_to(mat, ctx))

return self._cache[key]
SparseTensor
The incidence matrix.
"""
src, dst, eid = self.edges(sorted=False)
src = src.tousertensor(ctx) # the index of the ctx will be cached
dst = dst.tousertensor(ctx) # the index of the ctx will be cached
eid = eid.tousertensor(ctx) # the index of the ctx will be cached
n = self.number_of_nodes()
m = self.number_of_edges()
# create index
row = F.unsqueeze(F.cat([src, dst], dim=0), 0)
col = F.unsqueeze(F.cat([eid, eid], dim=0), 0)
idx = F.cat([row, col], dim=0)
# create data
diagonal = (src == dst)
if oriented:
# FIXME(minjie): data type
x = -F.ones((m,), dtype=F.float32, ctx=ctx)
y = F.ones((m,), dtype=F.float32, ctx=ctx)
x[diagonal] = 0
y[diagonal] = 0
dat = F.cat([x, y], dim=0)
else:
# FIXME(minjie): data type
x = F.ones((m,), dtype=F.float32, ctx=ctx)
x[diagonal] = 0
dat = F.cat([x, x], dim=0)
inc = F.sparse_matrix(dat, ('coo', idx), (n, m))
return inc

def to_networkx(self):
"""Convert to networkx graph.
Expand Down
10 changes: 2 additions & 8 deletions python/dgl/immutable_graph_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def node_subgraphs(self, vs_arr):
return [ImmutableSubgraphIndex(gi, self, induced_n,
induced_e) for gi, induced_n, induced_e in zip(gis, induced_nodes, induced_edges)]

def adjacency_matrix(self, transpose=False):
def adjacency_matrix(self, transpose=False, ctx=F.cpu()):
"""Return the adjacency matrix representation of this graph.
By default, a row of returned adjacency matrix represents the destination
Expand All @@ -451,13 +451,7 @@ def adjacency_matrix(self, transpose=False):
def get_adj(ctx):
new_mat = self._sparse.adjacency_matrix(transpose)
return F.copy_to(new_mat, ctx)

if not transpose and 'in_adj' in self._cache:
return self._cache['in_adj']
elif transpose and 'out_adj' in self._cache:
return self._cache['out_adj']
else:
return utils.CtxCachedObject(lambda ctx: get_adj(ctx))
return self._sparse.adjacency_matrix(transpose, ctx)

def incidence_matrix(self, oriented=False):
"""Return the incidence matrix representation of this graph.
Expand Down
Loading

0 comments on commit 048f6d7

Please sign in to comment.