Skip to content

Commit

Permalink
[DistGB] enable GB sampling on heterograph (dmlc#7087)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rhett-Ying authored Feb 6, 2024
1 parent a2e1c79 commit ee8b7b3
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 40 deletions.
36 changes: 30 additions & 6 deletions python/dgl/distributed/graph_services.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""A set of graph services of getting subgraphs from DistGraph"""
import os
from collections import namedtuple

import numpy as np
Expand Down Expand Up @@ -708,24 +709,47 @@ def _frontier_to_heterogeneous_graph(g, frontier, gpb):
idtype=g.idtype,
)

etype_ids, frontier.edata[EID] = gpb.map_to_per_etype(frontier.edata[EID])
src, dst = frontier.edges()
# For DGL partitions, the global edge IDs are always stored in the edata.
# For GraphBolt partitions, the edge type IDs are always stored in the
# edata. As for the edge IDs, they are stored in the edata if the graph is
# partitioned with `store_eids=True`. Otherwise, the edge IDs are not
# stored.
etype_ids, type_wise_eids = (
gpb.map_to_per_etype(frontier.edata[EID])
if EID in frontier.edata
else (frontier.edata[ETYPE], None)
)
etype_ids, idx = F.sort_1d(etype_ids)
if type_wise_eids is not None:
type_wise_eids = F.gather_row(type_wise_eids, idx)

# Sort the edges by their edge types.
src, dst = frontier.edges()
src, dst = F.gather_row(src, idx), F.gather_row(dst, idx)
eid = F.gather_row(frontier.edata[EID], idx)
_, src = gpb.map_to_per_ntype(src)
_, dst = gpb.map_to_per_ntype(dst)
src_ntype_ids, src = gpb.map_to_per_ntype(src)
dst_ntype_ids, dst = gpb.map_to_per_ntype(dst)

data_dict = dict()
edge_ids = {}
for etid, etype in enumerate(g.canonical_etypes):
src_ntype, _, dst_ntype = etype
src_ntype_id = g.get_ntype_id(src_ntype)
dst_ntype_id = g.get_ntype_id(dst_ntype)
type_idx = etype_ids == etid
if F.sum(type_idx, 0) > 0:
data_dict[etype] = (
F.boolean_mask(src, type_idx),
F.boolean_mask(dst, type_idx),
)
edge_ids[etype] = F.boolean_mask(eid, type_idx)
if "DGL_DIST_DEBUG" in os.environ:
assert torch.all(
src_ntype_id == src_ntype_ids[type_idx]
), "source ntype is is not expected."
assert torch.all(
dst_ntype_id == dst_ntype_ids[type_idx]
), "destination ntype is is not expected."
if type_wise_eids is not None:
edge_ids[etype] = F.boolean_mask(type_wise_eids, type_idx)
hg = heterograph(
data_dict,
{ntype: g.num_nodes(ntype) for ntype in g.ntypes},
Expand Down
Loading

0 comments on commit ee8b7b3

Please sign in to comment.