Skip to content

Commit

Permalink
[NN] TransE and TransR (dmlc#3530)
Browse files Browse the repository at this point in the history
* Update

* Update

* Update

* Update

* Update

* CI

* CI

* CI

Co-authored-by: Jinjing Zhou <[email protected]>
  • Loading branch information
mufeili and VoVAllen authored Dec 7, 2021
1 parent c3103b6 commit d6eecf9
Show file tree
Hide file tree
Showing 7 changed files with 250 additions and 4 deletions.
14 changes: 14 additions & 0 deletions docs/source/api/python/nn.pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,20 @@ EdgePredictor
:members: forward, reset_parameters
:show-inheritance:

TransE
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: dgl.nn.pytorch.link.TransE
:members: rel_emb, forward, reset_parameters
:show-inheritance:

TransR
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: dgl.nn.pytorch.link.TransR
:members: rel_emb, rel_project, forward, reset_parameters
:show-inheritance:

Heterogeneous Graph Convolution Module
----------------------------------------

Expand Down
2 changes: 1 addition & 1 deletion python/dgl/nn/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Package for pytorch-specific NN modules."""
from .conv import *
from .explain import *
from .link import *
from .glob import *
from .softmax import *
from .factory import *
from .hetero import *
from .utils import Sequential, WeightBasis, JumpingKnowledge
from .sparse_emb import NodeEmbedding
from .link import *
5 changes: 5 additions & 0 deletions python/dgl/nn/pytorch/link/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Torch modules for link prediction/knowledge graph completion."""

from .edgepred import EdgePredictor
from .transe import TransE
from .transr import TransR
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
"""Torch modules for link prediction."""
"""Predictor for edges in homogeneous graphs."""
# pylint: disable= no-member, arguments-differ, invalid-name, W0235
import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = ['EdgePredictor']

class EdgePredictor(nn.Module):
r"""
Expand Down
100 changes: 100 additions & 0 deletions python/dgl/nn/pytorch/link/transe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""TransE."""
# pylint: disable= no-member, arguments-differ, invalid-name, W0235
import torch
import torch.nn as nn

class TransE(nn.Module):
r"""
Description
-----------
Similarity measure introduced in `Translating Embeddings for Modeling Multi-relational Data
<https://papers.nips.cc/paper/2013/hash/1cecc7a77928ca8133fa24680a88d2f9-Abstract.html>`__.
Mathematically, it is defined as follows:
.. math::
- {\| h + r - t \|}_p
where :math:`h` is the head embedding, :math:`r` is the relation embedding, and
:math:`t` is the tail embedding.
Parameters
----------
num_rels : int
Number of relation types.
feats : int
Embedding size.
p : int, optional
The p to use for Lp norm, which can be 1 or 2.
Attributes
----------
rel_emb : torch.nn.Embedding
The learnable relation type embedding.
Examples
--------
>>> import dgl
>>> import torch as th
>>> from dgl.nn import TransE
>>> # input features
>>> num_nodes = 10
>>> num_edges = 30
>>> num_rels = 3
>>> feats = 4
>>> scorer = TransE(num_rels=num_rels, feats=feats)
>>> g = dgl.rand_graph(num_nodes=num_nodes, num_edges=num_edges)
>>> src, dst = g.edges()
>>> h = th.randn(num_nodes, feats)
>>> h_head = h[src]
>>> h_tail = h[dst]
>>> # Randomly initialize edge relation types for demonstration
>>> rels = th.randint(low=0, high=num_rels, size=(num_edges,))
>>> scorer(h_head, h_tail, rels).shape
torch.Size([30])
"""
def __init__(self, num_rels, feats, p=1):
super(TransE, self).__init__()

self.rel_emb = nn.Embedding(num_rels, feats)
self.p = p

def reset_parameters(self):
r"""
Description
-----------
Reinitialize learnable parameters.
"""
self.rel_emb.reset_parameters()

def forward(self, h_head, h_tail, rels):
r"""
Description
-----------
Score triples.
Parameters
----------
h_head : torch.Tensor
Head entity features. The tensor is of shape :math:`(E, D)`, where
:math:`E` is the number of triples, and :math:`D` is the feature size.
h_tail : torch.Tensor
Tail entity features. The tensor is of shape :math:`(E, D)`, where
:math:`E` is the number of triples, and :math:`D` is the feature size.
rels : torch.Tensor
Relation types. It is a LongTensor of shape :math:`(E)`, where
:math:`E` is the number of triples.
Returns
-------
torch.Tensor
The triple scores. The tensor is of shape :math:`(E)`.
"""
h_rel = self.rel_emb(rels)

return - torch.norm(h_head + h_rel - h_tail, p=self.p, dim=-1)
109 changes: 109 additions & 0 deletions python/dgl/nn/pytorch/link/transr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""TransR."""
# pylint: disable= no-member, arguments-differ, invalid-name, W0235
import torch
import torch.nn as nn

class TransR(nn.Module):
r"""
Description
-----------
Similarity measure introduced in
`Learning entity and relation embeddings for knowledge graph completion
<https://ojs.aaai.org/index.php/AAAI/article/view/9491>`__. Mathematically,
it is defined as follows:
.. math::
- {\| M_r h + r - M_r t \|}_p
where :math:`M_r` is a relation-specific projection matrix, :math:`h` is the
head embedding, :math:`r` is the relation embedding, and :math:`t` is the tail embedding.
Parameters
----------
num_rels : int
Number of relation types.
rfeats : int
Relation embedding size.
nfeats : int
Entity embedding size.
p : int, optional
The p to use for Lp norm, which can be 1 or 2.
Attributes
----------
rel_emb : torch.nn.Embedding
The learnable relation type embedding.
rel_project : torch.nn.Embedding
The learnable relation-type-specific projection.
Examples
--------
>>> import dgl
>>> import torch as th
>>> from dgl.nn import TransR
>>> # input features
>>> num_nodes = 10
>>> num_edges = 30
>>> num_rels = 3
>>> feats = 4
>>> scorer = TransE(num_rels=num_rels, rfeats=2, nfeats=feats)
>>> g = dgl.rand_graph(num_nodes=num_nodes, num_edges=num_edges)
>>> src, dst = g.edges()
>>> h = th.randn(num_nodes, feats)
>>> h_head = h[src]
>>> h_tail = h[dst]
>>> # Randomly initialize edge relation types for demonstration
>>> rels = th.randint(low=0, high=num_rels, size=(num_edges,))
>>> scorer(h_head, h_tail, rels).shape
torch.Size([30])
"""
def __init__(self, num_rels, rfeats, nfeats, p=1):
super(TransR, self).__init__()

self.rel_emb = nn.Embedding(num_rels, rfeats)
self.rel_project = nn.Embedding(num_rels, nfeats * rfeats)
self.rfeats = rfeats
self.nfeats = nfeats
self.p = p

def reset_parameters(self):
r"""
Description
-----------
Reinitialize learnable parameters.
"""
self.rel_emb.reset_parameters()
self.rel_project.reset_parameters()

def forward(self, h_head, h_tail, rels):
r"""
Score triples.
Parameters
----------
h_head : torch.Tensor
Head entity features. The tensor is of shape :math:`(E, D)`, where
:math:`E` is the number of triples, and :math:`D` is the feature size.
h_tail : torch.Tensor
Tail entity features. The tensor is of shape :math:`(E, D)`, where
:math:`E` is the number of triples, and :math:`D` is the feature size.
rels : torch.Tensor
Relation types. It is a LongTensor of shape :math:`(E)`, where
:math:`E` is the number of triples.
Returns
-------
torch.Tensor
The triple scores. The tensor is of shape :math:`(E)`.
"""
h_rel = self.rel_emb(rels)
proj_rel = self.rel_project(rels).reshape(-1, self.nfeats, self.rfeats)
h_head = (h_head.unsqueeze(1) @ proj_rel).squeeze(1)
h_tail = (h_tail.unsqueeze(1) @ proj_rel).squeeze(1)

return - torch.norm(h_head + h_rel - h_tail, p=self.p, dim=-1)
20 changes: 20 additions & 0 deletions tests/pytorch/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1323,6 +1323,26 @@ def test_edge_predictor(op):
pred = nn.EdgePredictor(op, in_feats, out_feats, bias=True).to(ctx)
assert pred(h_src, h_dst).shape == (num_pairs, out_feats)


def test_ke_score_funcs():
ctx = F.ctx()
num_edges = 30
num_rels = 3
nfeats = 4

h_src = th.randn((num_edges, nfeats)).to(ctx)
h_dst = th.randn((num_edges, nfeats)).to(ctx)
rels = th.randint(low=0, high=num_rels, size=(num_edges,)).to(ctx)

score_func = nn.TransE(num_rels=num_rels, feats=nfeats).to(ctx)
score_func.reset_parameters()
score_func(h_src, h_dst, rels).shape == (num_edges)

score_func = nn.TransR(num_rels=num_rels, rfeats=nfeats - 1, nfeats=nfeats).to(ctx)
score_func.reset_parameters()
score_func(h_src, h_dst, rels).shape == (num_edges)


def test_twirls():
g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
feat = th.ones(6, 10)
Expand Down

0 comments on commit d6eecf9

Please sign in to comment.