Skip to content

Commit

Permalink
Add Graph Transformer Layer (Dense Computation) (dmlc#4959)
Browse files Browse the repository at this point in the history
* Add GraphTransformerLayer (dense)

* beautify the python code with black

* refine according to mufei's comments

* fix AttributeError in unit test

* rename module as GraphormerLayer

* fix name issue

Co-authored-by: Mufei Li <[email protected]>
  • Loading branch information
ZHITENGLI and mufeili authored Dec 7, 2022
1 parent a80bd9e commit 279e2e8
Show file tree
Hide file tree
Showing 3 changed files with 243 additions and 98 deletions.
1 change: 1 addition & 0 deletions docs/source/api/python/nn-pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ Utility Modules
~dgl.nn.pytorch.graph_transformer.DegreeEncoder
~dgl.nn.pytorch.utils.LaplacianPosEnc
~dgl.nn.pytorch.graph_transformer.BiasedMultiheadAttention
~dgl.nn.pytorch.graph_transformer.GraphormerLayer
~dgl.nn.pytorch.graph_transformer.PathEncoder

Network Embedding Modules
Expand Down
315 changes: 218 additions & 97 deletions python/dgl/nn/pytorch/graph_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from ...batch import unbatch
from ...transforms import shortest_dist

__all__ = ["DegreeEncoder",
"BiasedMultiheadAttention",
"PathEncoder"]

__all__ = [
"DegreeEncoder",
"PathEncoder",
"BiasedMultiheadAttention",
"GraphormerLayer"
]

class DegreeEncoder(nn.Module):
r"""Degree Encoder, as introduced in
Expand Down Expand Up @@ -92,6 +94,118 @@ def forward(self, g):
return degree_embedding


class PathEncoder(nn.Module):
r"""Path Encoder, as introduced in Edge Encoding of
`Do Transformers Really Perform Bad for Graph Representation?
<https://proceedings.neurips.cc/paper/2021/file/f1c1592588411002af340cbaedd6fc33-Paper.pdf>`__
This module is a learnable path embedding module and encodes the shortest
path between each pair of nodes as attention bias.
Parameters
----------
max_len : int
Maximum number of edges in each path to be encoded.
Exceeding part of each path will be truncated, i.e.
truncating edges with serial number no less than :attr:`max_len`.
feat_dim : int
Dimension of edge features in the input graph.
num_heads : int, optional
Number of attention heads if multi-head attention mechanism is applied.
Default : 1.
Examples
--------
>>> import torch as th
>>> import dgl
>>> u = th.tensor([0, 0, 0, 1, 1, 2, 3, 3])
>>> v = th.tensor([1, 2, 3, 0, 3, 0, 0, 1])
>>> g = dgl.graph((u, v))
>>> edata = th.rand(8, 16)
>>> path_encoder = dgl.PathEncoder(2, 16, 8)
>>> out = path_encoder(g, edata)
"""

def __init__(self, max_len, feat_dim, num_heads=1):
super().__init__()
self.max_len = max_len
self.feat_dim = feat_dim
self.num_heads = num_heads
self.embedding_table = nn.Embedding(max_len * num_heads, feat_dim)

def forward(self, g, edge_feat):
"""
Parameters
----------
g : DGLGraph
A DGLGraph to be encoded, which must be a homogeneous one.
edge_feat : torch.Tensor
The input edge feature of shape :math:`(E, feat_dim)`,
where :math:`E` is the number of edges in the input graph.
Returns
-------
torch.Tensor
Return attention bias as path encoding,
of shape :math:`(batch_size, N, N, num_heads)`,
where :math:`N` is the maximum number of nodes
and batch_size is the batch size of the input graph.
"""

g_list = unbatch(g)
sum_num_edges = 0
max_num_nodes = th.max(g.batch_num_nodes())
path_encoding = []

for ubg in g_list:
num_nodes = ubg.num_nodes()
num_edges = ubg.num_edges()
edata = edge_feat[sum_num_edges: (sum_num_edges + num_edges)]
sum_num_edges = sum_num_edges + num_edges
edata = th.cat(
(edata, th.zeros(1, self.feat_dim).to(edata.device)),
dim=0
)
_, path = shortest_dist(ubg, root=None, return_paths=True)
path_len = min(self.max_len, path.size(dim=2))

# shape: [n, n, l], n = num_nodes, l = path_len
shortest_path = path[:, :, 0: path_len]
# shape: [n, n]
shortest_distance = th.clamp(
shortest_dist(ubg, root=None, return_paths=False),
min=1,
max=path_len
)
# shape: [n, n, l, d], d = feat_dim
path_data = edata[shortest_path]
# shape: [l, h], h = num_heads
embedding_idx = th.reshape(
th.arange(self.num_heads * path_len),
(path_len, self.num_heads)
).to(next(self.embedding_table.parameters()).device)
# shape: [d, l, h]
edge_embedding = th.permute(
self.embedding_table(embedding_idx), (2, 0, 1)
)

# [n, n, l, d] einsum [d, l, h] -> [n, n, h]
# [n, n, h] -> [N, N, h], N = max_num_nodes, padded with -inf
sub_encoding = th.full(
(max_num_nodes, max_num_nodes, self.num_heads),
float('-inf')
)
sub_encoding[0: num_nodes, 0: num_nodes] = th.div(
th.einsum(
'xyld,dlh->xyh', path_data, edge_embedding
).permute(2, 0, 1),
shortest_distance
).permute(1, 2, 0)
path_encoding.append(sub_encoding)

return th.stack(path_encoding, dim=0)


class BiasedMultiheadAttention(nn.Module):
r"""Dense Multi-Head Attention Module with Graph Attention Bias.
Expand Down Expand Up @@ -226,113 +340,120 @@ def forward(self, ndata, attn_bias=None, attn_mask=None):
return attn


class PathEncoder(nn.Module):
r"""Path Encoder, as introduced in Edge Encoding of
`Do Transformers Really Perform Bad for Graph Representation?
<https://proceedings.neurips.cc/paper/2021/file/f1c1592588411002af340cbaedd6fc33-Paper.pdf>`__
This module is a learnable path embedding module and encodes the shortest
path between each pair of nodes as attention bias.
class GraphormerLayer(nn.Module):
r"""Graphormer Layer with Dense Multi-Head Attention, as introduced
in `Do Transformers Really Perform Bad for Graph Representation?
<https://arxiv.org/pdf/2106.05234>`__
Parameters
----------
max_len : int
Maximum number of edges in each path to be encoded.
Exceeding part of each path will be truncated, i.e.
truncating edges with serial number no less than :attr:`max_len`.
feat_dim : int
Dimension of edge features in the input graph.
num_heads : int, optional
Number of attention heads if multi-head attention mechanism is applied.
Default : 1.
feat_size : int
Feature size.
hidden_size : int
Hidden size of feedforward layers.
num_heads : int
Number of attention heads, by which :attr:`feat_size` is divisible.
attn_bias_type : str, optional
The type of attention bias used for modifying attention. Selected from
'add' or 'mul'. Default: 'add'.
* 'add' is for additive attention bias.
* 'mul' is for multiplicative attention bias.
norm_first : bool, optional
If True, it performs layer normalization before attention and
feedforward operations. Otherwise, it applies layer normalization
afterwards. Default: False.
dropout : float, optional
Dropout probability. Default: 0.1.
activation : callable activation layer, optional
Activation function. Default: nn.ReLU().
Examples
--------
>>> import torch as th
>>> import dgl
>>> u = th.tensor([0, 0, 0, 1, 1, 2, 3, 3])
>>> v = th.tensor([1, 2, 3, 0, 3, 0, 0, 1])
>>> g = dgl.graph((u, v))
>>> edata = th.rand(8, 16)
>>> path_encoder = dgl.PathEncoder(2, 16, 8)
>>> out = path_encoder(g, edata)
>>> from dgl.nn import GraphormerLayer
>>> batch_size = 16
>>> num_nodes = 100
>>> feat_size = 512
>>> num_heads = 8
>>> nfeat = th.rand(batch_size, num_nodes, feat_size)
>>> bias = th.rand(batch_size, num_nodes, num_nodes, num_heads)
>>> net = GraphormerLayer(
feat_size=feat_size,
hidden_size=2048,
num_heads=num_heads
)
>>> out = net(nfeat, bias)
"""

def __init__(self, max_len, feat_dim, num_heads=1):
def __init__(
self,
feat_size,
hidden_size,
num_heads,
attn_bias_type='add',
norm_first=False,
dropout=0.1,
activation=nn.ReLU()
):
super().__init__()
self.max_len = max_len
self.feat_dim = feat_dim
self.num_heads = num_heads
self.embedding_table = nn.Embedding(max_len * num_heads, feat_dim)

def forward(self, g, edge_feat):
"""
self.norm_first = norm_first

self.attn = BiasedMultiheadAttention(
feat_size=feat_size,
num_heads=num_heads,
attn_bias_type=attn_bias_type,
attn_drop=dropout
)
self.ffn = nn.Sequential(
nn.Linear(feat_size, hidden_size),
activation,
nn.Dropout(p=dropout),
nn.Linear(hidden_size, feat_size),
nn.Dropout(p=dropout)
)

self.dropout = nn.Dropout(p=dropout)
self.attn_layer_norm = nn.LayerNorm(feat_size)
self.ffn_layer_norm = nn.LayerNorm(feat_size)

def forward(self, nfeat, attn_bias=None, attn_mask=None):
"""Forward computation.
Parameters
----------
g : DGLGraph
A DGLGraph to be encoded, which must be a homogeneous one.
edge_feat : torch.Tensor
The input edge feature of shape :math:`(E, feat_dim)`,
where :math:`E` is the number of edges in the input graph.
nfeat : torch.Tensor
A 3D input tensor. Shape: (batch_size, N, :attr:`feat_size`), where
N is the maximum number of nodes.
attn_bias : torch.Tensor, optional
The attention bias used for attention modification. Shape:
(batch_size, N, N, :attr:`num_heads`).
attn_mask : torch.Tensor, optional
The attention mask used for avoiding computation on invalid
positions. Shape: (batch_size, N, N).
Returns
-------
torch.Tensor
Return attention bias as path encoding,
of shape :math:`(batch_size, N, N, num_heads)`,
where :math:`N` is the maximum number of nodes
and batch_size is the batch size of the input graph.
y : torch.Tensor
The output tensor. Shape: (batch_size, N, :attr:`feat_size`)
"""

g_list = unbatch(g)
sum_num_edges = 0
max_num_nodes = th.max(g.batch_num_nodes())
path_encoding = []

for ubg in g_list:
num_nodes = ubg.num_nodes()
num_edges = ubg.num_edges()
edata = edge_feat[sum_num_edges: (sum_num_edges + num_edges)]
sum_num_edges = sum_num_edges + num_edges
edata = th.cat(
(edata, th.zeros(1, self.feat_dim).to(edata.device)),
dim=0
)
_, path = shortest_dist(ubg, root=None, return_paths=True)
path_len = min(self.max_len, path.size(dim=2))

# shape: [n, n, l], n = num_nodes, l = path_len
shortest_path = path[:, :, 0: path_len]
# shape: [n, n]
shortest_distance = th.clamp(
shortest_dist(ubg, root=None, return_paths=False),
min=1,
max=path_len
)
# shape: [n, n, l, d], d = feat_dim
path_data = edata[shortest_path]
# shape: [l, h], h = num_heads
embedding_idx = th.reshape(
th.arange(self.num_heads * path_len),
(path_len, self.num_heads)
).to(next(self.embedding_table.parameters()).device)
# shape: [d, l, h]
edge_embedding = th.permute(
self.embedding_table(embedding_idx), (2, 0, 1)
)

# [n, n, l, d] einsum [d, l, h] -> [n, n, h]
# [n, n, h] -> [N, N, h], N = max_num_nodes, padded with -inf
sub_encoding = th.full(
(max_num_nodes, max_num_nodes, self.num_heads),
float('-inf')
)
sub_encoding[0: num_nodes, 0: num_nodes] = th.div(
th.einsum(
'xyld,dlh->xyh', path_data, edge_embedding
).permute(2, 0, 1),
shortest_distance
).permute(1, 2, 0)
path_encoding.append(sub_encoding)

return th.stack(path_encoding, dim=0)
residual = nfeat
if self.norm_first:
nfeat = self.attn_layer_norm(nfeat)
nfeat = self.attn(nfeat, attn_bias, attn_mask)
nfeat = self.dropout(nfeat)
nfeat = residual + nfeat
if not self.norm_first:
nfeat = self.attn_layer_norm(nfeat)

residual = nfeat
if self.norm_first:
nfeat = self.ffn_layer_norm(nfeat)
nfeat = self.ffn(nfeat)
nfeat = residual + nfeat
if not self.norm_first:
nfeat = self.ffn_layer_norm(nfeat)

return nfeat
25 changes: 24 additions & 1 deletion tests/pytorch/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1802,6 +1802,30 @@ def test_BiasedMultiheadAttention(feat_size, num_heads, bias, attn_bias_type, at

assert out.shape == (16, 100, feat_size)

@pytest.mark.parametrize('attn_bias_type', ['add', 'mul'])
@pytest.mark.parametrize('norm_first', [True, False])
def test_GraphormerLayer(attn_bias_type, norm_first):
batch_size = 16
num_nodes = 100
feat_size = 512
num_heads = 8
nfeat = th.rand(batch_size, num_nodes, feat_size)
attn_bias = th.rand(batch_size, num_nodes, num_nodes, num_heads)
attn_mask = th.rand(batch_size, num_nodes, num_nodes) < 0.5

net = nn.GraphormerLayer(
feat_size=feat_size,
hidden_size=2048,
num_heads=num_heads,
attn_bias_type=attn_bias_type,
norm_first=norm_first,
dropout=0.1,
activation=th.nn.ReLU()
)
out = net(nfeat, attn_bias, attn_mask)

assert out.shape == (batch_size, num_nodes, feat_size)

@pytest.mark.parametrize('max_len', [1, 4])
@pytest.mark.parametrize('feat_dim', [16])
@pytest.mark.parametrize('num_heads', [1, 8])
Expand All @@ -1820,4 +1844,3 @@ def test_PathEncoder(max_len, feat_dim, num_heads):
model = nn.PathEncoder(max_len, feat_dim, num_heads=num_heads).to(dev)
bias = model(bg, edge_feat)
assert bias.shape == (2, 6, 6, num_heads)

0 comments on commit 279e2e8

Please sign in to comment.