Skip to content

Commit

Permalink
backward now stores DGLGraph index,not DGLGraph object witattached da…
Browse files Browse the repository at this point in the history
…ta (dmlc#3410)

Co-authored-by: Israt Nisa <[email protected]>
  • Loading branch information
isratnisa and Israt Nisa authored Oct 11, 2021
1 parent aef96df commit 532eaa8
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 52 deletions.
58 changes: 28 additions & 30 deletions python/dgl/backend/pytorch/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from distutils.version import LooseVersion
from ...base import is_all, ALL
from ...sparse import _gspmm, _gspmm_hetero, _gsddmm, _gsddmm_hetero, _segment_reduce, _bwd_segment_cmp, _scatter_add
from ...sparse import _csrmm, _csrsum, _csrmask, get_typeid_by_target
from ...sparse import _csrmm, _csrsum, _csrmask
from ...heterograph_index import create_unitgraph_from_csr

if LooseVersion(th.__version__) >= LooseVersion("1.6.0"):
Expand Down Expand Up @@ -192,20 +192,20 @@ def backward(ctx, dZ):
class GSpMM_hetero(th.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, g, op, reduce_op, X_len, *feats): # feats = lhs_data + rhs_data
out, (argX, argY) = _gspmm_hetero(g, op, reduce_op, X_len, feats)
def forward(ctx, gidx, op, reduce_op, X_len, *feats): # feats = lhs_data + rhs_data
out, (argX, argY) = _gspmm_hetero(gidx, op, reduce_op, X_len, feats)
X, Y = feats[:X_len], feats[X_len:]
# TODO (Israt): check target to decide src_id/dst_id?
# checking the first relation to decide for all the relations
src_id, dst_id = g._graph.metagraph.find_edge(0)
src_id, dst_id = gidx.metagraph.find_edge(0)
reduce_last = _need_reduce_last_dim(X[src_id], Y[dst_id])
X_shape = tuple([X[i].shape if X[i] is not None else None
for i in range(X_len)])
Y_shape = tuple([Y[i].shape if Y[i] is not None else None
for i in range(len(Y))])
dtype = X[src_id].dtype if X[src_id] is not None else Y[dst_id].dtype
device = X[src_id].device if X[src_id] is not None else Y[dst_id].device
ctx.backward_cache = g, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last, X_len
ctx.backward_cache = gidx, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last, X_len
req_grad_X = tuple([X[i].requires_grad if X[i] is not None else False
for i in range(X_len)])
req_grad_Y = tuple([Y[i].requires_grad if Y[i] is not None else False
Expand All @@ -223,14 +223,14 @@ def forward(ctx, g, op, reduce_op, X_len, *feats): # feats = lhs_data + rhs_data
@staticmethod
@custom_bwd
def backward(ctx, *dZ):
g, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last, X_len = ctx.backward_cache
gidx, op, reduce_op, X_shape, Y_shape, dtype, device, reduce_last, X_len = ctx.backward_cache
feats = ctx.saved_tensors[:-2]
argX = ctx.saved_tensors[-2]
argY = ctx.saved_tensors[-1]
X, Y = feats[:X_len], feats[X_len:]

if op != 'copy_rhs' and any([x is not None for x in X]):
g_rev = g.reverse()
g_rev = gidx.reverse()
# TODO(Israt): implement other combinations of message and reduce functions
if reduce_op == 'sum':
if op == 'mul':
Expand All @@ -251,11 +251,11 @@ def backward(ctx, *dZ):
for i in range(len(dZ))])
tpl_X_dZ = tuple(X + tpl_dZ)
if op == 'mul' and reduce_last:
dY = gsddmm_hetero(g, 'dot', X_len, 'u', 'v', *tpl_X_dZ)
dY = gsddmm_hetero(gidx, 'dot', X_len, 'u', 'v', *tpl_X_dZ)
elif op == 'mul':
dY = gsddmm_hetero(g, 'mul', X_len, 'u', 'v', *tpl_X_dZ)
dY = gsddmm_hetero(gidx, 'mul', X_len, 'u', 'v', *tpl_X_dZ)
elif op in ['add', 'copy_rhs']:
dY = gsddmm_hetero(g, 'copy_rhs', X_len, 'u', 'v', *tpl_X_dZ)
dY = gsddmm_hetero(gidx, 'copy_rhs', X_len, 'u', 'v', *tpl_X_dZ)
dY = tuple([_reduce_grad(dY[i], Y_shape[i]) if Y[i] is not None else None
for i in range(len(Y))])
else: # Y has no gradient
Expand Down Expand Up @@ -345,77 +345,75 @@ def backward(ctx, dZ):
class GSDDMM_hetero(th.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=th.float16)
def forward(ctx, g, op, X_len, lhs_target, rhs_target, *feats): # feats = X+Y
out = _gsddmm_hetero(g, op, X_len, lhs_target, rhs_target, feats)
def forward(ctx, gidx, op, X_len, lhs_target, rhs_target, *feats): # feats = X+Y
out = _gsddmm_hetero(gidx, op, X_len, lhs_target, rhs_target, feats)
X, Y = feats[:X_len], feats[X_len:]
X_shape = tuple([X[i].shape if X[i] is not None else None
for i in range(len(X))])
Y_shape = tuple([Y[i].shape if Y[i] is not None else None
for i in range(len(Y))])
ctx.backward_cache = g, op, lhs_target, rhs_target, X_shape, Y_shape, X_len
ctx.backward_cache = gidx, op, lhs_target, rhs_target, X_shape, Y_shape, X_len
req_grad_X = tuple([X[i].requires_grad if X[i] is not None else False
for i in range(len(X))])
req_grad_Y = tuple([Y[i].requires_grad if Y[i] is not None else False
for i in range(len(Y))])
lhs_id = get_typeid_by_target(g, g.canonical_etypes[0], lhs_target)
rhs_id = get_typeid_by_target(g, g.canonical_etypes[0], rhs_target)
ctx.save_for_backward(*feats)
return out

@staticmethod
@custom_bwd
# TODO(Israt): Implement the complete backward operator
def backward(ctx, *dZ):
g, op, lhs_target, rhs_target, X_shape, Y_shape, X_len = ctx.backward_cache
gidx, op, lhs_target, rhs_target, X_shape, Y_shape, X_len = ctx.backward_cache
feats = ctx.saved_tensors
X, Y = feats[:X_len], feats[X_len:]
if op != 'copy_rhs' and any([x is not None for x in X]):
if lhs_target in ['u', 'v']:
_g = g if lhs_target == 'v' else g.reverse()
_gidx = gidx if lhs_target == 'v' else gidx.reverse()
tpl_of_None = tuple([None] * len(X))
if op in ['add', 'copy_lhs']:
dX = gspmm_hetero(_g, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ)))
dX = gspmm_hetero(_gidx, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ)))
else: # mul, dot
if rhs_target == lhs_target:
dX = gspmm_hetero(_g, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ))) * Y
dX = gspmm_hetero(_gidx, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ))) * Y
elif rhs_target == 'e':
dZ_mul_Y = tuple([dZ[i] * Y[i] if dZ[i] is not None else None
for i in range(len(Y))])
dX = gspmm_hetero(_g, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ_mul_Y)))
dX = gspmm_hetero(_gidx, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ_mul_Y)))
else: # rhs_target = !lhs_target
dX = gspmm_hetero(_g, 'mul', 'sum', len(X), *tuple(Y + dZ))
dX = gspmm_hetero(_gidx, 'mul', 'sum', len(X), *tuple(Y + dZ))
else: # lhs_target == 'e'
if op in ['add', 'copy_lhs']:
dX = dZ
else: # mul, dot
num_etype = g._graph.number_of_etypes()
dX = gsddmm_hetero(g, 'mul', num_etype, 'e', rhs_target, *tuple(dZ + Y))
num_etype = gidx.number_of_etypes()
dX = gsddmm_hetero(gidx, 'mul', num_etype, 'e', rhs_target, *tuple(dZ + Y))
dX = tuple([_reduce_grad(dX[i], X_shape[i]) if X[i] is not None else None
for i in range(len(X))])
else:
dX = tuple([None] * len(X))
if op != 'copy_lhs' and any([y is not None for y in Y]):
if rhs_target in ['u', 'v']:
_g = g if rhs_target == 'v' else g.reverse()
_gidx = gidx if rhs_target == 'v' else gidx.reverse()
tpl_of_None = tuple([None] * len(X))
if op in ['add', 'copy_rhs']:
dY = gspmm_hetero(_g, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ)))
dY = gspmm_hetero(_gidx, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ)))
else: # mul, dot
if lhs_target == rhs_target:
dY = gspmm_hetero(_g, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ))) * X
dY = gspmm_hetero(_gidx, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ))) * X
elif lhs_target == 'e':
dZ_mul_X = tuple([dZ[i] * X[i] if dZ[i] is not None else None
for i in range(len(X))])
dY = gspmm_hetero(_g, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ_mul_X)))
dY = gspmm_hetero(_gidx, 'copy_rhs', 'sum', len(X), *(tuple(tpl_of_None + dZ_mul_X)))
else: # rhs_target = !lhs_target
dY = gspmm_hetero(_g, 'mul', 'sum', len(X), *tuple(X + dZ))
dY = gspmm_hetero(_gidx, 'mul', 'sum', len(X), *tuple(X + dZ))
else:
if op in ['add', 'copy_rhs']:
dY = tuple([dZ[i] if dZ[i] is not None else None
for i in range(len(dZ))])
else: # mul, dot
num_etype = g._graph.number_of_etypes()
dY = gsddmm_hetero(g, 'mul', num_etype, 'e', lhs_target, *tuple(dZ + X))
num_etype = gidx.number_of_etypes()
dY = gsddmm_hetero(gidx, 'mul', num_etype, 'e', lhs_target, *tuple(dZ + X))
dY = tuple([_reduce_grad(dY[i], Y_shape[i]) if Y[i] is not None else None
for i in range(len(Y))])
else:
Expand Down
2 changes: 1 addition & 1 deletion python/dgl/ops/sddmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
# different dimensions, and different etypes may need different broadcasting
# dims for the same node.
lhs_and_rhs_tuple = tuple(list(lhs_data) + list(rhs_data))
return gsddmm_internal_hetero(g, op, len(lhs_data), lhs_target,
return gsddmm_internal_hetero(g._graph, op, len(lhs_data), lhs_target,
rhs_target, *lhs_and_rhs_tuple)

def _gen_sddmm_func(lhs_target, rhs_target, binary_op):
Expand Down
2 changes: 1 addition & 1 deletion python/dgl/ops/spmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def gspmm(g, op, reduce_op, lhs_data, rhs_data):
rhs_data = [None] * g._graph.number_of_etypes() if rhs_data is None else rhs_data
# TODO (Israt): Call reshape func
lhs_and_rhs_tuple = tuple(list(lhs_data) + list(rhs_data))
ret = gspmm_internal_hetero(g, op,
ret = gspmm_internal_hetero(g._graph, op,
'sum' if reduce_op == 'mean' else reduce_op,
len(lhs_data), *lhs_and_rhs_tuple)
# TODO (Israt): Add support for 'max', 'min', 'mean' in heterograph
Expand Down
33 changes: 13 additions & 20 deletions python/dgl/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,13 @@ def to_dgl_nd_for_write(x):
return nd.NULL['int64'] if x is None else F.zerocopy_to_dgl_ndarray_for_write(x)


def get_typeid_by_target(g, rel, target):
def get_typeid_by_target(gidx, etid, target):
"""Find the src/dst/etype id based on the target 'u', 'v' or 'e'."""
srctype, _, dsttype = rel
etid = g.get_etype_id(rel)
src_id, dst_id = gidx.metagraph.find_edge(etid)
if target in [0, 'u']:
return g.get_ntype_id(srctype)
return src_id
if target in [2, 'v']:
return g.get_ntype_id(dsttype)
return dst_id
return etid


Expand Down Expand Up @@ -190,11 +189,10 @@ def _gspmm(gidx, op, reduce_op, u, e):
return v, (arg_u, arg_e)


def _gspmm_hetero(g, op, reduce_op, u_len, u_and_e_tuple):
def _gspmm_hetero(gidx, op, reduce_op, u_len, u_and_e_tuple):
r""" Generalized Sparse Matrix Multiplication interface.
"""
u_tuple, e_tuple = u_and_e_tuple[:u_len], u_and_e_tuple[u_len:]
gidx = g._graph
use_u = op != 'copy_rhs'
use_e = op != 'copy_lhs'
# TODO (Israt): Add check - F.dtype(u) != F.dtype(e):
Expand All @@ -205,11 +203,8 @@ def _gspmm_hetero(g, op, reduce_op, u_len, u_and_e_tuple):
list_v = [None] * gidx.number_of_ntypes()
list_e = [None] * gidx.number_of_etypes()

for rel in g.canonical_etypes:
srctype, _, dsttype = rel
etid = g.get_etype_id(rel)
src_id = g.get_ntype_id(srctype)
dst_id = g.get_ntype_id(dsttype)
for etid in range(gidx.number_of_etypes()):
src_id, dst_id = gidx.metagraph.find_edge(etid)
u = u_tuple[src_id] if use_u else None
e = e_tuple[etid] if use_e else None
if use_u:
Expand Down Expand Up @@ -346,10 +341,9 @@ def _gsddmm(gidx, op, lhs, rhs, lhs_target='u', rhs_target='v'):
return out


def _gsddmm_hetero(g, op, lhs_len, lhs_target='u', rhs_target='v', lhs_and_rhs_tuple=None):
def _gsddmm_hetero(gidx, op, lhs_len, lhs_target='u', rhs_target='v', lhs_and_rhs_tuple=None):
r""" Generalized Sampled-Dense-Dense Matrix Multiplication interface.
"""
gidx = g._graph
lhs_tuple, rhs_tuple = lhs_and_rhs_tuple[:lhs_len], lhs_and_rhs_tuple[lhs_len:]

use_lhs = op != 'copy_rhs'
Expand All @@ -358,19 +352,18 @@ def _gsddmm_hetero(g, op, lhs_len, lhs_target='u', rhs_target='v', lhs_and_rhs_t
# TODO (Israt): Add check - F.dtype(u) != F.dtype(e):
# deal with scalar features.
expand_lhs, expand_rhs = False, False
num_ntype = g._graph.number_of_ntypes()
num_etype = g._graph.number_of_etypes()
num_ntype = gidx.number_of_ntypes()
num_etype = gidx.number_of_etypes()
lhs_list = [None] * num_ntype if lhs_target in ['u', 'v'] else [None] * num_etype
rhs_list = [None] * num_ntype if rhs_target in ['u', 'v'] else [None] * num_etype
out_list = [None] * gidx.number_of_etypes()

lhs_target = target_mapping[lhs_target]
rhs_target = target_mapping[rhs_target]

for rel in g.canonical_etypes:
etid = g.get_etype_id(rel)
lhs_id = get_typeid_by_target(g, rel, lhs_target)
rhs_id = get_typeid_by_target(g, rel, rhs_target)
for etid in range(gidx.number_of_etypes()):
lhs_id = get_typeid_by_target(gidx, etid, lhs_target)
rhs_id = get_typeid_by_target(gidx, etid, rhs_target)
lhs = lhs_tuple[lhs_id]
rhs = rhs_tuple[rhs_id]
if use_lhs:
Expand Down

0 comments on commit 532eaa8

Please sign in to comment.