Skip to content

Commit

Permalink
[GraphBolt] Update tests to confirm FusedCSCSampling graph can solve …
Browse files Browse the repository at this point in the history
…indptr and indices with different dtype. (dmlc#7378)

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
yxy235 and Ubuntu authored May 8, 2024
1 parent 67a897f commit 2da713f
Showing 1 changed file with 22 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1613,10 +1613,14 @@ def test_csc_sampling_graph_to_pinned_memory():
is_graph_pinned(graph)


@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("labor", [False, True])
@pytest.mark.parametrize("is_pinned", [False, True])
@pytest.mark.parametrize("nodes", [None, True])
def test_sample_neighbors_homo(labor, is_pinned, nodes):
def test_sample_neighbors_homo(
indptr_dtype, indices_dtype, labor, is_pinned, nodes
):
if is_pinned and nodes is None:
pytest.skip("Optional nodes and is_pinned is not supported together.")
"""Original graph in COO:
Expand All @@ -1630,8 +1634,10 @@ def test_sample_neighbors_homo(labor, is_pinned, nodes):
pytest.skip("Pinning is not meaningful without a GPU.")
# Initialize data.
total_num_edges = 12
indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
indptr = torch.tensor([0, 3, 5, 7, 9, 12], dtype=indptr_dtype)
indices = torch.tensor(
[0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4], dtype=indices_dtype
)
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)

Expand All @@ -1642,7 +1648,7 @@ def test_sample_neighbors_homo(labor, is_pinned, nodes):

# Generate subgraph via sample neighbors.
if nodes:
nodes = torch.LongTensor([1, 3, 4]).to(F.ctx())
nodes = torch.tensor([1, 3, 4], dtype=indices_dtype).to(F.ctx())
elif F._default_context_str != "gpu":
pytest.skip("Optional nodes is supported only for the GPU.")
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
Expand All @@ -1662,8 +1668,10 @@ def test_sample_neighbors_homo(labor, is_pinned, nodes):
assert subgraph.original_edge_ids is None


@pytest.mark.parametrize("indptr_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("indices_dtype", [torch.int32, torch.int64])
@pytest.mark.parametrize("labor", [False, True])
def test_sample_neighbors_hetero(labor):
def test_sample_neighbors_hetero(indptr_dtype, indices_dtype, labor):
"""Original graph in COO:
"n1:e1:n2":[0, 0, 1, 1, 1], [0, 2, 0, 1, 2]
"n2:e2:n1":[0, 0, 1, 2], [0, 1, 1 ,0]
Expand All @@ -1677,10 +1685,12 @@ def test_sample_neighbors_hetero(labor):
ntypes = {"n1": 0, "n2": 1}
etypes = {"n1:e1:n2": 0, "n2:e2:n1": 1}
total_num_edges = 9
indptr = torch.LongTensor([0, 2, 4, 6, 7, 9])
indices = torch.LongTensor([2, 4, 2, 3, 0, 1, 1, 0, 1])
type_per_edge = torch.LongTensor([1, 1, 1, 1, 0, 0, 0, 0, 0])
node_type_offset = torch.LongTensor([0, 2, 5])
indptr = torch.tensor([0, 2, 4, 6, 7, 9], dtype=indptr_dtype)
indices = torch.tensor([2, 4, 2, 3, 0, 1, 1, 0, 1], dtype=indices_dtype)
type_per_edge = torch.tensor(
[1, 1, 1, 1, 0, 0, 0, 0, 0], dtype=indices_dtype
)
node_type_offset = torch.tensor([0, 2, 5], dtype=indices_dtype)
assert indptr[-1] == total_num_edges
assert indptr[-1] == len(indices)

Expand All @@ -1696,8 +1706,8 @@ def test_sample_neighbors_hetero(labor):

# Sample on both node types.
nodes = {
"n1": torch.tensor([0], device=F.ctx()),
"n2": torch.tensor([0], device=F.ctx()),
"n1": torch.tensor([0], dtype=indices_dtype, device=F.ctx()),
"n2": torch.tensor([0], dtype=indices_dtype, device=F.ctx()),
}
fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
Expand Down Expand Up @@ -1725,7 +1735,7 @@ def test_sample_neighbors_hetero(labor):
assert subgraph.original_edge_ids is None

# Sample on single node type.
nodes = {"n1": torch.tensor([0], device=F.ctx())}
nodes = {"n1": torch.tensor([0], dtype=indices_dtype, device=F.ctx())}
fanouts = torch.tensor([-1, -1])
sampler = graph.sample_layer_neighbors if labor else graph.sample_neighbors
subgraph = sampler(nodes, fanouts)
Expand Down

0 comments on commit 2da713f

Please sign in to comment.