diff --git a/python/dgl/backend/backend.py b/python/dgl/backend/backend.py index 9c6d0114b293..117eb2b964e8 100644 --- a/python/dgl/backend/backend.py +++ b/python/dgl/backend/backend.py @@ -77,6 +77,9 @@ def tensor(data, dtype=None): def sparse_matrix(data, index, shape, force_format=False): """Create a sparse matrix. + NOTE: Please make sure that the data and index tensors are not + copied. This is critical to the performance. + Parameters ---------- data : Tensor @@ -482,7 +485,7 @@ def reshape(input, shape): """ pass -def zeros(shape, dtype): +def zeros(shape, dtype, ctx): """Create a zero tensor. Parameters @@ -491,6 +494,8 @@ def zeros(shape, dtype): The tensor shape. dtype : data type It should be one of the values in the data type dict. + ctx : context + The device of the result tensor. Returns ------- @@ -499,7 +504,7 @@ def zeros(shape, dtype): """ pass -def ones(shape, dtype): +def ones(shape, dtype, ctx): """Create a one tensor. Parameters @@ -508,6 +513,8 @@ def ones(shape, dtype): The tensor shape. dtype : data type It should be one of the values in the data type dict. + ctx : context + The device of the result tensor. Returns ------- diff --git a/python/dgl/backend/mxnet/immutable_graph_index.py b/python/dgl/backend/mxnet/immutable_graph_index.py index 5fbaa0c819fe..b95986878e84 100644 --- a/python/dgl/backend/mxnet/immutable_graph_index.py +++ b/python/dgl/backend/mxnet/immutable_graph_index.py @@ -268,7 +268,7 @@ def node_subgraphs(self, vs_arr): induced_es.append(induced_e) return gis, induced_ns, induced_es - def adjacency_matrix(self, transpose=False): + def adjacency_matrix(self, transpose, ctx): """Return the adjacency matrix representation of this graph. By default, a row of returned adjacency matrix represents the destination @@ -281,6 +281,8 @@ def adjacency_matrix(self, transpose=False): ---------- transpose : bool A flag to tranpose the returned adjacency matrix. + ctx : context + The device context of the returned matrix. Returns ------- @@ -294,7 +296,7 @@ def adjacency_matrix(self, transpose=False): indices = mat.indices indptr = mat.indptr - data = mx.nd.ones(indices.shape, dtype=np.float32) + data = mx.nd.ones(indices.shape, dtype=np.float32, ctx=ctx) return mx.nd.sparse.csr_matrix((data, indices, indptr), shape=mat.shape) def from_coo_matrix(self, out_coo): diff --git a/python/dgl/backend/mxnet/tensor.py b/python/dgl/backend/mxnet/tensor.py index 59990dd57301..fa371aa43209 100644 --- a/python/dgl/backend/mxnet/tensor.py +++ b/python/dgl/backend/mxnet/tensor.py @@ -114,11 +114,11 @@ def reshape(input, shape): # NOTE: the input cannot be a symbol return nd.reshape(input ,shape) -def zeros(shape, dtype): - return nd.zeros(shape, dtype=dtype) +def zeros(shape, dtype, ctx): + return nd.zeros(shape, dtype=dtype, ctx=ctx) -def ones(shape, dtype): - return nd.ones(shape, dtype=dtype) +def ones(shape, dtype, ctx): + return nd.ones(shape, dtype=dtype, ctx=ctx) def spmm(x, y): return nd.dot(x, y) diff --git a/python/dgl/backend/pytorch/tensor.py b/python/dgl/backend/pytorch/tensor.py index fb3045a8a5f4..bb540ed33536 100644 --- a/python/dgl/backend/pytorch/tensor.py +++ b/python/dgl/backend/pytorch/tensor.py @@ -23,7 +23,8 @@ def sparse_matrix(data, index, shape, force_format=False): fmt = index[0] if fmt != 'coo': raise TypeError('Pytorch backend only supports COO format. But got %s.' % fmt) - return th.sparse.FloatTensor(index[1], data, shape) + # NOTE: use _sparse_coo_tensor_unsafe to avoid unnecessary boundary check + return th._sparse_coo_tensor_unsafe(index[1], data, shape) def sparse_matrix_indices(spmat): return ('coo', spmat._indices()) @@ -98,11 +99,11 @@ def unsqueeze(input, dim): def reshape(input, shape): return th.reshape(input ,shape) -def zeros(shape, dtype): - return th.zeros(shape, dtype=dtype) +def zeros(shape, dtype, ctx): + return th.zeros(shape, dtype=dtype, device=ctx) -def ones(shape, dtype): - return th.ones(shape, dtype=dtype) +def ones(shape, dtype, ctx): + return th.ones(shape, dtype=dtype, device=ctx) def spmm(x, y): return th.spmm(x, y) diff --git a/python/dgl/frame.py b/python/dgl/frame.py index 244d85db0572..abd8321b10ab 100644 --- a/python/dgl/frame.py +++ b/python/dgl/frame.py @@ -183,7 +183,7 @@ def _warn_and_set_initializer(self): dgl_warning('Initializer is not set. Use zero initializer instead.' ' To suppress this warning, use `set_initializer` to' ' explicitly specify which initializer to use.') - self._initializer = lambda shape, dtype: F.zeros(shape, dtype) + self._initializer = lambda shape, dtype, ctx: F.zeros(shape, dtype, ctx) def set_initializer(self, initializer): """Set the initializer for empty values. @@ -283,9 +283,7 @@ def add_column(self, name, scheme, ctx): ' one column in the frame so number of rows can be inferred.' % name) if self.initializer is None: self._warn_and_set_initializer() - # TODO(minjie): directly init data on the targer device. - init_data = self.initializer((self.num_rows,) + scheme.shape, scheme.dtype) - init_data = F.copy_to(init_data, ctx) + init_data = self.initializer((self.num_rows,) + scheme.shape, scheme.dtype, ctx) self._columns[name] = Column(init_data, scheme) def update_column(self, name, data): @@ -601,10 +599,10 @@ def add_rows(self, num_rows): for key in self._frame: scheme = self._frame[key].scheme - + ctx = F.context(self._frame[key].data) if self._frame.initializer is None: self._frame._warn_and_set_initializer() - new_data = self._frame.initializer((num_rows,) + scheme.shape, scheme.dtype) + new_data = self._frame.initializer((num_rows,) + scheme.shape, scheme.dtype, ctx) feat_placeholders[key] = new_data self.append(feat_placeholders) diff --git a/python/dgl/graph.py b/python/dgl/graph.py index b3ac107799b5..c5f4ffaa216e 100644 --- a/python/dgl/graph.py +++ b/python/dgl/graph.py @@ -733,7 +733,8 @@ def edge_attr_schemes(self): def set_n_initializer(self, initializer): """Set the initializer for empty node features. - Initializer is a callable that returns a tensor given the shape and data type. + Initializer is a callable that returns a tensor given the shape, data type + and device context. Parameters ---------- @@ -745,7 +746,8 @@ def set_n_initializer(self, initializer): def set_e_initializer(self, initializer): """Set the initializer for empty edge features. - Initializer is a callable that returns a tensor given the shape and data type. + Initializer is a callable that returns a tensor given the shape, data type + and device context. Parameters ---------- @@ -1509,12 +1511,20 @@ def merge(self, subgraphs, reduce_func='sum'): self._edge_frame.num_rows, reduce_func) - def adjacency_matrix(self, ctx=F.cpu()): + def adjacency_matrix(self, transpose=False, ctx=F.cpu()): """Return the adjacency matrix representation of this graph. + By default, a row of returned adjacency matrix represents the destination + of an edge and the column represents the source. + + When transpose is True, a row represents the source and a column represents + a destination. + Parameters ---------- - ctx : optional + transpose : bool, optional (default=False) + A flag to tranpose the returned adjacency matrix. + ctx : context, optional (default=cpu) The context of returned adjacency matrix. Returns @@ -1522,7 +1532,10 @@ def adjacency_matrix(self, ctx=F.cpu()): sparse_tensor The adjacency matrix. """ - return self._graph.adjacency_matrix().get(ctx) + if not isinstance(transpose, bool): + raise DGLError('Expect bool value for "transpose" arg,' + ' but got %s.' % (type(transpose))) + return self._graph.adjacency_matrix(transpose, ctx) def incidence_matrix(self, oriented=False, ctx=F.cpu()): """Return the incidence matrix representation of this graph. @@ -1540,7 +1553,10 @@ def incidence_matrix(self, oriented=False, ctx=F.cpu()): sparse_tensor The incidence matrix. """ - return self._graph.incidence_matrix(oriented).get(ctx) + if not isinstance(oriented, bool): + raise DGLError('Expect bool value for "oriented" arg,' + ' but got %s.' % (type(oriented))) + return self._graph.incidence_matrix(oriented, ctx) def line_graph(self, backtracking=True, shared=False): """Return the line graph of this graph. diff --git a/python/dgl/graph_index.py b/python/dgl/graph_index.py index 83b384a5064b..9cf2d387dec6 100644 --- a/python/dgl/graph_index.py +++ b/python/dgl/graph_index.py @@ -7,6 +7,7 @@ from ._ffi.base import c_array from ._ffi.function import _init_api +from .base import DGLError from . import backend as F from . import utils from .immutable_graph_index import create_immutable_graph_index @@ -347,11 +348,14 @@ def edges(self, sorted=False): utils.Index The edge ids. """ - edge_array = _CAPI_DGLGraphEdges(self._handle, sorted) - src = utils.toindex(edge_array(0)) - dst = utils.toindex(edge_array(1)) - eid = utils.toindex(edge_array(2)) - return src, dst, eid + key = 'edges_s%d' % sorted + if key not in self._cache: + edge_array = _CAPI_DGLGraphEdges(self._handle, sorted) + src = utils.toindex(edge_array(0)) + dst = utils.toindex(edge_array(1)) + eid = utils.toindex(edge_array(2)) + self._cache[key] = (src, dst, eid) + return self._cache[key] def in_degree(self, v): """Return the in degree of the node. @@ -470,7 +474,7 @@ def edge_subgraph(self, e): induced_nodes = utils.toindex(rst(1)) return SubgraphIndex(rst(0), self, induced_nodes, e) - def adjacency_matrix(self, transpose=False): + def adjacency_matrix(self, transpose=False, ctx=F.cpu()): """Return the adjacency matrix representation of this graph. By default, a row of returned adjacency matrix represents the destination @@ -481,31 +485,30 @@ def adjacency_matrix(self, transpose=False): Parameters ---------- - transpose : bool + transpose : bool, optional (default=False) A flag to tranpose the returned adjacency matrix. Returns ------- - utils.CtxCachedObject - An object that returns tensor given context. - """ - key = 'transposed adj' if transpose else 'adj' - if not key in self._cache: - src, dst, _ = self.edges(sorted=False) - src = F.unsqueeze(src.tousertensor(), 0) - dst = F.unsqueeze(dst.tousertensor(), 0) - if transpose: - idx = F.cat([src, dst], dim=0) - else: - idx = F.cat([dst, src], dim=0) - n = self.number_of_nodes() - # FIXME(minjie): data type - dat = F.ones((self.number_of_edges(),), dtype=F.float32) - mat = F.sparse_matrix(dat, ('coo', idx), (n, n)) - self._cache[key] = utils.CtxCachedObject(lambda ctx: F.copy_to(mat, ctx)) - return self._cache[key] - - def incidence_matrix(self, oriented=False): + SparseTensor + The adjacency matrix. + """ + src, dst, _ = self.edges(sorted=False) + src = src.tousertensor(ctx) # the index of the ctx will be cached + dst = dst.tousertensor(ctx) # the index of the ctx will be cached + src = F.unsqueeze(src, dim=0) + dst = F.unsqueeze(dst, dim=0) + if transpose: + idx = F.cat([src, dst], dim=0) + else: + idx = F.cat([dst, src], dim=0) + n = self.number_of_nodes() + # FIXME(minjie): data type + dat = F.ones((self.number_of_edges(),), dtype=F.float32, ctx=ctx) + adj = F.sparse_matrix(dat, ('coo', idx), (n, n)) + return adj + + def incidence_matrix(self, oriented=False, ctx=F.cpu()): """Return the incidence matrix representation of this graph. Parameters @@ -515,38 +518,35 @@ def incidence_matrix(self, oriented=False): Returns ------- - utils.CtxCachedObject - An object that returns tensor given context. - """ - key = ('oriented ' if oriented else '') + 'incidence matrix' - if not key in self._cache: - src, dst, _ = self.edges(sorted=False) - src = src.tousertensor() - dst = dst.tousertensor() - m = self.number_of_edges() - eid = F.arange(0, m) - row = F.unsqueeze(F.cat([src, dst], dim=0), 0) - col = F.unsqueeze(F.cat([eid, eid], dim=0), 0) - idx = F.cat([row, col], dim=0) - - diagonal = (src == dst) - if oriented: - # FIXME(minjie): data type - x = -F.ones((m,), dtype=F.float32) - y = F.ones((m,), dtype=F.float32) - x[diagonal] = 0 - y[diagonal] = 0 - dat = F.cat([x, y], dim=0) - else: - # FIXME(minjie): data type - x = F.ones((m,), dtype=F.float32) - x[diagonal] = 0 - dat = F.cat([x, x], dim=0) - n = self.number_of_nodes() - mat = F.sparse_matrix(dat, ('coo', idx), (n, m)) - self._cache[key] = utils.CtxCachedObject(lambda ctx: F.copy_to(mat, ctx)) - - return self._cache[key] + SparseTensor + The incidence matrix. + """ + src, dst, eid = self.edges(sorted=False) + src = src.tousertensor(ctx) # the index of the ctx will be cached + dst = dst.tousertensor(ctx) # the index of the ctx will be cached + eid = eid.tousertensor(ctx) # the index of the ctx will be cached + n = self.number_of_nodes() + m = self.number_of_edges() + # create index + row = F.unsqueeze(F.cat([src, dst], dim=0), 0) + col = F.unsqueeze(F.cat([eid, eid], dim=0), 0) + idx = F.cat([row, col], dim=0) + # create data + diagonal = (src == dst) + if oriented: + # FIXME(minjie): data type + x = -F.ones((m,), dtype=F.float32, ctx=ctx) + y = F.ones((m,), dtype=F.float32, ctx=ctx) + x[diagonal] = 0 + y[diagonal] = 0 + dat = F.cat([x, y], dim=0) + else: + # FIXME(minjie): data type + x = F.ones((m,), dtype=F.float32, ctx=ctx) + x[diagonal] = 0 + dat = F.cat([x, x], dim=0) + inc = F.sparse_matrix(dat, ('coo', idx), (n, m)) + return inc def to_networkx(self): """Convert to networkx graph. diff --git a/python/dgl/immutable_graph_index.py b/python/dgl/immutable_graph_index.py index b7c209af003f..611341946838 100644 --- a/python/dgl/immutable_graph_index.py +++ b/python/dgl/immutable_graph_index.py @@ -429,7 +429,7 @@ def node_subgraphs(self, vs_arr): return [ImmutableSubgraphIndex(gi, self, induced_n, induced_e) for gi, induced_n, induced_e in zip(gis, induced_nodes, induced_edges)] - def adjacency_matrix(self, transpose=False): + def adjacency_matrix(self, transpose=False, ctx=F.cpu()): """Return the adjacency matrix representation of this graph. By default, a row of returned adjacency matrix represents the destination @@ -451,13 +451,7 @@ def adjacency_matrix(self, transpose=False): def get_adj(ctx): new_mat = self._sparse.adjacency_matrix(transpose) return F.copy_to(new_mat, ctx) - - if not transpose and 'in_adj' in self._cache: - return self._cache['in_adj'] - elif transpose and 'out_adj' in self._cache: - return self._cache['out_adj'] - else: - return utils.CtxCachedObject(lambda ctx: get_adj(ctx)) + return self._sparse.adjacency_matrix(transpose, ctx) def incidence_matrix(self, oriented=False): """Return the incidence matrix representation of this graph. diff --git a/python/dgl/scheduler.py b/python/dgl/scheduler.py index ebf0e73c1f05..ed8b14582547 100644 --- a/python/dgl/scheduler.py +++ b/python/dgl/scheduler.py @@ -276,10 +276,10 @@ def _adj_build_fn(self, edge_field, ctx, use_edge_feat): if len(F.shape(dat)) > 1: # The edge feature is of shape (N, 1) dat = F.squeeze(dat, 1) - idx = F.sparse_matrix_indices(self.g.adjacency_matrix(ctx)) + idx = F.sparse_matrix_indices(self.g.adjacency_matrix(ctx=ctx)) adjmat = F.sparse_matrix(dat, idx, self.graph_shape) else: - adjmat = self.g.adjacency_matrix(ctx) + adjmat = self.g.adjacency_matrix(ctx=ctx) return adjmat @@ -347,8 +347,7 @@ def _adj_build_fn(self, edge_field, ctx, use_edge_feat): # edge feature is of shape (N, 1) dat = F.squeeze(dat, dim=1) else: - # TODO(minjie): data type should be adjusted according t othe usage. - dat = F.ones((len(self.u), ), dtype=F.float32) + dat = F.ones((len(self.u), ), dtype=F.float32, ctx=ctx) adjmat = F.sparse_matrix(dat, ('coo', self.graph_idx), self.graph_shape) return F.copy_to(adjmat, ctx) diff --git a/python/dgl/utils.py b/python/dgl/utils.py index 2d17347add0d..1c0e8a377ed8 100644 --- a/python/dgl/utils.py +++ b/python/dgl/utils.py @@ -244,7 +244,7 @@ def build_relabel_map(x): x = x.tousertensor() unique_x, _ = F.sort_1d(F.unique(x)) map_len = int(F.max(unique_x, dim=0)) + 1 - old_to_new = F.zeros(map_len, dtype=F.int64) + old_to_new = F.zeros(map_len, dtype=F.int64, ctx=F.cpu()) F.scatter_row_inplace(old_to_new, unique_x, F.arange(0, len(unique_x))) return unique_x, old_to_new diff --git a/tests/mxnet/test_graph_index.py b/tests/mxnet/test_graph_index.py index 1380a39a19b0..849b1556754d 100644 --- a/tests/mxnet/test_graph_index.py +++ b/tests/mxnet/test_graph_index.py @@ -14,8 +14,8 @@ def generate_rand_graph(n): return g, ig def check_graph_equal(g1, g2): - adj1 = g1.adjacency_matrix().get(mx.cpu()) != 0 - adj2 = g2.adjacency_matrix().get(mx.cpu()) != 0 + adj1 = g1.adjacency_matrix(ctx=mx.cpu()) != 0 + adj2 = g2.adjacency_matrix(ctx=mx.cpu()) != 0 assert mx.nd.sum(adj1 - adj2).asnumpy() == 0 def test_graph_gen(): diff --git a/tests/pytorch/test_basics.py b/tests/pytorch/test_basics.py index 8ae0645ed257..a7eb288c78e1 100644 --- a/tests/pytorch/test_basics.py +++ b/tests/pytorch/test_basics.py @@ -40,8 +40,8 @@ def generate_graph(grad=False): ecol = Variable(th.randn(17, D), requires_grad=grad) g.ndata['h'] = ncol g.edata['w'] = ecol - g.set_n_initializer(lambda shape, dtype : th.zeros(shape)) - g.set_e_initializer(lambda shape, dtype : th.zeros(shape)) + g.set_n_initializer(lambda shape, dtype, ctx : th.zeros(shape, dtype=dtype, device=ctx)) + g.set_e_initializer(lambda shape, dtype, ctx : th.zeros(shape, dtype=dtype, device=ctx)) return g def test_batch_setter_getter(): diff --git a/tests/pytorch/test_graph.py b/tests/pytorch/test_graph.py new file mode 100644 index 000000000000..ac83de17b9e0 --- /dev/null +++ b/tests/pytorch/test_graph.py @@ -0,0 +1,40 @@ +import time +import math +import numpy as np +import scipy.sparse as sp +import torch as th +import dgl + +def test_adjmat_speed(): + n = 1000 + p = 10 * math.log(n) / n + a = sp.random(n, n, p, data_rvs=lambda n: np.ones(n)) + g = dgl.DGLGraph(a) + # the first call should contruct the adj + t0 = time.time() + g.adjacency_matrix() + dur1 = time.time() - t0 + # the second call should be cached and should be very fast + t0 = time.time() + g.adjacency_matrix() + dur2 = time.time() - t0 + assert dur2 < dur1 / 5 + +def test_incmat_speed(): + n = 1000 + p = 10 * math.log(n) / n + a = sp.random(n, n, p, data_rvs=lambda n: np.ones(n)) + g = dgl.DGLGraph(a) + # the first call should contruct the adj + t0 = time.time() + g.incidence_matrix() + dur1 = time.time() - t0 + # the second call should be cached and should be very fast + t0 = time.time() + g.incidence_matrix() + dur2 = time.time() - t0 + assert dur2 < dur1 + +if __name__ == '__main__': + test_adjmat_speed() + test_incmat_speed() diff --git a/tests/pytorch/test_readout.py b/tests/pytorch/test_readout.py index 25e763ffbbe5..c4efc167f486 100644 --- a/tests/pytorch/test_readout.py +++ b/tests/pytorch/test_readout.py @@ -1,5 +1,6 @@ import torch as th import dgl +import utils as U def test_simple_readout(): g1 = dgl.DGLGraph() @@ -29,26 +30,26 @@ def test_simple_readout(): g2.ndata['w'] = w2 g1.edata['x'] = e1 - assert th.allclose(dgl.sum_nodes(g1, 'x'), s1) - assert th.allclose(dgl.sum_nodes(g1, 'x', 'w'), ws1) - assert th.allclose(dgl.sum_edges(g1, 'x'), se1) - assert th.allclose(dgl.mean_nodes(g1, 'x'), m1) - assert th.allclose(dgl.mean_nodes(g1, 'x', 'w'), wm1) - assert th.allclose(dgl.mean_edges(g1, 'x'), me1) + assert U.allclose(dgl.sum_nodes(g1, 'x'), s1) + assert U.allclose(dgl.sum_nodes(g1, 'x', 'w'), ws1) + assert U.allclose(dgl.sum_edges(g1, 'x'), se1) + assert U.allclose(dgl.mean_nodes(g1, 'x'), m1) + assert U.allclose(dgl.mean_nodes(g1, 'x', 'w'), wm1) + assert U.allclose(dgl.mean_edges(g1, 'x'), me1) g = dgl.batch([g1, g2]) s = dgl.sum_nodes(g, 'x') m = dgl.mean_nodes(g, 'x') - assert th.allclose(s, th.stack([s1, s2], 0)) - assert th.allclose(m, th.stack([m1, m2], 0)) + assert U.allclose(s, th.stack([s1, s2], 0)) + assert U.allclose(m, th.stack([m1, m2], 0)) ws = dgl.sum_nodes(g, 'x', 'w') wm = dgl.mean_nodes(g, 'x', 'w') - assert th.allclose(ws, th.stack([ws1, ws2], 0)) - assert th.allclose(wm, th.stack([wm1, wm2], 0)) + assert U.allclose(ws, th.stack([ws1, ws2], 0)) + assert U.allclose(wm, th.stack([wm1, wm2], 0)) s = dgl.sum_edges(g, 'x') m = dgl.mean_edges(g, 'x') - assert th.allclose(s, th.stack([se1, th.zeros(5)], 0)) - assert th.allclose(m, th.stack([me1, th.zeros(5)], 0)) + assert U.allclose(s, th.stack([se1, th.zeros(5)], 0)) + assert U.allclose(m, th.stack([me1, th.zeros(5)], 0)) if __name__ == '__main__':