Skip to content

Commit

Permalink
[Performance][Feature] Add src_nodes paramter to to_block() to av…
Browse files Browse the repository at this point in the history
…oid cost running unique() when available. (dmlc#2973)

* Add lhs_nodes are paremeter to to_block

* Update unit test

* Switch to simplified node conversion

* Switch lhs_nodes to be in/out parameter

* Update docs

Co-authored-by: Da Zheng <[email protected]>
Co-authored-by: Jinjing Zhou <[email protected]>
Co-authored-by: Quan (Andy) Gan <[email protected]>
Co-authored-by: Minjie Wang <[email protected]>
  • Loading branch information
5 people authored Sep 16, 2021
1 parent ad61a9a commit 2647afc
Show file tree
Hide file tree
Showing 5 changed files with 349 additions and 118 deletions.
50 changes: 46 additions & 4 deletions python/dgl/transform.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
##
# Copyright 2019-2021 Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Module for graph transformation utilities."""

from collections.abc import Iterable, Mapping
Expand Down Expand Up @@ -2053,7 +2068,7 @@ def compact_graphs(graphs, always_preserve=None, copy_ndata=True, copy_edata=Tru

return new_graphs

def to_block(g, dst_nodes=None, include_dst_in_src=True):
def to_block(g, dst_nodes=None, include_dst_in_src=True, src_nodes=None):
"""Convert a graph into a bipartite-structured *block* for message passing.
A block is a graph consisting of two sets of nodes: the
Expand Down Expand Up @@ -2089,6 +2104,12 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True):
(Default: True)
src_nodes : Tensor or disct[str, Tensor], optional
The list of source nodes (and prefixed by destination nodes if
`include_dst_in_src` is True).
If a tensor is given, the graph must have only one node type.
Returns
-------
DGLBlock
Expand Down Expand Up @@ -2215,15 +2236,36 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True):
if g._graph.ctx != d.ctx:
raise ValueError('g and dst_nodes need to have the same context.')

new_graph_index, src_nodes_nd, induced_edges_nd = _CAPI_DGLToBlock(
g._graph, dst_node_ids_nd, include_dst_in_src)
src_node_ids = None
src_node_ids_nd = None
if src_nodes is not None and not isinstance(src_nodes, Mapping):
# src_nodes is a Tensor, check if the g has only one type.
if len(g.ntypes) > 1:
raise DGLError(
'Graph has more than one node type; please specify a dict for src_nodes.')
src_nodes = {g.ntypes[0]: src_nodes}
src_node_ids = [
F.copy_to(F.tensor(src_nodes.get(ntype, []), dtype=g._idtype_str), \
F.to_backend_ctx(g._graph.ctx)) \
for ntype in g.ntypes]
src_node_ids_nd = [F.to_dgl_nd(nodes) for nodes in src_node_ids]

for d in src_node_ids_nd:
if g._graph.ctx != d.ctx:
raise ValueError('g and src_nodes need to have the same context.')
else:
# use an empty list to signal we need to generate it
src_node_ids_nd = []

new_graph_index, src_nodes_ids_nd, induced_edges_nd = _CAPI_DGLToBlock(
g._graph, dst_node_ids_nd, include_dst_in_src, src_node_ids_nd)

# The new graph duplicates the original node types to SRC and DST sets.
new_ntypes = (g.ntypes, g.ntypes)
new_graph = DGLBlock(new_graph_index, new_ntypes, g.etypes)
assert new_graph.is_unibipartite # sanity check

src_node_ids = [F.from_dgl_nd(src) for src in src_nodes_nd]
src_node_ids = [F.from_dgl_nd(src) for src in src_nodes_ids_nd]
edge_ids = [F.from_dgl_nd(eid) for eid in induced_edges_nd]

node_frames = utils.extract_node_subframes_for_block(g, src_node_ids, dst_node_ids)
Expand Down
Loading

0 comments on commit 2647afc

Please sign in to comment.