Skip to content

Commit

Permalink
[Feature] Reduce messages with scatter_add in PyTorch (dmlc#427)
Browse files Browse the repository at this point in the history
* implement pytorch spmm with gather and scatter add

* fix

* replace torch take with index_select

* comments

* comment about pytorch __getitem__ operator pitfall

* typo
  • Loading branch information
lingfanyu authored and jermainewang committed Mar 2, 2019
1 parent fe44ffe commit 3cc32a9
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions python/dgl/backend/pytorch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,17 @@ def zeros_like(input):
def ones(shape, dtype, ctx):
return th.ones(shape, dtype=dtype, device=ctx)

if TH_VERSION.version[0] == 0:
# TODO(minjie): note this does not support autograd on the `x` tensor.
# should adopt a workaround using custom op.
def spmm(x, y):
return th.spmm(x, y)
else:
# torch v1.0+
def spmm(x, y):
return th.sparse.mm(x, y)
def spmm(x, y):
dst, src = x._indices()
# scatter index
index = dst.view(-1, 1).expand(-1, y.shape[1])
# zero tensor to be scatter_add to
out = y.new_full((x.shape[0], y.shape[1]), 0)
# look up src features and multiply by edge features
# Note: using y[src] instead of index_select will lead to terrible
# performance in backward
feature = th.index_select(y, 0, src) * x._values().unsqueeze(-1)
return out.scatter_add(0, index, feature)

def unsorted_1d_segment_sum(input, seg_id, n_segs, dim):
y = th.zeros(n_segs, *input.shape[1:]).to(input)
Expand Down

0 comments on commit 3cc32a9

Please sign in to comment.