Skip to content

Commit

Permalink
[DistGB] add testcases for DistDGL local sampling on multiple partiti…
Browse files Browse the repository at this point in the history
…ons (dmlc#7451)
  • Loading branch information
Rhett-Ying authored Jun 7, 2024
1 parent e366260 commit d281e42
Show file tree
Hide file tree
Showing 2 changed files with 195 additions and 18 deletions.
173 changes: 173 additions & 0 deletions tests/distributed/test_distributed_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
sample_etype_neighbors,
sample_neighbors,
)

from dgl.distributed.graph_partition_book import _etype_tuple_to_str

from scipy import sparse as spsp
from utils import generate_ip_config, reset_envs

Expand Down Expand Up @@ -1685,6 +1688,176 @@ def test_standalone_etype_sampling():
check_standalone_etype_sampling(Path(tmpdirname))


@pytest.mark.parametrize("num_parts", [1, 4])
@pytest.mark.parametrize("use_graphbolt", [False])
@pytest.mark.parametrize("prob_or_mask", ["prob", "mask"])
def test_local_sampling_homograph(num_parts, use_graphbolt, prob_or_mask):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as test_dir:
g = CitationGraphDataset("cora")[0]
prob = torch.rand(g.num_edges())
mask = prob > 0.2
prob[torch.randperm(len(prob))[: int(len(prob) * 0.5)]] = 0.0
g.edata["prob"] = prob
g.edata["mask"] = mask
graph_name = "test_local_sampling"

_, orig_eids = partition_graph(
g,
graph_name,
num_parts,
test_dir,
num_hops=1,
part_method="metis",
return_mapping=True,
use_graphbolt=use_graphbolt,
store_eids=True,
store_inner_node=True,
store_inner_edge=True,
)

part_config = os.path.join(test_dir, f"{graph_name}.json")
for part_id in range(num_parts):
local_g, _, edge_feats, gpb, _, _, _ = load_partition(
part_config,
part_id,
load_feats=True,
use_graphbolt=use_graphbolt,
)
inner_global_nids = gpb.partid2nids(part_id)
inner_global_eids = gpb.partid2eids(part_id)
inner_node_data = (
local_g.node_attributes["inner_node"]
if use_graphbolt
else local_g.ndata["inner_node"]
)
inner_edge_data = (
local_g.edge_attributes["inner_edge"]
if use_graphbolt
else local_g.edata["inner_edge"]
)
assert len(inner_global_nids) == inner_node_data.sum()
assert len(inner_global_eids) == inner_edge_data.sum()

c_etype = gpb.canonical_etypes[0]
_prob = []
prob = edge_feats[_etype_tuple_to_str(c_etype) + "/" + prob_or_mask]
assert len(prob) == len(inner_global_eids)
assert len(prob) <= inner_edge_data.shape[0]
_prob.append(prob)

sampled_g = dgl.distributed.graph_services._sample_neighbors(
use_graphbolt,
local_g,
gpb,
inner_global_nids,
5,
prob=_prob,
)
sampled_homo_eids = sampled_g.global_eids
sampled_orig_eids = orig_eids[sampled_homo_eids]
assert torch.all(g.edata[prob_or_mask][sampled_orig_eids] > 0)


@pytest.mark.parametrize("num_parts", [1, 4])
@pytest.mark.parametrize("use_graphbolt", [False])
@pytest.mark.parametrize("prob_or_mask", ["prob", "mask"])
def test_local_sampling_heterograph(num_parts, use_graphbolt, prob_or_mask):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as test_dir:
g = create_random_hetero()
for c_etype in g.canonical_etypes:
prob = torch.rand(g.num_edges(c_etype))
mask = prob > 0.2
prob[torch.randperm(len(prob))[: int(len(prob) * 0.5)]] = 0.0
g.edges[c_etype].data["prob"] = prob
g.edges[c_etype].data["mask"] = mask
graph_name = "test_local_sampling"

_, orig_eids = partition_graph(
g,
graph_name,
num_parts,
test_dir,
num_hops=1,
part_method="metis",
return_mapping=True,
use_graphbolt=use_graphbolt,
store_eids=True,
store_inner_node=True,
store_inner_edge=True,
)

