Skip to content

Commit

Permalink
[Graphbolt] Remove redundant data in to_dgl (dmlc#6466)
Browse files Browse the repository at this point in the history
  • Loading branch information
peizhou001 authored Oct 19, 2023
1 parent 625f8a6 commit 72b3e07
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 11 deletions.
8 changes: 6 additions & 2 deletions python/dgl/graphbolt/minibatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,12 +314,16 @@ def to_dgl(self):
"""
minibatch = DGLMiniBatch(
blocks=self._to_dgl_blocks(),
input_nodes=self.input_nodes,
output_nodes=self.seed_nodes,
node_features=self.node_features,
edge_features=self.edge_features,
labels=self.labels,
)
# Need input nodes to fetch feature.
if self.node_features is None:
minibatch.input_nodes = self.input_nodes
# Need output nodes to fetch label.
if self.labels is None:
minibatch.output_nodes = self.seed_nodes
assert (
minibatch.blocks is not None
), "Sampled subgraphs for computation are missing."
Expand Down
32 changes: 29 additions & 3 deletions tests/python/pytorch/graphbolt/impl/test_minibatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def create_homo_minibatch():
sampled_subgraphs=subgraphs,
node_features=node_features,
edge_features=edge_features,
input_nodes=torch.tensor([10, 11, 12, 13]),
)


Expand Down Expand Up @@ -103,6 +104,10 @@ def create_hetero_minibatch():
sampled_subgraphs=subgraphs,
node_features=node_features,
edge_features=edge_features,
input_nodes={
"A": torch.tensor([5, 7, 9, 11]),
"B": torch.tensor([10, 11, 12]),
},
)


Expand Down Expand Up @@ -286,7 +291,7 @@ def test_dgl_minibatch_representation():
node_features={'x': tensor([7, 6, 2, 2])},
negative_node_pairs=(tensor([0, 1, 2]), tensor([6, 0, 0])),
labels=tensor([0., 1., 2.]),
input_nodes=tensor([8, 1, 6, 5, 9, 0, 2, 4]),
input_nodes=None,
edge_features=[{'x': tensor([[8],
[1],
[6]])},
Expand Down Expand Up @@ -354,6 +359,25 @@ def check_dgl_blocks_homo(minibatch, blocks):
assert torch.equal(blocks[0].srcdata[dgl.NID], original_row_node_ids[0])


def test_to_dgl_node_classification_without_feature():
# Arrange
minibatch = create_homo_minibatch()
minibatch.node_features = None
minibatch.labels = None
minibatch.seed_nodes = torch.tensor([10, 15])
# Act
dgl_minibatch = minibatch.to_dgl()

# Assert
assert len(dgl_minibatch.blocks) == 2
assert dgl_minibatch.node_features is None
assert minibatch.edge_features is dgl_minibatch.edge_features
assert dgl_minibatch.labels is None
assert minibatch.input_nodes is dgl_minibatch.input_nodes
assert minibatch.seed_nodes is dgl_minibatch.output_nodes
check_dgl_blocks_homo(minibatch, dgl_minibatch.blocks)


def test_to_dgl_node_classification_homo():
# Arrange
minibatch = create_homo_minibatch()
Expand All @@ -367,7 +391,8 @@ def test_to_dgl_node_classification_homo():
assert minibatch.node_features is dgl_minibatch.node_features
assert minibatch.edge_features is dgl_minibatch.edge_features
assert minibatch.labels is dgl_minibatch.labels
assert minibatch.seed_nodes is dgl_minibatch.output_nodes
assert dgl_minibatch.input_nodes is None
assert dgl_minibatch.output_nodes is None
check_dgl_blocks_homo(minibatch, dgl_minibatch.blocks)


Expand All @@ -382,7 +407,8 @@ def test_to_dgl_node_classification_hetero():
assert minibatch.node_features is dgl_minibatch.node_features
assert minibatch.edge_features is dgl_minibatch.edge_features
assert minibatch.labels is dgl_minibatch.labels
assert minibatch.seed_nodes is dgl_minibatch.output_nodes
assert dgl_minibatch.input_nodes is None
assert dgl_minibatch.output_nodes is None
check_dgl_blocks_hetero(minibatch, dgl_minibatch.blocks)


Expand Down
12 changes: 6 additions & 6 deletions tests/python/pytorch/graphbolt/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_integration_link_prediction():
[0.5503, 0.8223]])},
negative_node_pairs=(tensor([0, 1, 1, 1]), tensor([0, 3, 4, 5])),
labels=None,
input_nodes=tensor([5, 3, 1, 2, 0, 4]),
input_nodes=None,
edge_features=[{},
{}],
blocks=[Block(num_src_nodes=6,
Expand All @@ -92,7 +92,7 @@ def test_integration_link_prediction():
[0.6172, 0.7865]])},
negative_node_pairs=(tensor([0, 1, 1, 2]), tensor([1, 3, 4, 1])),
labels=None,
input_nodes=tensor([3, 4, 0, 5, 1]),
input_nodes=None,
edge_features=[{},
{}],
blocks=[Block(num_src_nodes=5,
Expand All @@ -112,7 +112,7 @@ def test_integration_link_prediction():
[0.9634, 0.2294]])},
negative_node_pairs=(tensor([0, 1]), tensor([1, 2])),
labels=None,
input_nodes=tensor([5, 4, 3, 0]),
input_nodes=None,
edge_features=[{},
{}],
blocks=[Block(num_src_nodes=4,
Expand Down Expand Up @@ -193,7 +193,7 @@ def test_integration_node_classification():
[0.9634, 0.2294]])},
negative_node_pairs=None,
labels=None,
input_nodes=tensor([5, 3, 1, 2, 4, 0]),
input_nodes=None,
edge_features=[{},
{}],
blocks=[Block(num_src_nodes=6,
Expand All @@ -212,7 +212,7 @@ def test_integration_node_classification():
[0.9634, 0.2294]])},
negative_node_pairs=None,
labels=None,
input_nodes=tensor([3, 4, 0]),
input_nodes=None,
edge_features=[{},
{}],
blocks=[Block(num_src_nodes=3,
Expand All @@ -231,7 +231,7 @@ def test_integration_node_classification():
[0.9634, 0.2294]])},
negative_node_pairs=None,
labels=None,
input_nodes=tensor([5, 4, 0]),
input_nodes=None,
edge_features=[{},
{}],
blocks=[Block(num_src_nodes=3,
Expand Down

0 comments on commit 72b3e07

Please sign in to comment.