Skip to content

Commit

Permalink
[Feature] Update edge features (apply_edge) for heterogeneous graph (d…
Browse files Browse the repository at this point in the history
…mlc#3250)

* rgcn with new heterograph API

* apply_edge() forward for multi relation

* undoing changes from rgcn-hetero

* backward apply_edge(copy_u) added

* unittest for apply_edge(copy_e)

* Compatible with new PRs

* resolving conflict with master

* Bringing back change after resolving conflict

* minor

* minor

Co-authored-by: Israt Nisa <[email protected]>
  • Loading branch information
isratnisa and Israt Nisa authored Sep 9, 2021
1 parent 129f67c commit d17b299
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 4 deletions.
27 changes: 23 additions & 4 deletions python/dgl/heterograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -4410,9 +4410,17 @@ def apply_edges(self, func, edges=ALL, etype=None, inplace=False):
"""
if inplace:
raise DGLError('The `inplace` option is removed in v0.5.')
etid = self.get_etype_id(etype)
etype = self.canonical_etypes[etid]
g = self if etype is None else self[etype]
# Graph with one relation type
if self._graph.number_of_etypes() == 1 or etype is not None:
etid = self.get_etype_id(etype)
etype = self.canonical_etypes[etid]
g = self if etype is None else self[etype]
else: # heterogeneous graph with number of relation types > 1
if not core.is_builtin(func):
raise DGLError("User defined functions are not yet "
"supported in apply_edges for heterogeneous graphs. "
"Please use (apply_edges(func), etype = rel) instead.")
g = self
if is_all(edges):
eid = ALL
else:
Expand All @@ -4423,7 +4431,18 @@ def apply_edges(self, func, edges=ALL, etype=None, inplace=False):
edata = core.invoke_gsddmm(g, func)
else:
edata = core.invoke_edge_udf(g, eid, etype, func)
self._set_e_repr(etid, eid, edata)

if self._graph.number_of_etypes() == 1 or etype is not None:
self._set_e_repr(etid, eid, edata)
else:
edata_tensor = {}
key = list(edata.keys())[0]
out_tensor_tuples = edata[key]
for etid in range(self._graph.number_of_etypes()):
# TODO (Israt): Check the logic why some output tensor is None
if out_tensor_tuples[etid] is not None:
edata_tensor[key] = out_tensor_tuples[etid]
self._set_e_repr(etid, eid, edata_tensor)

def send_and_recv(self,
edges,
Expand Down
4 changes: 4 additions & 0 deletions python/dgl/ops/sddmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
return gsddmm_internal(
g._graph, op, lhs_data, rhs_data, lhs_target, rhs_target)
else:
if op == 'copy_lhs':
rhs_data = [None] * g._graph.number_of_etypes()
elif op == 'copy_rhs':
lhs_data = [None] * g._graph.number_of_ntypes()
# TODO (Israt): Call reshape_lhs_rhs() on lhs and rhs data to match their dimension
# and avoid broadcasting issue. Handle the case where different nodes have
# different dimensions, and different etypes may need different broadcasting
Expand Down
162 changes: 162 additions & 0 deletions tests/compute/test_apply_edges_hetero.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import dgl
import dgl.function as fn
from collections import Counter
import numpy as np
import scipy.sparse as ssp
import itertools
import backend as F
import networkx as nx
import unittest, pytest
from dgl import DGLError
import test_utils
from test_utils import parametrize_dtype, get_cases
from scipy.sparse import rand

rfuncs = {'sum': fn.sum, 'max': fn.max, 'min': fn.min, 'mean': fn.mean}
fill_value = {'sum': 0, 'max': float("-inf")}
feat_size = 2

@unittest.skipIf(dgl.backend.backend_name != 'pytorch', reason='Only support PyTorch for now')

def create_test_heterograph(idtype):
# test heterograph from the docstring, plus a user -- wishes -- game relation
# 3 users, 2 games, 2 developers
# metagraph:
# ('user', 'follows', 'user'),
# ('user', 'plays', 'game'),
# ('user', 'wishes', 'game'),
# ('developer', 'develops', 'game')])

g = dgl.heterograph({
('user', 'follows', 'user'): ([0, 1, 2, 1], [0, 0, 1, 1]),
('user', 'plays', 'game'): ([0, 1, 2, 1], [0, 0, 1, 1]),
('user', 'wishes', 'game'): ([0, 1, 1], [0, 0, 1]),
('developer', 'develops', 'game'): ([0, 1, 0], [0, 1, 1]),
}, idtype=idtype, device=F.ctx())
assert g.idtype == idtype
assert g.device == F.ctx()
return g


@parametrize_dtype
def test_unary_copy_u(idtype):
def _test(mfunc, rfunc):

g = create_test_heterograph(idtype)

x1 = F.randn((g.num_nodes('user'), feat_size))
x2 = F.randn((g.num_nodes('developer'), feat_size))

F.attach_grad(x1)
F.attach_grad(x2)
g.nodes['user'].data['h'] = x1
g.nodes['developer'].data['h'] = x2

#################################################################
# apply_edges() is called for each etype in a loop
#################################################################

with F.record_grad():
[g.apply_edges(fn.copy_u('h', 'm'), etype = rel)
for rel in g.canonical_etypes]
r1 = g['plays'].edata['m']
F.backward(r1, F.ones(r1.shape))
n_grad1 = F.grad(g.ndata['h']['user'])
# TODO (Israt): clear not working
g.edata['m'].clear()

#################################################################
# apply_edges() is called for all etypes at once
#################################################################

g.apply_edges(fn.copy_u('h', 'm'))
r2 = g['plays'].edata['m']
F.backward(r2, F.ones(r2.shape))
n_grad2 = F.grad(g.nodes['user'].data['h'])

# correctness check
def _print_error(a, b):
for i, (x, y) in enumerate(zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())):
if not np.allclose(x, y):
print('@{} {} v.s. {}'.format(i, x, y))

if not F.allclose(r1, r2):
_print_error(r1, r2)
assert F.allclose(r1, r2)
if not F.allclose(n_grad1, n_grad2):
print('node grad')
_print_error(n_grad1, n_grad2)
assert(F.allclose(n_grad1, n_grad2))

_test(fn.copy_u, fn.sum)
# TODO(Israt) :Add reduce func to suport the following reduce op
# _test('copy_u', 'max')
# _test('copy_u', 'min')
# _test('copy_u', 'mean')


@parametrize_dtype
def test_unary_copy_e(idtype):
def _test(mfunc, rfunc):

g = create_test_heterograph(idtype)
feat_size = 2

x1 = F.randn((4,feat_size))
x2 = F.randn((4,feat_size))
x3 = F.randn((3,feat_size))
x4 = F.randn((3,feat_size))
F.attach_grad(x1)
F.attach_grad(x2)
F.attach_grad(x3)
F.attach_grad(x4)
g['plays'].edata['eid'] = x1
g['follows'].edata['eid'] = x2
g['develops'].edata['eid'] = x3
g['wishes'].edata['eid'] = x4

#################################################################
# apply_edges() is called for each etype in a loop
#################################################################
with F.record_grad():
[g.apply_edges(fn.copy_e('eid', 'm'), etype = rel)
for rel in g.canonical_etypes]
r1 = g['develops'].edata['m']
F.backward(r1, F.ones(r1.shape))
e_grad1 = F.grad(g['develops'].edata['eid'])

#################################################################
# apply_edges() is called for all etypes at the same time
#################################################################

g.apply_edges(fn.copy_e('eid', 'm'))
r2 = g['develops'].edata['m']
F.backward(r2, F.ones(r2.shape))
e_grad2 = F.grad(g['develops'].edata['eid'])

# # correctness check
def _print_error(a, b):
for i, (x, y) in enumerate(zip(F.asnumpy(a).flatten(), F.asnumpy(b).flatten())):
if not np.allclose(x, y):
print('@{} {} v.s. {}'.format(i, x, y))

if not F.allclose(r1, r2):
_print_error(r1, r2)
assert F.allclose(r1, r2)
if not F.allclose(e_grad1, e_grad2):
print('edge grad')
_print_error(e_grad1, e_grad2)
assert(F.allclose(e_grad1, e_grad2))

_test(fn.copy_e, fn.sum)
# TODO(Israt) :Add reduce func to suport the following reduce op
# _test('copy_e', 'max')
# _test('copy_e', 'min')
# _test('copy_e', 'mean')


if __name__ == '__main__':
test_unary_copy_u()
test_unary_copy_e()


0 comments on commit d17b299

Please sign in to comment.