Skip to content

Commit

Permalink
Spmv partial (dmlc#43)
Browse files Browse the repository at this point in the history
* partial spmv impl and test

* some fix for update edge
  • Loading branch information
jermainewang authored Aug 13, 2018
1 parent ee24169 commit e3bac70
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 51 deletions.
9 changes: 5 additions & 4 deletions python/dgl/backend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
sum = th.sum
max = th.max

def astype(a, ty):
return a.type(ty)

def asnumpy(a):
return a.cpu().numpy()

Expand Down Expand Up @@ -50,16 +53,14 @@ def broadcast_to(x, to_array):
return x + th.zeros_like(to_array)

nonzero = th.nonzero

def eq_scalar(x, val):
return th.eq(x, float(val))

squeeze = th.squeeze
unsqueeze = th.unsqueeze
reshape = th.reshape
zeros = th.zeros
ones = th.ones
spmm = th.spmm
sort = th.sort
arange = th.arange

def to_context(x, ctx):
if ctx is None:
Expand Down
126 changes: 86 additions & 40 deletions python/dgl/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,24 +436,32 @@ def sendto(self, u, v, message_func=None, batchable=False):

def _nonbatch_sendto(self, u, v, message_func):
f_msg = _get_message_func(message_func)
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
for uu, vv in utils.edge_iter(u, v):
ret = f_msg(_get_repr(self.nodes[uu]),
_get_repr(self.edges[uu, vv]))
self.edges[uu, vv][__MSG__] = ret

def _batch_sendto(self, u, v, message_func):
f_msg = _get_message_func(message_func)
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
u = utils.convert_to_id_tensor(u)
v = utils.convert_to_id_tensor(v)
eid = self.cached_graph.get_edge_id(u, v)
self.msg_graph.add_edges(u, v)
if len(u) != len(v) and len(u) == 1:
u = F.broadcast_to(u, v)
# call UDF
src_reprs = self.get_n_repr(u)
edge_reprs = self.get_e_repr_by_id(eid)
msgs = message_func(src_reprs, edge_reprs)
self.msg_graph.add_edges(u, v)
# call UDF
src_reprs = self.get_n_repr(u)
edge_reprs = self.get_e_repr()
msgs = message_func(src_reprs, edge_reprs)
else:
u = utils.convert_to_id_tensor(u)
v = utils.convert_to_id_tensor(v)
u, v = utils.edge_broadcasting(u, v)
eid = self.cached_graph.get_edge_id(u, v)
self.msg_graph.add_edges(u, v)
# call UDF
src_reprs = self.get_n_repr(u)
edge_reprs = self.get_e_repr_by_id(eid)
msgs = message_func(src_reprs, edge_reprs)
if isinstance(msgs, dict):
self._msg_frame.append(msgs)
else:
Expand Down Expand Up @@ -490,26 +498,34 @@ def update_edge(self, u, v, edge_func=None, batchable=False):
self._nonbatch_update_edge(u, v, edge_func)

def _nonbatch_update_edge(self, u, v, edge_func):
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
for uu, vv in utils.edge_iter(u, v):
ret = edge_func(_get_repr(self.nodes[uu]),
_get_repr(self.nodes[vv]),
_get_repr(self.edges[uu, vv]))
_set_repr(self.edges[uu, vv], ret)

def _batch_update_edge(self, u, v, edge_func):
u = utils.convert_to_id_tensor(u)
v = utils.convert_to_id_tensor(v)
eid = self.cached_graph.get_edge_id(u, v)
if len(u) != len(v) and len(u) == 1:
u = F.broadcast_to(u, v)
elif len(u) != len(v) and len(v) == 1:
v = F.broadcast_to(v, u)
# call the UDF
src_reprs = self.get_n_repr(u)
dst_reprs = self.get_n_repr(v)
edge_reprs = self.get_e_repr_by_id(eid)
new_edge_reprs = edge_func(src_reprs, dst_reprs, edge_reprs)
self.set_e_repr_by_id(new_edge_reprs, eid)
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
# call the UDF
src_reprs = self.get_n_repr(u)
dst_reprs = self.get_n_repr(v)
edge_reprs = self.get_e_repr()
new_edge_reprs = edge_func(src_reprs, dst_reprs, edge_reprs)
self.set_e_repr(new_edge_reprs)
else:
u = utils.convert_to_id_tensor(u)
v = utils.convert_to_id_tensor(v)
u, v = utils.edge_broadcasting(u, v)
eid = self.cached_graph.get_edge_id(u, v)
# call the UDF
src_reprs = self.get_n_repr(u)
dst_reprs = self.get_n_repr(v)
edge_reprs = self.get_e_repr_by_id(eid)
new_edge_reprs = edge_func(src_reprs, dst_reprs, edge_reprs)
self.set_e_repr_by_id(new_edge_reprs, eid)

def recv(self,
u,
Expand Down Expand Up @@ -566,6 +582,8 @@ def recv(self,
def _nonbatch_recv(self, u, reduce_func, update_func):
f_reduce = _get_reduce_func(reduce_func)
f_update = update_func
if is_all(u):
u = list(range(0, self.number_of_nodes()))
for i, uu in enumerate(utils.node_iter(u)):
# reduce phase
msgs_batch = [self.edges[vv, uu].pop(__MSG__)
Expand Down Expand Up @@ -702,6 +720,8 @@ def _nonbatch_update_by_edge(
message_func,
reduce_func,
update_func):
if is_all(u) and is_all(v):
u, v = self.cached_graph.edges()
self._nonbatch_sendto(u, v, message_func)
dst = set()
for uu, vv in utils.edge_iter(u, v):
Expand All @@ -714,26 +734,39 @@ def _batch_update_by_edge(
message_func,
reduce_func,
update_func):
if message_func == 'from_src' and reduce_func == 'sum' \
and is_all(u) and is_all(v):
# TODO(minjie): SPMV is only supported for updating all nodes right now.
adjmat = self.cached_graph.adjmat(self.context)
if is_all(u) and is_all(v):
self.update_all(message_func, reduce_func, update_func, True)
elif message_func == 'from_src' and reduce_func == 'sum':
# TODO(minjie): check the validity of edges u->v
u = utils.convert_to_id_tensor(u)
v = utils.convert_to_id_tensor(v)
# TODO(minjie): broadcasting is optional for many-one input.
u, v = utils.edge_broadcasting(u, v)
# relabel destination nodes.
new2old, old2new = utils.build_relabel_map(v)
# TODO(minjie): should not directly use []
new_v = old2new[v]
# create adj mat
idx = F.pack([F.unsqueeze(new_v, 0), F.unsqueeze(u, 0)])
dat = F.ones((len(u),))
n = self.number_of_nodes()
m = len(new2old)
adjmat = F.sparse_tensor(idx, dat, [m, n])
adjmat = F.to_context(adjmat, self.context)
# TODO(minjie): use lazy dict for reduced_msgs
reduced_msgs = {}
for key in self._node_frame.schemes:
col = self._node_frame[key]
reduced_msgs[key] = F.spmm(adjmat, col)
node_repr = self.get_n_repr()
if len(reduced_msgs) == 1 and __REPR__ in reduced_msgs:
reduced_msgs = reduced_msgs[__REPR__]
self.set_n_repr(update_func(node_repr, reduced_msgs))
node_repr = self.get_n_repr(new2old)
new_node_repr = update_func(node_repr, reduced_msgs)
self.set_n_repr(new_node_repr, new2old)
else:
if is_all(u) and is_all(v):
self._batch_sendto(u, v, message_func)
self._batch_recv(v, reduce_func, update_func)
else:
self._batch_sendto(u, v, message_func)
unique_v = F.unique(v)
self._batch_recv(unique_v, reduce_func, update_func)
self._batch_sendto(u, v, message_func)
unique_v = F.unique(v)
self._batch_recv(unique_v, reduce_func, update_func)

def update_to(self,
v,
Expand Down Expand Up @@ -845,11 +878,24 @@ def update_all(self,
assert reduce_func is not None
assert update_func is not None
if batchable:
self._batch_update_by_edge(ALL, ALL,
message_func, reduce_func, update_func)
if message_func == 'from_src' and reduce_func == 'sum':
# TODO(minjie): use lazy dict for reduced_msgs
adjmat = self.cached_graph.adjmat(self.context)
reduced_msgs = {}
for key in self._node_frame.schemes:
col = self._node_frame[key]
reduced_msgs[key] = F.spmm(adjmat, col)
if len(reduced_msgs) == 1 and __REPR__ in reduced_msgs:
reduced_msgs = reduced_msgs[__REPR__]
node_repr = self.get_n_repr()
self.set_n_repr(update_func(node_repr, reduced_msgs))
else:
self._batch_sendto(ALL, ALL, message_func)
self._batch_recv(ALL, reduce_func, update_func)
else:
u = [uu for uu, _ in self.edges]
v = [vv for _, vv in self.edges]
u, v = zip(*self.edges)
u = list(u)
v = list(v)
self._nonbatch_sendto(u, v, message_func)
self._nonbatch_recv(list(self.nodes()), reduce_func, update_func)

Expand Down
41 changes: 41 additions & 0 deletions python/dgl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,21 @@
from dgl.backend import Tensor, SparseTensor

def is_id_tensor(u):
"""Return whether the input is a supported id tensor."""
return isinstance(u, Tensor) and F.isinteger(u) and len(F.shape(u)) == 1

def is_id_container(u):
"""Return whether the input is a supported id container."""
return isinstance(u, list)

def node_iter(n):
"""Return an iterator that loops over the given nodes."""
n = convert_to_id_container(n)
for nn in n:
yield nn

def edge_iter(u, v):
"""Return an iterator that loops over the given edges."""
u = convert_to_id_container(u)
v = convert_to_id_container(v)
if len(u) == len(v):
Expand All @@ -35,6 +39,7 @@ def edge_iter(u, v):
raise ValueError('Error edges:', u, v)

def convert_to_id_container(x):
"""Convert the input to id container."""
if is_id_container(x):
return x
elif is_id_tensor(x):
Expand All @@ -47,6 +52,7 @@ def convert_to_id_container(x):
return None

def convert_to_id_tensor(x, ctx=None):
"""Convert the input to id tensor."""
if is_id_container(x):
ret = F.tensor(x, dtype=F.int64)
elif is_id_tensor(x):
Expand Down Expand Up @@ -81,3 +87,38 @@ def __iter__(self):

def __len__(self):
return len(self._keys)

def build_relabel_map(x):
"""Relabel the input ids to continuous ids that starts from zero.
Parameters
----------
x : int, tensor or container
The input ids.
Returns
-------
new_to_old : tensor
The mapping from new id to old id.
old_to_new : tensor
The mapping from old id to new id. It is a vector of length MAX(x).
One can use advanced indexing to convert an old id tensor to a
new id tensor: new_id = old_to_new[old_id]
"""
x = convert_to_id_tensor(x)
unique_x, _ = F.sort(F.unique(x))
map_len = int(F.max(unique_x)) + 1
old_to_new = F.zeros(map_len, dtype=F.int64)
# TODO(minjie): should not directly use []
old_to_new[unique_x] = F.astype(F.arange(len(unique_x)), F.int64)
return unique_x, old_to_new

def edge_broadcasting(u, v):
"""Convert one-many and many-one edges to many-many."""
if len(u) != len(v) and len(u) == 1:
u = F.broadcast_to(u, v)
elif len(u) != len(v) and len(v) == 1:
v = F.broadcast_to(v, u)
else:
assert len(u) == len(v)
return u, v
20 changes: 13 additions & 7 deletions tests/pytorch/test_specialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,22 @@ def generate_graph():

def test_spmv_specialize():
g = generate_graph()
g.register_message_func('from_src', batchable=True)
g.register_reduce_func('sum', batchable=True)
g.register_update_func(update_func, batchable=True)
# update all
v1 = g.get_n_repr()
g.update_all()
g.update_all('from_src', 'sum', update_func, batchable=True)
v2 = g.get_n_repr()
g.set_n_repr(v1)
g.register_message_func(message_func, batchable=True)
g.register_reduce_func(reduce_func, batchable=True)
g.update_all()
g.update_all(message_func, reduce_func, update_func, batchable=True)
v3 = g.get_n_repr()
check_eq(v2, v3)
# partial update
u = th.tensor([0, 0, 0, 3, 4, 9])
v = th.tensor([1, 2, 3, 9, 9, 0])
v1 = g.get_n_repr()
g.update_by_edge(u, v, 'from_src', 'sum', update_func, batchable=True)
v2 = g.get_n_repr()
g.set_n_repr(v1)
g.update_by_edge(u, v, message_func, reduce_func, update_func, batchable=True)
v3 = g.get_n_repr()
check_eq(v2, v3)

Expand Down

0 comments on commit e3bac70

Please sign in to comment.