Skip to content

Commit

Permalink
Revert "Refactor code for retaining formats in message-passing. (dmlc…
Browse files Browse the repository at this point in the history
…#2570)" (dmlc#2583)

This reverts commit a613ad8.
  • Loading branch information
jermainewang authored Jan 28, 2021
1 parent 7bab136 commit 878acdb
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 47 deletions.
10 changes: 5 additions & 5 deletions python/dgl/backend/mxnet/sparse.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import mxnet as mx
import numpy as np
from mxnet import nd
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp, _reverse
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp
from ...base import dgl_warning, is_all, ALL
from .tensor import asnumpy, copy_to, zerocopy_from_numpy, context, to_backend_ctx

Expand Down Expand Up @@ -132,7 +132,7 @@ def backward(self, dZ):
X, Y, argX, argY = self.saved_tensors
gidx, op, reduce_op = self.gidx, self.op, self.reduce_op
if op != 'copy_rhs':
g_rev = _reverse(gidx)
g_rev = gidx.reverse()
if reduce_op == 'sum':
if op in ['mul', 'div']:
dX = _gspmm(g_rev, 'mul', 'sum', dZ, _muldiv(op, Y))[0]
Expand Down Expand Up @@ -215,7 +215,7 @@ def backward(self, dZ):
lhs_target, rhs_target = self.lhs_target, self.rhs_target
if op != 'copy_rhs':
if lhs_target in ['u', 'v']:
_gidx = gidx if self.lhs_target == 'v' else _reverse(gidx)
_gidx = gidx if self.lhs_target == 'v' else gidx.reverse()
if op in ['add', 'sub', 'copy_lhs']:
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0]
else: # mul, div, dot
Expand All @@ -235,7 +235,7 @@ def backward(self, dZ):
dX = nd.zeros_like(X)
if op != 'copy_lhs':
if self.rhs_target in ['u', 'v']:
_gidx = gidx if rhs_target == 'v' else _reverse(gidx)
_gidx = gidx if rhs_target == 'v' else gidx.reverse()
if op in ['add', 'sub', 'copy_rhs']:
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, _addsub(op, dZ))[0]
else: # mul, div, dot
Expand Down Expand Up @@ -277,7 +277,7 @@ def __init__(self, gidx, eids, norm_by):
if not is_all(eids):
gidx = gidx.edge_subgraph([eids], True).graph
if norm_by == 'src':
gidx = _reverse(gidx)
gidx = gidx.reverse()
self.gidx = gidx

def forward(self, score):
Expand Down
30 changes: 29 additions & 1 deletion python/dgl/backend/pytorch/sparse.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch as th
from distutils.version import LooseVersion
from ...base import is_all, ALL
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp, _reverse
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp

