Skip to content

Commit

Permalink
[Dist] etype is not guaranteed to be sorted (dmlc#4156)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rhett-Ying authored Jun 23, 2022
1 parent 4d3c01d commit ab1b281
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions python/dgl/distributed/graph_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _sample_neighbors(local_g, partition_book, seed_nodes, fan_out, edge_dir, pr
return global_src, global_dst, global_eids

def _sample_etype_neighbors(local_g, partition_book, seed_nodes, etype_field,
fan_out, edge_dir, prob, replace):
fan_out, edge_dir, prob, replace, etype_sorted=False):
""" Sample from local partition.
The input nodes use global IDs. We need to map the global node IDs to local node IDs,
Expand All @@ -80,13 +80,10 @@ def _sample_etype_neighbors(local_g, partition_book, seed_nodes, etype_field,
"""
local_ids = partition_book.nid2localnid(seed_nodes, partition_book.partid)
local_ids = F.astype(local_ids, local_g.idtype)
# local_ids = self.seed_nodes

# DistGraph's edges are sorted by default according to
# graph partition mechanism.
sampled_graph = local_sample_etype_neighbors(
local_g, local_ids, etype_field, fan_out, edge_dir, prob, replace,
etype_sorted=True, _dist_training=True)
etype_sorted=etype_sorted, _dist_training=True)
global_nid_mapping = local_g.ndata[NID]
src, dst = sampled_graph.edges()
global_src, global_dst = F.gather_row(global_nid_mapping, src), \
Expand Down

0 comments on commit ab1b281

Please sign in to comment.