Skip to content

Commit

Permalink
[Bugfix] Consistent order for cross-type stack reducer (dmlc#1267)
Browse files Browse the repository at this point in the history
* WIP

* fix shape of stack reducer

* apply merge order by edge id for stack reducer
  • Loading branch information
jermainewang authored Feb 18, 2020
1 parent 537d37c commit c7f6cf6
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 21 deletions.
46 changes: 31 additions & 15 deletions python/dgl/heterograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2653,6 +2653,7 @@ def multi_recv(self, v, reducer_dict, cross_reducer, apply_node_func=None, inpla
# TODO(minjie): currently loop over each edge type and reuse the old schedule.
# Should replace it with fused kernel.
all_out = []
merge_order = []
with ir.prog() as prog:
for ety, args in reducer_dict.items():
outframe = FrameRef(frame_like(self._node_frames[ntid]._frame))
Expand All @@ -2667,9 +2668,10 @@ def multi_recv(self, v, reducer_dict, cross_reducer, apply_node_func=None, inpla
v, rfunc, afunc,
inplace=inplace, outframe=outframe)
all_out.append(outframe)
merge_order.append(etid) # use edge type id as merge order hint
Runtime.run(prog)
# merge by cross_reducer
self._node_frames[ntid].update(merge_frames(all_out, cross_reducer))
self._node_frames[ntid].update(merge_frames(all_out, cross_reducer, merge_order))
# apply
if apply_node_func is not None:
self.apply_nodes(apply_node_func, v, ntype, inplace)
Expand Down Expand Up @@ -2855,6 +2857,7 @@ def multi_send_and_recv(self, etype_dict, cross_reducer, apply_node_func=None, i
# Should replace it with fused kernel.
all_out = []
all_vs = []
merge_order = []
with ir.prog() as prog:
for etype, args in etype_dict.items():
etid = self.get_etype_id(etype)
Expand Down Expand Up @@ -2883,9 +2886,10 @@ def multi_send_and_recv(self, etype_dict, cross_reducer, apply_node_func=None, i
mfunc, rfunc, afunc,
inplace=inplace, outframe=outframe)
all_out.append(outframe)
merge_order.append(etid) # use edge type id as merge order hint
Runtime.run(prog)
# merge by cross_reducer
self._node_frames[dtid].update(merge_frames(all_out, cross_reducer))
self._node_frames[dtid].update(merge_frames(all_out, cross_reducer, merge_order))
# apply
if apply_node_func is not None:
dstnodes = F.unique(F.cat([x.tousertensor() for x in all_vs], 0))
Expand Down Expand Up @@ -3043,6 +3047,7 @@ def multi_pull(self, v, etype_dict, cross_reducer, apply_node_func=None, inplace
# TODO(minjie): currently loop over each edge type and reuse the old schedule.
# Should replace it with fused kernel.
all_out = []
merge_order = []
with ir.prog() as prog:
for etype, args in etype_dict.items():
etid = self.get_etype_id(etype)
Expand All @@ -3058,9 +3063,10 @@ def multi_pull(self, v, etype_dict, cross_reducer, apply_node_func=None, inplace
mfunc, rfunc, afunc,
inplace=inplace, outframe=outframe)
all_out.append(outframe)
merge_order.append(etid) # use edge type id as merge order hint
Runtime.run(prog)
# merge by cross_reducer
self._node_frames[dtid].update(merge_frames(all_out, cross_reducer))
self._node_frames[dtid].update(merge_frames(all_out, cross_reducer, merge_order))
# apply
if apply_node_func is not None:
self.apply_nodes(apply_node_func, v, ntype, inplace)
Expand Down Expand Up @@ -3263,6 +3269,7 @@ def multi_update_all(self, etype_dict, cross_reducer, apply_node_func=None):
# TODO(minjie): currently loop over each edge type and reuse the old schedule.
# Should replace it with fused kernel.
all_out = defaultdict(list)
merge_order = defaultdict(list)
with ir.prog() as prog:
for etype, args in etype_dict.items():
etid = self.get_etype_id(etype)
Expand All @@ -3277,10 +3284,12 @@ def multi_update_all(self, etype_dict, cross_reducer, apply_node_func=None):
mfunc, rfunc, afunc,
outframe=outframe)
all_out[dtid].append(outframe)
merge_order[dtid].append(etid) # use edge type id as merge order hint
Runtime.run(prog)
for dtid, frames in all_out.items():
# merge by cross_reducer
self._node_frames[dtid].update(merge_frames(frames, cross_reducer))
self._node_frames[dtid].update(
merge_frames(frames, cross_reducer, merge_order[dtid]))
# apply
if apply_node_func is not None:
self.apply_nodes(apply_node_func, ALL, self.ntypes[dtid], inplace=False)
Expand Down Expand Up @@ -3813,36 +3822,46 @@ def pad_tuple(tup, length, pad_val=None):
else:
return tup + (pad_val,) * (length - len(tup))

def merge_frames(frames, reducer):
def merge_frames(frames, reducer, order=None):
"""Merge input frames into one. Resolve conflict fields using reducer.
Parameters
----------
frames : list of FrameRef
frames : list[FrameRef]
Input frames
reducer : str
One of "sum", "max", "min", "mean", "stack"
order : list[Int], optional
Merge order hint. Useful for "stack" reducer.
If provided, each integer indicates the relative order
of the ``frames`` list. Frames are sorted according to this list
in ascending order. Tie is not handled so make sure the order values
are distinct.
Returns
-------
FrameRef
Merged frame
"""
if len(frames) == 1:
if len(frames) == 1 and reducer != 'stack':
# Directly return the only one input. Stack reducer requires
# modifying tensor shape.
return frames[0]
if reducer == 'stack':
# TODO(minjie): Stack order does not matter. However, it must
# be consistent! Need to enforce one type of order.
# Stack order does not matter. However, it must be consistent!
if order:
assert len(order) == len(frames)
sorted_with_key = sorted(zip(frames, order), key=lambda x: x[1])
frames = list(zip(*sorted_with_key))[0]
def merger(flist):
flist = [F.unsqueeze(f, 1) for f in flist]
return F.stack(flist, 1)
else:
redfn = getattr(F, reducer, None)
if redfn is None:
raise DGLError('Invalid cross type reducer. Must be one of '
'"sum", "max", "min", "mean" or "stack".')
def merger(flist):
return redfn(F.stack(flist, 0), 0)
return redfn(F.stack(flist, 0), 0) if len(flist) > 1 else flist[0]
ret = FrameRef(frame_like(frames[0]._frame))
keys = set()
for frm in frames:
Expand All @@ -3852,10 +3871,7 @@ def merger(flist):
for frm in frames:
if k in frm:
flist.append(frm[k])
if len(flist) > 1:
ret[k] = merger(flist)
else:
ret[k] = flist[0]
ret[k] = merger(flist)
return ret

def combine_frames(frames, ids):
Expand Down
37 changes: 31 additions & 6 deletions tests/compute/test_heterograph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,12 +1249,10 @@ def afunc(nodes):
g['wishes'].update_all(mfunc, rfunc2)
y2 = g.nodes['game'].data['y']
if cred == 'stack':
# stack has two both correct outcomes
yy1 = F.stack([F.unsqueeze(y1, 1), F.unsqueeze(y2, 1)], 1)
yy1 = yy1 + 1 # final afunc
yy2 = F.stack([F.unsqueeze(y2, 1), F.unsqueeze(y1, 1)], 1)
yy2 = yy2 + 1 # final afunc
assert F.array_equal(y, yy1) or F.array_equal(y, yy2)
# stack has an internal order by edge type id
yy = F.stack([y1, y2], 1)
yy = yy + 1 # final afunc
assert F.array_equal(y, yy)
else:
yy = get_redfn(cred)(F.stack([y1, y2], 0), 0)
yy = yy + 1 # final afunc
Expand Down Expand Up @@ -1469,6 +1467,32 @@ def filter_edges2(edges):
g.filter_nodes(filter_nodes2, ntype='game')
g.filter_edges(filter_edges2)

def test_stack_reduce():
#edges = {
# 'follows': ([0, 1], [1, 2]),
# 'plays': ([0, 1, 2, 1], [0, 0, 1, 1]),
# 'wishes': ([0, 2], [1, 0]),
# 'develops': ([0, 1], [0, 1]),
#}
g = create_test_heterograph()
g.nodes['user'].data['h'] = F.randn((3, 200))
def rfunc(nodes):
return {'y': F.sum(nodes.mailbox['m'], 1)}
def rfunc2(nodes):
return {'y': F.max(nodes.mailbox['m'], 1)}
def mfunc(edges):
return {'m': edges.src['h']}
g.multi_update_all(
{'plays' : (mfunc, rfunc),
'wishes': (mfunc, rfunc2)},
'stack')
assert g.nodes['game'].data['y'].shape == (g.number_of_nodes('game'), 2, 200)
# only one type-wise update_all, stack still adds one dimension
g.multi_update_all(
{'plays' : (mfunc, rfunc)},
'stack')
assert g.nodes['game'].data['y'].shape == (g.number_of_nodes('game'), 1, 200)

if __name__ == '__main__':
test_create()
test_query()
Expand All @@ -1491,3 +1515,4 @@ def filter_edges2(edges):
test_empty_heterograph()
test_compact()
test_types_in_function()
test_stack_reduce()

0 comments on commit c7f6cf6

Please sign in to comment.