if LooseVersion(th.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import custom_fwd, custom_bwd
Expand All @@ -27,6 +27,34 @@ def decorate_bwd(*args, **kwargs):
__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce']


_inverse_format = {
'coo': 'coo',
'csr': 'csc',
'csc': 'csr'
}


def _reverse(gidx):
"""Reverse the given graph index while retaining its formats.
Parameters
----------
gidx: HeteroGraphIndex
Return
------
HeteroGraphIndex
"""


g_rev = gidx.reverse()
original_formats_dict = gidx.formats()
original_formats = original_formats_dict['created'] +\
original_formats_dict['not created']
g_rev = g_rev.formats([_inverse_format[fmt] for fmt in original_formats])
return g_rev


def _reduce_grad(grad, shape):
"""Reduce gradient on the broadcast dimension
If there is broadcast in forward pass, gradients need to be reduced on
Expand Down
10 changes: 5 additions & 5 deletions python/dgl/backend/tensorflow/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
from .tensor import tensor, copy_to, context, asnumpy, zerocopy_from_numpy
from ...base import is_all, ALL
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp, _reverse
from ...sparse import _gspmm, _gsddmm, _segment_reduce, _bwd_segment_cmp

__all__ = ['gspmm', 'gsddmm', 'edge_softmax', 'segment_reduce']

Expand Down Expand Up @@ -110,7 +110,7 @@ def gspmm_real(gidx, op, reduce_op, X, Y):
def grad(dZ):
dZ = tensor(dZ)
if op != 'copy_rhs':
g_rev = _reverse(gidx)
g_rev = gidx.reverse()
if reduce_op == 'sum':
if op in ['mul', 'div']:
dX = _gspmm(g_rev, 'mul', 'sum', dZ, _muldiv(op, Y))[0]
Expand Down Expand Up @@ -172,7 +172,7 @@ def gsddmm_real(gidx, op, X, Y, lhs_target, rhs_target):
def grad(dZ):
if op != 'copy_rhs':
if lhs_target in ['u', 'v']:
_gidx = gidx if lhs_target == 'v' else _reverse(gidx)
_gidx = gidx if lhs_target == 'v' else gidx.reverse()
if op in ['add', 'sub', 'copy_lhs']:
dX = _gspmm(_gidx, 'copy_rhs', 'sum', None, dZ)[0]
else: # mul, div, dot
Expand All @@ -192,7 +192,7 @@ def grad(dZ):
dX = tf.zeros_like(X)
if op != 'copy_lhs':
if rhs_target in ['u', 'v']:
_gidx = gidx if rhs_target == 'v' else _reverse(gidx)
_gidx = gidx if rhs_target == 'v' else gidx.reverse()
if op in ['add', 'sub', 'copy_rhs']:
dY = _gspmm(_gidx, 'copy_rhs', 'sum', None, _addsub(op, dZ))[0]
else: # mul, div, dot
Expand Down Expand Up @@ -233,7 +233,7 @@ def edge_softmax_real(gidx, score, eids=ALL, norm_by='dst'):
if not is_all(eids):
gidx = gidx.edge_subgraph([eids], True).graph
if norm_by == 'src':
gidx = _reverse(gidx)
gidx = gidx.reverse()
score_max = _gspmm(gidx, 'copy_rhs', 'max', None, score)[0]
score = tf.math.exp(_gsddmm(gidx, 'sub', score, score_max, 'e', 'v'))
score_sum = _gspmm(gidx, 'copy_rhs', 'sum', None, score)[0]
Expand Down
7 changes: 0 additions & 7 deletions python/dgl/heterograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5449,13 +5449,6 @@ def formats(self, formats=None):
>>> # Only allowed formats will be displayed in the status query
>>> csr_g.formats()
{'created': ['csr'], 'not created': []}
Notes
-----
DGL will create sparse formats (only constrained to the allowed formats, i.e.
created formats and not created formats) on-the-fly during the training of Graph
Neural Networks. Once a format was created, it would be cached and reused until
user changes the graph structure.
"""
if formats is None:
# Return the format information
Expand Down
29 changes: 0 additions & 29 deletions python/dgl/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from .base import DGLError
from . import backend as F

__all__ = ['_gspmm', '_gsddmm', '_segment_reduce', '_bwd_segment_cmp', '_reverse']


def infer_broadcast_shape(op, shp1, shp2):
r"""Check the shape validity, and infer the output shape given input shape and operator.
Expand Down Expand Up @@ -67,33 +65,6 @@ def to_dgl_nd_for_write(x):
return nd.NULL['int64'] if x is None else F.zerocopy_to_dgl_ndarray_for_write(x)


inverse_format = {
'coo': 'coo',
'csr': 'csc',
'csc': 'csr'
}


def _reverse(gidx):
"""Reverse the given graph index while retaining its formats.
``dgl.reverse`` would not keep graph format information by default.
Parameters
----------
gidx: HeteroGraphIndex
Return
------
HeteroGraphIndex
"""
g_rev = gidx.reverse()
original_formats_dict = gidx.formats()
original_formats = original_formats_dict['created'] +\
original_formats_dict['not created']
g_rev = g_rev.formats([inverse_format[fmt] for fmt in original_formats])
return g_rev


target_mapping = {
'u': 0,
'e': 1,
Expand Down

0 comments on commit 878acdb

Please sign in to comment.