Skip to content

Commit

Permalink
[DistGB] return global eids from GB sampling on homograph (dmlc#7085)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rhett-Ying authored Feb 5, 2024
1 parent badeaf1 commit 4ee0a8b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
9 changes: 6 additions & 3 deletions python/dgl/distributed/graph_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ def _sample_neighbors_graphbolt(
# [Rui][TODO] Support multiple fanouts.
assert fanout.numel() == 1, "Expect a single fanout."

subgraph = g._sample_neighbors(nodes, fanout)
return_eids = g.edge_attributes is not None and EID in g.edge_attributes
subgraph = g._sample_neighbors(nodes, fanout, return_eids=return_eids)

# 3. Map local node IDs to global node IDs.
local_src = subgraph.indices
Expand All @@ -156,9 +157,11 @@ def _sample_neighbors_graphbolt(
global_src = global_nid_mapping[local_src]
global_dst = global_nid_mapping[local_dst]

# [Rui][TODO] edge IDs are not supported yet.
global_eids = None
if return_eids:
global_eids = g.edge_attributes[EID][subgraph.original_edge_ids]
return LocalSampledGraph(
global_src, global_dst, None, subgraph.type_per_edge
global_src, global_dst, global_eids, subgraph.type_per_edge
)


Expand Down
15 changes: 11 additions & 4 deletions tests/distributed/test_distributed_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def start_sample_client_shuffle(
orig_nid,
orig_eid,
use_graphbolt=False,
return_eids=False,
):
os.environ["DGL_GROUP_ID"] = str(group_id)
gpb = None
Expand All @@ -95,7 +96,7 @@ def start_sample_client_shuffle(
dst = orig_nid[dst]
assert sampled_graph.num_nodes() == g.num_nodes()
assert np.all(F.asnumpy(g.has_edges_between(src, dst)))
if use_graphbolt:
if use_graphbolt and not return_eids:
assert (
dgl.EID not in sampled_graph.edata
), "EID should not be in sampled graph if use_graphbolt=True."
Expand Down Expand Up @@ -391,7 +392,7 @@ def test_rpc_sampling():


def check_rpc_sampling_shuffle(
tmpdir, num_server, num_groups=1, use_graphbolt=False
tmpdir, num_server, num_groups=1, use_graphbolt=False, return_eids=False
):
generate_ip_config("rpc_ip_config.txt", num_server, num_server)

Expand All @@ -408,6 +409,7 @@ def check_rpc_sampling_shuffle(
part_method="metis",
return_mapping=True,
use_graphbolt=use_graphbolt,
store_eids=return_eids,
)

pserver_list = []
Expand Down Expand Up @@ -444,6 +446,7 @@ def check_rpc_sampling_shuffle(
orig_nids,
orig_eids,
use_graphbolt,
return_eids,
),
)
p.start()
Expand Down Expand Up @@ -1015,12 +1018,16 @@ def check_rpc_bipartite_etype_sampling_shuffle(tmpdir, num_server):

@pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize("use_graphbolt", [False, True])
def test_rpc_sampling_shuffle(num_server, use_graphbolt):
@pytest.mark.parametrize("return_eids", [False, True])
def test_rpc_sampling_shuffle(num_server, use_graphbolt, return_eids):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname:
check_rpc_sampling_shuffle(
Path(tmpdirname), num_server, use_graphbolt=use_graphbolt
Path(tmpdirname),
num_server,
use_graphbolt=use_graphbolt,
return_eids=return_eids,
)


Expand Down

0 comments on commit 4ee0a8b

Please sign in to comment.