Skip to content

Commit

Permalink
[GraphBolt] Remove old version exclude edges. (dmlc#7299)
Browse files Browse the repository at this point in the history
Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
yxy235 and Ubuntu authored Apr 12, 2024
1 parent 279ebb1 commit 20e5e26
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 220 deletions.
2 changes: 1 addition & 1 deletion python/dgl/graphbolt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@ def load_graphbolt():
unique_and_compact,
unique_and_compact_csc_formats,
)
from .utils import add_reverse_edges, add_reverse_edges_2, exclude_seed_edges
from .utils import add_reverse_edges, exclude_seed_edges
64 changes: 14 additions & 50 deletions python/dgl/graphbolt/sampled_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,7 @@ def original_edge_ids(self) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
def exclude_edges(
self,
edges: Union[
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Dict[str, torch.Tensor],
Tuple[torch.Tensor, torch.Tensor],
torch.Tensor,
],
assume_num_node_within_int32: bool = True,
Expand All @@ -133,10 +131,9 @@ def exclude_edges(
----------
self : SampledSubgraph
The sampled subgraph.
edges : Union[Tuple[torch.Tensor, torch.Tensor],
Dict[str, Tuple[torch.Tensor, torch.Tensor]]]
edges : Union[torch.Tensor, Dict[str, torch.Tensor]]
Edges to exclude. If sampled subgraph is homogeneous, then `edges`
should be a pair of tensors representing the edges to exclude. If
should be a N*2 tensors representing the edges to exclude. If
sampled subgraph is heterogeneous, then `edges` should be a
dictionary of edge types and the corresponding edges to exclude.
assume_num_node_within_int32: bool
Expand Down Expand Up @@ -165,8 +162,7 @@ def exclude_edges(
... original_row_node_ids=original_row_node_ids,
... original_edge_ids=original_edge_ids
... )
>>> edges_to_exclude = {"A:relation:B": (torch.tensor([14, 15]),
... torch.tensor([11, 12]))}
>>> edges_to_exclude = {"A:relation:B": torch.tensor([[14, 11], [15, 12]])}
>>> result = subgraph.exclude_edges(edges_to_exclude)
>>> print(result.sampled_csc)
{'A:relation:B': CSCFormatBase(indptr=tensor([0, 1, 1, 1]),
Expand All @@ -183,9 +179,9 @@ def exclude_edges(
assert (
assume_num_node_within_int32
), "Values > int32 are not supported yet."
assert (
isinstance(self.sampled_csc, (CSCFormatBase, tuple))
) == isinstance(edges, (tuple, torch.Tensor)), (
assert (isinstance(self.sampled_csc, CSCFormatBase)) == isinstance(
edges, torch.Tensor
), (
"The sampled subgraph and the edges to exclude should be both "
"homogeneous or both heterogeneous."
)
Expand All @@ -202,14 +198,9 @@ def exclude_edges(
self.original_row_node_ids,
self.original_column_node_ids,
)
if isinstance(edges, torch.Tensor):
index = _exclude_homo_edges_2(
reverse_edges, edges, assume_num_node_within_int32
)
else:
index = _exclude_homo_edges(
reverse_edges, edges, assume_num_node_within_int32
)
index = _exclude_homo_edges(
reverse_edges, edges, assume_num_node_within_int32
)
return calling_class(*_slice_subgraph(self, index))
else:
index = {}
Expand All @@ -234,18 +225,11 @@ def exclude_edges(
original_row_node_ids,
original_column_node_ids,
)
if isinstance(edges[etype], torch.Tensor):
index[etype] = _exclude_homo_edges_2(
reverse_edges,
edges[etype],
assume_num_node_within_int32,
)
else:
index[etype] = _exclude_homo_edges(
reverse_edges,
edges[etype],
assume_num_node_within_int32,
)
index[etype] = _exclude_homo_edges(
reverse_edges,
edges[etype],
assume_num_node_within_int32,
)
return calling_class(*_slice_subgraph(self, index))

def to(self, device: torch.device) -> None: # pylint: disable=invalid-name
Expand Down Expand Up @@ -286,26 +270,6 @@ def _relabel_two_arrays(lhs_array, rhs_array):


def _exclude_homo_edges(
edges: Tuple[torch.Tensor, torch.Tensor],
edges_to_exclude: Tuple[torch.Tensor, torch.Tensor],
assume_num_node_within_int32: bool,
):
"""Return the indices of edges to be included."""
if assume_num_node_within_int32:
val = edges[0].long() << 32 | edges[1].long()
val_to_exclude = (
edges_to_exclude[0].long() << 32 | edges_to_exclude[1].long()
)
else:
# TODO: Add support for value > int32.
raise NotImplementedError(
"Values out of range int32 are not supported yet"
)
mask = ~isin(val, val_to_exclude)
return torch.nonzero(mask, as_tuple=True)[0]


def _exclude_homo_edges_2(
edges: Tuple[torch.Tensor, torch.Tensor],
edges_to_exclude: torch.Tensor,
assume_num_node_within_int32: bool,
Expand Down
85 changes: 6 additions & 79 deletions python/dgl/graphbolt/utils.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,13 @@
"""Utility functions for external use."""

from typing import Dict, Tuple, Union
from typing import Dict, Union

import torch

from .minibatch import MiniBatch


def add_reverse_edges(
edges: Union[
Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor],
],
reverse_etypes_mapping: Dict[str, str] = None,
):
r"""
This function finds the reverse edges of the given `edges` and returns the
composition of them. In a homogeneous graph, reverse edges have inverted
source and destination node IDs. While in a heterogeneous graph, reversing
also involves swapping node IDs and their types. This function could be
used before `exclude_edges` function to help find targeting edges.
Note: The found reverse edges may not really exists in the original graph.
And repeat edges could be added becasue reverse edges may already exists in
the `edges`.
Parameters
----------
edges : Union[Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor]]
- If sampled subgraph is homogeneous, then `edges` should be a pair of
of tensors.
- If sampled subgraph is heterogeneous, then `edges` should be a
dictionary of edge types and the corresponding edges to exclude.
reverse_etypes_mapping : Dict[str, str], optional
The mapping from the original edge types to their reverse edge types.
Returns
-------
Union[Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor]]
The node pairs contain both the original edges and their reverse
counterparts.
Examples
--------
>>> edges = {"A:r:B": (torch.tensor([0, 1]), torch.tensor([1, 2]))}
>>> print(gb.add_reverse_edges(edges, {"A:r:B": "B:rr:A"}))
{'A:r:B': (tensor([0, 1]), tensor([1, 2])),
'B:rr:A': (tensor([1, 2]), tensor([0, 1]))}
>>> edges = (torch.tensor([0, 1]), torch.tensor([2, 1]))
>>> print(gb.add_reverse_edges(edges))
(tensor([0, 1, 2, 1]), tensor([2, 1, 0, 1]))
"""
if isinstance(edges, tuple):
u, v = edges
return (torch.cat([u, v]), torch.cat([v, u]))
else:
combined_edges = edges.copy()
for etype, reverse_etype in reverse_etypes_mapping.items():
if etype in edges:
if reverse_etype in combined_edges:
u, v = combined_edges[reverse_etype]
u = torch.cat([u, edges[etype][1]])
v = torch.cat([v, edges[etype][0]])
combined_edges[reverse_etype] = (u, v)
else:
combined_edges[reverse_etype] = (
edges[etype][1],
edges[etype][0],
)
return combined_edges


def add_reverse_edges_2(
edges: Union[Dict[str, torch.Tensor], torch.Tensor],
reverse_etypes_mapping: Dict[str, str] = None,
):
Expand Down Expand Up @@ -157,18 +91,11 @@ def exclude_seed_edges(
reverse_etypes_mapping : Dict[str, str] = None
The mapping from the original edge types to their reverse edge types.
"""
if minibatch.node_pairs is not None:
edges_to_exclude = minibatch.node_pairs
if include_reverse_edges:
edges_to_exclude = add_reverse_edges(
minibatch.node_pairs, reverse_etypes_mapping
)
else:
edges_to_exclude = minibatch.seeds
if include_reverse_edges:
edges_to_exclude = add_reverse_edges_2(
edges_to_exclude, reverse_etypes_mapping
)
edges_to_exclude = minibatch.seeds
if include_reverse_edges:
edges_to_exclude = add_reverse_edges(
edges_to_exclude, reverse_etypes_mapping
)
minibatch.sampled_subgraphs = [
subgraph.exclude_edges(edges_to_exclude)
for subgraph in minibatch.sampled_subgraphs
Expand Down
34 changes: 23 additions & 11 deletions tests/python/pytorch/graphbolt/impl/test_sampled_subgraph_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_exclude_edges_homo_deduplicated(reverse_row, reverse_column):
original_row_node_ids,
original_edge_ids,
)
edges_to_exclude = (src_to_exclude, dst_to_exclude)
edges_to_exclude = torch.cat((src_to_exclude, dst_to_exclude)).view(2, -1).T
result = subgraph.exclude_edges(edges_to_exclude)
expected_csc_formats = gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 1, 2, 2, 2]), indices=torch.tensor([0, 3])
Expand Down Expand Up @@ -107,7 +107,7 @@ def test_exclude_edges_homo_duplicated(reverse_row, reverse_column):
original_row_node_ids,
original_edge_ids,
)
edges_to_exclude = (src_to_exclude, dst_to_exclude)
edges_to_exclude = torch.cat((src_to_exclude, dst_to_exclude)).view(2, -1).T
result = subgraph.exclude_edges(edges_to_exclude)
expected_csc_formats = gb.CSCFormatBase(
indptr=torch.tensor([0, 0, 1, 1, 1, 3]), indices=torch.tensor([0, 2, 2])
Expand Down Expand Up @@ -163,10 +163,14 @@ def test_exclude_edges_hetero_deduplicated(reverse_row, reverse_column):
)

edges_to_exclude = {
"A:relation:B": (
src_to_exclude,
dst_to_exclude,
"A:relation:B": torch.cat(
(
src_to_exclude,
dst_to_exclude,
)
)
.view(2, -1)
.T
}
result = subgraph.exclude_edges(edges_to_exclude)
expected_csc_formats = {
Expand Down Expand Up @@ -231,10 +235,14 @@ def test_exclude_edges_hetero_duplicated(reverse_row, reverse_column):
)

edges_to_exclude = {
"A:relation:B": (
src_to_exclude,
dst_to_exclude,
"A:relation:B": torch.cat(
(
src_to_exclude,
dst_to_exclude,
)
)
.view(2, -1)
.T
}
result = subgraph.exclude_edges(edges_to_exclude)
expected_csc_formats = {
Expand Down Expand Up @@ -525,10 +533,14 @@ def test_sampled_subgraph_to_device():
original_edge_ids=original_edge_ids,
)
edges_to_exclude = {
"A:relation:B": (
src_to_exclude,
dst_to_exclude,
"A:relation:B": torch.cat(
(
src_to_exclude,
dst_to_exclude,
)
)
.view(2, -1)
.T
}
graph = subgraph.exclude_edges(edges_to_exclude)

Expand Down
61 changes: 25 additions & 36 deletions tests/python/pytorch/graphbolt/test_graphbolt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,67 +5,56 @@


def test_find_reverse_edges_homo():
edges = (torch.tensor([1, 3, 5]), torch.tensor([2, 4, 5]))
edges = torch.tensor([[1, 3, 5], [2, 4, 5]]).T
edges = gb.add_reverse_edges(edges)
expected_edges = (
torch.tensor([1, 3, 5, 2, 4, 5]),
torch.tensor([2, 4, 5, 1, 3, 5]),
)
assert torch.equal(edges[0], expected_edges[0])
expected_edges = torch.tensor([[1, 3, 5, 2, 4, 5], [2, 4, 5, 1, 3, 5]]).T
assert torch.equal(edges, expected_edges)
assert torch.equal(edges[1], expected_edges[1])


def test_find_reverse_edges_hetero():
edges = {
"A:r:B": (torch.tensor([1, 5]), torch.tensor([2, 5])),
"B:rr:A": (torch.tensor([3]), torch.tensor([3])),
"A:r:B": torch.tensor([[1, 5], [2, 5]]).T,
"B:rr:A": torch.tensor([[3], [3]]).T,
}
edges = gb.add_reverse_edges(edges, {"A:r:B": "B:rr:A"})
expected_edges = {
"A:r:B": (torch.tensor([1, 5]), torch.tensor([2, 5])),
"B:rr:A": (torch.tensor([3, 2, 5]), torch.tensor([3, 1, 5])),
"A:r:B": torch.tensor([[1, 5], [2, 5]]).T,
"B:rr:A": torch.tensor([[3, 2, 5], [3, 1, 5]]).T,
}
assert torch.equal(edges["A:r:B"][0], expected_edges["A:r:B"][0])
assert torch.equal(edges["A:r:B"][1], expected_edges["A:r:B"][1])
assert torch.equal(edges["B:rr:A"][0], expected_edges["B:rr:A"][0])
assert torch.equal(edges["B:rr:A"][1], expected_edges["B:rr:A"][1])
assert torch.equal(edges["A:r:B"], expected_edges["A:r:B"])
assert torch.equal(edges["B:rr:A"], expected_edges["B:rr:A"])


def test_find_reverse_edges_bi_reverse_types():
edges = {
"A:r:B": (torch.tensor([1, 5]), torch.tensor([2, 5])),
"B:rr:A": (torch.tensor([3]), torch.tensor([3])),
"A:r:B": torch.tensor([[1, 5], [2, 5]]).T,
"B:rr:A": torch.tensor([[3], [3]]).T,
}
edges = gb.add_reverse_edges(edges, {"A:r:B": "B:rr:A", "B:rr:A": "A:r:B"})
expected_edges = {
"A:r:B": (torch.tensor([1, 5, 3]), torch.tensor([2, 5, 3])),
"B:rr:A": (torch.tensor([3, 2, 5]), torch.tensor([3, 1, 5])),
"A:r:B": torch.tensor([[1, 5, 3], [2, 5, 3]]).T,
"B:rr:A": torch.tensor([[3, 2, 5], [3, 1, 5]]).T,
}
assert torch.equal(edges["A:r:B"][0], expected_edges["A:r:B"][0])
assert torch.equal(edges["A:r:B"][1], expected_edges["A:r:B"][1])
assert torch.equal(edges["B:rr:A"][0], expected_edges["B:rr:A"][0])
assert torch.equal(edges["B:rr:A"][1], expected_edges["B:rr:A"][1])
assert torch.equal(edges["A:r:B"], expected_edges["A:r:B"])
assert torch.equal(edges["B:rr:A"], expected_edges["B:rr:A"])


def test_find_reverse_edges_circual_reverse_types():
edges = {
"A:r1:B": (torch.tensor([1]), torch.tensor([1])),
"B:r2:C": (torch.tensor([2]), torch.tensor([2])),
"C:r3:A": (torch.tensor([3]), torch.tensor([3])),
"A:r1:B": torch.tensor([[1, 1]]),
"B:r2:C": torch.tensor([[2, 2]]),
"C:r3:A": torch.tensor([[3, 3]]),
}
edges = gb.add_reverse_edges(
edges, {"A:r1:B": "B:r2:C", "B:r2:C": "C:r3:A", "C:r3:A": "A:r1:B"}
)
expected_edges = {
"A:r1:B": (torch.tensor([1, 3]), torch.tensor([1, 3])),
"B:r2:C": (torch.tensor([2, 1]), torch.tensor([2, 1])),
"C:r3:A": (torch.tensor([3, 2]), torch.tensor([3, 2])),
"A:r1:B": torch.tensor([[1, 3], [1, 3]]).T,
"B:r2:C": torch.tensor([[2, 1], [2, 1]]).T,
"C:r3:A": torch.tensor([[3, 2], [3, 2]]).T,
}
assert torch.equal(edges["A:r1:B"][0], expected_edges["A:r1:B"][0])
assert torch.equal(edges["A:r1:B"][1], expected_edges["A:r1:B"][1])
assert torch.equal(edges["B:r2:C"][0], expected_edges["B:r2:C"][0])
assert torch.equal(edges["B:r2:C"][1], expected_edges["B:r2:C"][1])
assert torch.equal(edges["A:r1:B"][0], expected_edges["A:r1:B"][0])
assert torch.equal(edges["A:r1:B"][1], expected_edges["A:r1:B"][1])
assert torch.equal(edges["C:r3:A"][0], expected_edges["C:r3:A"][0])
assert torch.equal(edges["C:r3:A"][1], expected_edges["C:r3:A"][1])
assert torch.equal(edges["A:r1:B"], expected_edges["A:r1:B"])
assert torch.equal(edges["B:r2:C"], expected_edges["B:r2:C"])
assert torch.equal(edges["A:r1:B"], expected_edges["A:r1:B"])
assert torch.equal(edges["C:r3:A"], expected_edges["C:r3:A"])
Loading

0 comments on commit 20e5e26

Please sign in to comment.