Skip to content

Commit

Permalink
[DistGB] sample with graphbolt on homograph via DistNodeDataLoader (d…
Browse files Browse the repository at this point in the history
  • Loading branch information
Rhett-Ying authored Feb 9, 2024
1 parent 7f7967b commit 3ebdee7
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 13 deletions.
6 changes: 4 additions & 2 deletions python/dgl/dataloading/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,11 @@ def sample_blocks(self, g, seed_nodes, exclude_eids=None):
output_device=self.output_device,
exclude_edges=exclude_eids,
)
eid = frontier.edata[EID]
block = to_block(frontier, seed_nodes)
block.edata[EID] = eid
# If sampled from graphbolt-backed DistGraph, `EID` may not be in
# the block.
if EID in frontier.edata.keys():
block.edata[EID] = frontier.edata[EID]
seed_nodes = block.srcdata[NID]
blocks.insert(0, block)

Expand Down
7 changes: 6 additions & 1 deletion python/dgl/distributed/dist_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,7 +1406,12 @@ def sample_neighbors(
)
else:
frontier = graph_services.sample_neighbors(
self, seed_nodes, fanout, replace=replace, prob=prob
self,
seed_nodes,
fanout,
replace=replace,
prob=prob,
use_graphbolt=self._use_graphbolt,
)
return frontier

Expand Down
73 changes: 63 additions & 10 deletions tests/distributed/test_mp_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def check_neg_dataloader(g, num_server, num_workers):


@pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize("num_workers", [0, 4])
@pytest.mark.parametrize("num_workers", [0, 1])
@pytest.mark.parametrize("drop_last", [False, True])
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
Expand Down Expand Up @@ -429,6 +429,8 @@ def start_node_dataloader(
orig_nid,
orig_eid,
groundtruth_g,
use_graphbolt=False,
return_eids=False,
):
dgl.distributed.initialize(ip_config)
gpb = None
Expand All @@ -437,7 +439,12 @@ def start_node_dataloader(
_, _, _, gpb, _, _, _ = load_partition(part_config, rank)
num_nodes_to_sample = 202
batch_size = 32
dist_graph = DistGraph("test_mp", gpb=gpb, part_config=part_config)
dist_graph = DistGraph(
"test_sampling",
gpb=gpb,
part_config=part_config,
use_graphbolt=use_graphbolt,
)
assert len(dist_graph.ntypes) == len(groundtruth_g.ntypes)
assert len(dist_graph.etypes) == len(groundtruth_g.etypes)
if len(dist_graph.etypes) == 1:
Expand All @@ -459,6 +466,9 @@ def start_node_dataloader(
]
) # test int for hetero

# Enable santity check in distributed sampling.
os.environ["DGL_DIST_DEBUG"] = "1"

# We need to test creating DistDataLoader multiple times.
for i in range(2):
# Create DataLoader for constructing blocks
Expand All @@ -472,7 +482,7 @@ def start_node_dataloader(
num_workers=num_workers,
)

for epoch in range(2):
for _ in range(2):
for idx, (_, _, blocks) in zip(
range(0, num_nodes_to_sample, batch_size), dataloader
):
Expand All @@ -487,6 +497,16 @@ def start_node_dataloader(
src_nodes_id, dst_nodes_id, etype=etype
)
assert np.all(F.asnumpy(has_edges))

if use_graphbolt and not return_eids:
continue
eids = orig_eid[etype][block.edata[dgl.EID]]
expected_eids = groundtruth_g.edge_ids(
src_nodes_id, dst_nodes_id
)
assert th.equal(
eids, expected_eids
), f"{eids} != {expected_eids}"
del dataloader
# this is needed since there's two test here in one process
dgl.distributed.exit_client()
Expand All @@ -509,7 +529,7 @@ def start_edge_dataloader(
_, _, _, gpb, _, _, _ = load_partition(part_config, rank)
num_edges_to_sample = 202
batch_size = 32
dist_graph = DistGraph("test_mp", gpb=gpb, part_config=part_config)
dist_graph = DistGraph("test_sampling", gpb=gpb, part_config=part_config)
assert len(dist_graph.ntypes) == len(groundtruth_g.ntypes)
assert len(dist_graph.etypes) == len(groundtruth_g.etypes)
if len(dist_graph.etypes) == 1:
Expand Down Expand Up @@ -561,7 +581,14 @@ def start_edge_dataloader(
dgl.distributed.exit_client()


def check_dataloader(g, num_server, num_workers, dataloader_type):
def check_dataloader(
g,
num_server,
num_workers,
dataloader_type,
use_graphbolt=False,
return_eids=False,
):
with tempfile.TemporaryDirectory() as test_dir:
ip_config = "ip_config.txt"
generate_ip_config(ip_config, num_server, num_server)
Expand All @@ -576,6 +603,8 @@ def check_dataloader(g, num_server, num_workers, dataloader_type):
num_hops=num_hops,
part_method="metis",
return_mapping=True,
use_graphbolt=use_graphbolt,
store_eids=return_eids,
)
part_config = os.path.join(test_dir, "test_sampling.json")
if not isinstance(orig_nid, dict):
Expand All @@ -594,6 +623,7 @@ def check_dataloader(g, num_server, num_workers, dataloader_type):
part_config,
num_server > 1,
num_workers + 1,
use_graphbolt,
),
)
p.start()
Expand All @@ -615,6 +645,8 @@ def check_dataloader(g, num_server, num_workers, dataloader_type):
orig_nid,
orig_eid,
g,
use_graphbolt,
return_eids,
),
)
p.start()
Expand Down Expand Up @@ -663,14 +695,35 @@ def create_random_hetero():
return g


@unittest.skip(reason="Skip due to glitch in CI")
@pytest.mark.parametrize("num_server", [3])
@pytest.mark.parametrize("num_workers", [0, 4])
@pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize("num_workers", [0, 1])
@pytest.mark.parametrize("dataloader_type", ["node", "edge"])
def test_dataloader(num_server, num_workers, dataloader_type):
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_dataloader_homograph(
num_server, num_workers, dataloader_type, use_graphbolt, return_eids
):
if dataloader_type == "edge" and use_graphbolt:
# GraphBolt does not support edge dataloader.
return
reset_envs()
g = CitationGraphDataset("cora")[0]
check_dataloader(g, num_server, num_workers, dataloader_type)
check_dataloader(
g,
num_server,
num_workers,
dataloader_type,
use_graphbolt=use_graphbolt,
return_eids=return_eids,
)


@unittest.skip(reason="Skip due to glitch in CI")
@pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize("num_workers", [0, 1])
@pytest.mark.parametrize("dataloader_type", ["node", "edge"])
def test_dataloader_heterograph(num_server, num_workers, dataloader_type):
reset_envs()
g = create_random_hetero()
check_dataloader(g, num_server, num_workers, dataloader_type)

Expand Down

0 comments on commit 3ebdee7

Please sign in to comment.