From a2410e375062f628967c9b6522bde2492d0e8ee6 Mon Sep 17 00:00:00 2001 From: Minjie Wang Date: Thu, 26 Dec 2019 09:08:16 +0000 Subject: [PATCH] fix nodeflow bug in apply block --- python/dgl/runtime/scheduler.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/dgl/runtime/scheduler.py b/python/dgl/runtime/scheduler.py index 620b5c3e58a4..194fcc153603 100644 --- a/python/dgl/runtime/scheduler.py +++ b/python/dgl/runtime/scheduler.py @@ -398,7 +398,7 @@ def schedule_nodeflow_apply_edges(graph, block_id, name='out_nf') var_ef = var.FEAT_DICT(graph._get_edge_frame(block_id), name='ef') var_out = _gen_send(graph, u, v, eid, apply_func, in_var_nf, out_var_nf, - var_ef) + var_ef, block_id=block_id) var_eid = var.IDX(eid) if inplace: ir.WRITE_ROW_INPLACE_(var_ef, var_eid, var_out) @@ -951,7 +951,7 @@ def _mfunc_wrapper(src_data, edge_data, dst_data): msg = ir.EDGE_UDF(_mfunc_wrapper, fdsrc, fdedge, fddst) return msg -def _gen_send(graph, u, v, eid, mfunc, var_src_nf, var_dst_nf, var_ef): +def _gen_send(graph, u, v, eid, mfunc, var_src_nf, var_dst_nf, var_ef, block_id=None): """Internal function to generate send schedule""" mfunc = _standardize_func_usage(mfunc, 'message') mfunc_is_list = utils.is_iterable(mfunc) @@ -961,7 +961,10 @@ def _gen_send(graph, u, v, eid, mfunc, var_src_nf, var_dst_nf, var_ef): var_eid = var.IDX(eid) if mfunc_is_list: - if eid.is_slice(0, graph.num_edges()): + if not hasattr(graph, 'num_edges'): + # XXX(minjie): a temporary hack to detect Nodeflow object + res = spmv.build_gidx_and_mapping_block(graph, block_id) + elif eid.is_slice(0, graph.num_edges()): # full graph case res = spmv.build_gidx_and_mapping_graph(graph) else: @@ -969,7 +972,7 @@ def _gen_send(graph, u, v, eid, mfunc, var_src_nf, var_dst_nf, var_ef): (u, v, eid), graph.num_src(), graph.num_dst()) adj, edge_map, _ = res # create a tmp message frame - tmp_mfr = FrameRef(frame_like(graph.edgeframe._frame, len(eid))) + tmp_mfr = FrameRef(frame_like(var_ef.data._frame, len(eid))) var_out = var.FEAT_DICT(data=tmp_mfr) spmv.gen_v2e_spmv_schedule(graph=adj, mfunc=mfunc,