Skip to content

Commit

Permalink
[Graphbolt] Convert a data block to dgl graphs (dmlc#6228)
Browse files Browse the repository at this point in the history
  • Loading branch information
peizhou001 authored Sep 4, 2023
1 parent 268f456 commit ac49220
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 0 deletions.
90 changes: 90 additions & 0 deletions python/dgl/graphbolt/minibatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

import torch

import dgl

from .base import etype_str_to_tuple
from .sampled_subgraph import SampledSubgraph

__all__ = ["MiniBatch"]
Expand Down Expand Up @@ -122,3 +125,90 @@ class MiniBatch:
Representation of compacted nodes corresponding to 'negative_tail', where
all node ids inside are compacted.
"""

def to_dgl_graphs(self):
"""Transforming a data graph into DGL graphs necessitates constructing a
graphical structure and assigning features to the nodes and edges within
the graphs.
"""
if not self.sampled_subgraphs:
return None

is_heterogeneous = isinstance(
self.sampled_subgraphs[0].node_pairs, Dict
)

if is_heterogeneous:
graphs = []
for subgraph in self.sampled_subgraphs:
graphs.append(
dgl.heterograph(
{
etype_str_to_tuple(etype): v
for etype, v in subgraph.node_pairs.items()
}
)
)
else:
graphs = [
dgl.graph(subgraph.node_pairs)
for subgraph in self.sampled_subgraphs
]

if is_heterogeneous:
# Assign node features to the outermost layer's nodes.
if self.node_features:
for (
node_type,
feature_name,
), feature in self.node_features.items():
graphs[0].nodes[node_type].data[feature_name] = feature
# Assign edge features.
if self.edge_features:
for graph, edge_feature in zip(graphs, self.edge_features):
for (
edge_type,
feature_name,
), feature in edge_feature.items():
graph.edges[etype_str_to_tuple(edge_type)].data[
feature_name
] = feature
# Assign reverse node ids to the outermost layer's nodes.
reverse_row_node_ids = self.sampled_subgraphs[
0
].reverse_row_node_ids
if reverse_row_node_ids:
for node_type, reverse_ids in reverse_row_node_ids.items():
graphs[0].nodes[node_type].data[dgl.NID] = reverse_ids
# Assign reverse edges ids.
for graph, subgraph in zip(graphs, self.sampled_subgraphs):
if subgraph.reverse_edge_ids:
for (
edge_type,
reverse_ids,
) in subgraph.reverse_edge_ids.items():
graph.edges[etype_str_to_tuple(edge_type)].data[
dgl.EID
] = reverse_ids
else:
# Assign node features to the outermost layer's nodes.
if self.node_features:
for feature_name, feature in self.node_features.items():
graphs[0].ndata[feature_name] = feature
# Assign edge features.
if self.edge_features:
for graph, edge_feature in zip(graphs, self.edge_features):
for feature_name, feature in edge_feature.items():
graph.edata[feature_name] = feature
# Assign reverse node ids.
reverse_row_node_ids = self.sampled_subgraphs[
0
].reverse_row_node_ids
if reverse_row_node_ids is not None:
graphs[0].ndata[dgl.NID] = reverse_row_node_ids
# Assign reverse edges ids.
for graph, subgraph in zip(graphs, self.sampled_subgraphs):
if subgraph.reverse_edge_ids is not None:
graph.edata[dgl.EID] = subgraph.reverse_edge_ids

return graphs
69 changes: 69 additions & 0 deletions tests/python/pytorch/graphbolt/impl/test_to_dgl_graphs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import dgl
import dgl.graphbolt as gb
import torch


def test_to_dgl_graphs_hetero():
relation = "A:relation:B"
node_pairs = {relation: (torch.tensor([0, 1, 2]), torch.tensor([0, 4, 5]))}
reverse_column_node_ids = {"B": torch.tensor([10, 11, 12, 13, 14, 16])}
reverse_row_node_ids = {
"A": torch.tensor([5, 9, 7]),
"B": torch.tensor([10, 11, 12, 13, 14, 16]),
}
reverse_edge_ids = {relation: torch.tensor([19, 20, 21])}
node_features = {
("A", "x"): torch.randint(0, 10, (3,)),
("B", "y"): torch.randint(0, 10, (6,)),
}
edge_features = {(relation, "x"): torch.randint(0, 10, (3,))}
subgraph = gb.SampledSubgraphImpl(
node_pairs=node_pairs,
reverse_column_node_ids=reverse_column_node_ids,
reverse_row_node_ids=reverse_row_node_ids,
reverse_edge_ids=reverse_edge_ids,
)
g = gb.MiniBatch(
sampled_subgraphs=[subgraph],
node_features=node_features,
edge_features=[edge_features],
).to_dgl_graphs()[0]

assert torch.equal(g.edges()[0], node_pairs[relation][0])
assert torch.equal(g.edges()[1], node_pairs[relation][1])
assert torch.equal(g.ndata[dgl.NID]["A"], reverse_row_node_ids["A"])
assert torch.equal(g.ndata[dgl.NID]["B"], reverse_row_node_ids["B"])
assert torch.equal(g.edata[dgl.EID], reverse_edge_ids[relation])
assert torch.equal(g.nodes["A"].data["x"], node_features[("A", "x")])
assert torch.equal(g.nodes["B"].data["y"], node_features[("B", "y")])
assert torch.equal(
g.edges[gb.etype_str_to_tuple(relation)].data["x"],
edge_features[(relation, "x")],
)


def test_to_dgl_graphs_homo():
node_pairs = (torch.tensor([0, 1, 2]), torch.tensor([0, 4, 5]))
reverse_column_node_ids = torch.tensor([10, 11, 12])
reverse_row_node_ids = torch.tensor([10, 11, 12, 13, 14, 16])
reverse_edge_ids = torch.tensor([19, 20, 21])
node_features = {"x": torch.randint(0, 10, (6,))}
edge_features = {"x": torch.randint(0, 10, (3,))}
subgraph = gb.SampledSubgraphImpl(
node_pairs=node_pairs,
reverse_column_node_ids=reverse_column_node_ids,
reverse_row_node_ids=reverse_row_node_ids,
reverse_edge_ids=reverse_edge_ids,
)
g = gb.MiniBatch(
sampled_subgraphs=[subgraph],
node_features=node_features,
edge_features=[edge_features],
).to_dgl_graphs()[0]

assert torch.equal(g.edges()[0], node_pairs[0])
assert torch.equal(g.edges()[1], node_pairs[1])
assert torch.equal(g.ndata[dgl.NID], reverse_row_node_ids)
assert torch.equal(g.edata[dgl.EID], reverse_edge_ids)
assert torch.equal(g.ndata["x"], node_features["x"])
assert torch.equal(g.edata["x"], edge_features["x"])

0 comments on commit ac49220

Please sign in to comment.