Skip to content

Commit

Permalink
[DistGB] remove unnecessary testcases (dmlc#7444)
Browse files Browse the repository at this point in the history
  • Loading branch information
Rhett-Ying authored Jun 3, 2024
1 parent 0fee9be commit 3edc195
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions tests/distributed/test_mp_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,9 @@ def check_neg_dataloader(g, num_server, num_workers):
def test_dist_dataloader(
num_server, num_workers, drop_last, use_graphbolt, return_eids
):
if not use_graphbolt and return_eids:
# return_eids is not supported in non-GraphBolt mode.
return
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
os.environ["DGL_NUM_SAMPLER"] = str(num_workers)
Expand Down Expand Up @@ -820,6 +823,9 @@ def create_random_hetero():
def test_dataloader_homograph(
num_server, num_workers, dataloader_type, use_graphbolt, return_eids
):
if not use_graphbolt and return_eids:
# return_eids is not supported in non-GraphBolt mode.
return
reset_envs()
g = CitationGraphDataset("cora")[0]
check_dataloader(
Expand All @@ -832,7 +838,7 @@ def test_dataloader_homograph(
)


@pytest.mark.parametrize("num_workers", [0, 1])
@pytest.mark.parametrize("num_workers", [0])
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("exclude", [None, "self", "reverse_id"])
@pytest.mark.parametrize("negative", [False, True])
Expand Down Expand Up @@ -881,6 +887,9 @@ def test_edge_dataloader_homograph(
def test_dataloader_heterograph(
num_server, num_workers, dataloader_type, use_graphbolt, return_eids
):
if not use_graphbolt and return_eids:
# return_eids is not supported in non-GraphBolt mode.
return
reset_envs()
g = create_random_hetero()
check_dataloader(
Expand All @@ -893,7 +902,7 @@ def test_dataloader_heterograph(
)


@pytest.mark.parametrize("num_workers", [0, 1])
@pytest.mark.parametrize("num_workers", [0])
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("exclude", [None, "self", "reverse_types"])
@pytest.mark.parametrize("negative", [False, True])
Expand Down Expand Up @@ -973,17 +982,13 @@ def start_multiple_dataloaders(
dgl.distributed.exit_client()


@pytest.mark.parametrize("num_dataloaders", [1, 4])
@pytest.mark.parametrize("num_workers", [0, 1])
@pytest.mark.parametrize("num_dataloaders", [4])
@pytest.mark.parametrize("num_workers", [0])
@pytest.mark.parametrize("dataloader_type", ["node", "edge"])
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_multiple_dist_dataloaders(
num_dataloaders, num_workers, dataloader_type, use_graphbolt, return_eids
num_dataloaders, num_workers, dataloader_type, use_graphbolt
):
if dataloader_type == "edge" and use_graphbolt:
# GraphBolt does not support edge dataloader.
return
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
os.environ["DGL_NUM_SAMPLER"] = str(num_workers)
Expand All @@ -1001,7 +1006,6 @@ def test_multiple_dist_dataloaders(
num_parts,
test_dir,
use_graphbolt=use_graphbolt,
store_eids=return_eids,
)
part_config = os.path.join(test_dir, f"{graph_name}.json")

Expand Down

0 comments on commit 3edc195

Please sign in to comment.