part_config = os.path.join(test_dir, f"{graph_name}.json")
for part_id in range(num_parts):
local_g, _, edge_feats, gpb, _, _, _ = load_partition(
part_config,
part_id,
load_feats=True,
use_graphbolt=use_graphbolt,
)
inner_global_nids = [
gpb.map_to_homo_nid(gpb.partid2nids(part_id, ntype), ntype)
for ntype in gpb.ntypes
]
inner_global_nids = torch.cat(inner_global_nids)
inner_global_eids = {
c_etype: gpb.partid2eids(part_id, c_etype)
for c_etype in gpb.canonical_etypes
}
inner_node_data = (
local_g.node_attributes["inner_node"]
if use_graphbolt
else local_g.ndata["inner_node"]
)
inner_edge_data = (
local_g.edge_attributes["inner_edge"]
if use_graphbolt
else local_g.edata["inner_edge"]
)
assert len(inner_global_nids) == inner_node_data.sum()
num_inner_global_eids = sum(
[len(eids) for eids in inner_global_eids.values()]
)
assert num_inner_global_eids == inner_edge_data.sum()

_prob = []
for i, c_etype in enumerate(gpb.canonical_etypes):
prob = edge_feats[
_etype_tuple_to_str(c_etype) + "/" + prob_or_mask
]
assert len(prob) == len(inner_global_eids[c_etype])
assert (
len(prob)
== gpb.local_etype_offset[i + 1] - gpb.local_etype_offset[i]
)
assert len(prob) <= inner_edge_data.shape[0]
_prob.append(prob)

sampled_g = dgl.distributed.graph_services._sample_etype_neighbors(
use_graphbolt,
local_g,
gpb,
inner_global_nids,
torch.full((len(g.canonical_etypes),), 5, dtype=torch.int64),
prob=_prob,
etype_offset=gpb.local_etype_offset,
)
sampled_homo_eids = sampled_g.global_eids
sampled_etype_ids, sampled_per_etype_eids = gpb.map_to_per_etype(
sampled_homo_eids
)
for etype_id, c_etype in enumerate(gpb.canonical_etypes):
indices = torch.nonzero(sampled_etype_ids == etype_id).squeeze()
sampled_eids = sampled_per_etype_eids[indices]
sampled_orig_eids = orig_eids[c_etype][sampled_eids]
assert torch.all(
g.edges[c_etype].data[prob_or_mask][sampled_orig_eids] > 0
)


if __name__ == "__main__":
import tempfile

Expand Down
40 changes: 22 additions & 18 deletions tests/distributed/test_mp_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,27 @@
from utils import generate_ip_config, reset_envs


def _unique_rand_graph(num_nodes=1000, num_edges=10 * 1000):
edges_set = set()
while len(edges_set) < num_edges:
src = np.random.randint(0, num_nodes - 1)
dst = np.random.randint(0, num_nodes - 1)
if (
src != dst
and (src, dst) not in edges_set
and (dst, src) not in edges_set
):
edges_set.add((src, dst))
src_list, dst_list = zip(*edges_set)

src = th.tensor(src_list, dtype=th.long)
dst = th.tensor(dst_list, dtype=th.long)
g = dgl.graph((th.cat([src, dst]), th.cat([dst, src])))
E = len(src)
reverse_eids = th.cat([th.arange(E, 2 * E), th.arange(0, E)])
return g, reverse_eids


class NeighborSampler(object):
def __init__(
self,
Expand Down Expand Up @@ -889,24 +910,7 @@ def test_edge_dataloader_homograph(
num_server = 1
dataloader_type = "edge"
reset_envs()
g = CitationGraphDataset("cora")[0]
src, dst = g.edges()
# Remove reverse edges.
visited = th.zeros_like(src, dtype=th.bool)
remove_mask = th.zeros_like(src, dtype=th.bool)
for i, (src_id, dst_id) in enumerate(zip(src, dst)):
if visited[i]:
continue
if g.has_edges_between(dst_id, src_id):
eid = g.edge_ids(dst_id, src_id)
visited[eid] = True
remove_mask[i] = True
visited[i] = True
src = src[~remove_mask]
dst = dst[~remove_mask]
g = dgl.graph((th.cat([src, dst]), th.cat([dst, src])))
E = len(src)
reverse_eids = th.cat([th.arange(E, 2 * E), th.arange(0, E)])
g, reverse_eids = _unique_rand_graph()
check_dataloader(
g,
num_server,
Expand Down

0 comments on commit d281e42

Please sign in to comment.