Skip to content

Commit

Permalink
Cleanup some code in graph_services.py (dmlc#3238)
Browse files Browse the repository at this point in the history
* Fix bug

* Fix

* Fix

* upd

* Merge some code

* lint
classicsong authored Aug 15, 2021

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 1b350e9 commit 6f36dd6
Showing 1 changed file with 30 additions and 51 deletions.
81 changes: 30 additions & 51 deletions python/dgl/distributed/graph_services.py
Original file line number Diff line number Diff line change
@@ -382,6 +382,34 @@ def _distributed_access(g, nodes, issue_remote_req, local_access):
sampled_graph = merge_graphs(res_list, g.number_of_nodes())
return sampled_graph

def _frontier_to_heterogeneous_graph(g, frontier, gpb):
etype_ids, frontier.edata[EID] = gpb.map_to_per_etype(frontier.edata[EID])
src, dst = frontier.edges()
etype_ids, idx = F.sort_1d(etype_ids)
src, dst = F.gather_row(src, idx), F.gather_row(dst, idx)
eid = F.gather_row(frontier.edata[EID], idx)
assert len(eid) > 0
_, src = gpb.map_to_per_ntype(src)
_, dst = gpb.map_to_per_ntype(dst)

data_dict = dict()
edge_ids = {}
for etid in range(len(g.etypes)):
etype = g.etypes[etid]
canonical_etype = g.canonical_etypes[etid]
type_idx = etype_ids == etid
if F.sum(type_idx, 0) > 0:
data_dict[canonical_etype] = (F.boolean_mask(src, type_idx), \
F.boolean_mask(dst, type_idx))
edge_ids[etype] = F.boolean_mask(eid, type_idx)
hg = heterograph(data_dict,
{ntype: g.number_of_nodes(ntype) for ntype in g.ntypes},
idtype=g.idtype)

for etype in edge_ids:
hg.edges[etype].data[EID] = edge_ids[etype]
return hg

def sample_etype_neighbors(g, nodes, etype_field, fanout, edge_dir='in', prob=None, replace=False):
"""Sample from the neighbors of the given nodes from a distributed graph.
@@ -460,31 +488,7 @@ def local_access(local_g, partition_book, local_nids):
etype_field, fanout, edge_dir, prob, replace)
frontier = _distributed_access(g, nodes, issue_remote_req, local_access)
if len(gpb.etypes) > 1:
etype_ids, frontier.edata[EID] = gpb.map_to_per_etype(frontier.edata[EID])
src, dst = frontier.edges()
etype_ids, idx = F.sort_1d(etype_ids)
src, dst = F.gather_row(src, idx), F.gather_row(dst, idx)
eid = F.gather_row(frontier.edata[EID], idx)
_, src = gpb.map_to_per_ntype(src)
_, dst = gpb.map_to_per_ntype(dst)

data_dict = dict()
edge_ids = {}
for etid in range(len(g.etypes)):
etype = g.etypes[etid]
canonical_etype = g.canonical_etypes[etid]
type_idx = etype_ids == etid
if F.sum(type_idx, 0) > 0:
data_dict[canonical_etype] = (F.boolean_mask(src, type_idx), \
F.boolean_mask(dst, type_idx))
edge_ids[etype] = F.boolean_mask(eid, type_idx)
hg = heterograph(data_dict,
{ntype: g.number_of_nodes(ntype) for ntype in g.ntypes},
idtype=g.idtype)

for etype in edge_ids:
hg.edges[etype].data[EID] = edge_ids[etype]
return hg
return _frontier_to_heterogeneous_graph(g, frontier, gpb)
else:
return frontier

@@ -561,32 +565,7 @@ def local_access(local_g, partition_book, local_nids):
fanout, edge_dir, prob, replace)
frontier = _distributed_access(g, nodes, issue_remote_req, local_access)
if len(gpb.etypes) > 1:
etype_ids, frontier.edata[EID] = gpb.map_to_per_etype(frontier.edata[EID])
src, dst = frontier.edges()
etype_ids, idx = F.sort_1d(etype_ids)
src, dst = F.gather_row(src, idx), F.gather_row(dst, idx)
eid = F.gather_row(frontier.edata[EID], idx)
assert len(eid) > 0
_, src = gpb.map_to_per_ntype(src)
_, dst = gpb.map_to_per_ntype(dst)

data_dict = dict()
edge_ids = {}
for etid in range(len(g.etypes)):
etype = g.etypes[etid]
canonical_etype = g.canonical_etypes[etid]
type_idx = etype_ids == etid
if F.sum(type_idx, 0) > 0:
data_dict[canonical_etype] = (F.boolean_mask(src, type_idx), \
F.boolean_mask(dst, type_idx))
edge_ids[etype] = F.boolean_mask(eid, type_idx)
hg = heterograph(data_dict,
{ntype: g.number_of_nodes(ntype) for ntype in g.ntypes},
idtype=g.idtype)

for etype in edge_ids:
hg.edges[etype].data[EID] = edge_ids[etype]
return hg
return _frontier_to_heterogeneous_graph(g, frontier, gpb)
else:
return frontier

0 comments on commit 6f36dd6

Please sign in to comment.