Skip to content

Commit

Permalink
[NN] Add HeteroGraphConv module for cleaner module definition (dmlc#1385
Browse files Browse the repository at this point in the history
)

* Add HeteroGraphConv

* add custom aggregator; some docstring

* debugging

* rm print

* fix some acc bugs

* fix initialization problem in weight basis

* passed tests

* lint

* fix graphconv flag; add error message

* add mxnet heteroconv

* more fix for mx

* lint

* fix torch cuda test

* fix mx test_nn

* add exhaust test for graphconv

* add tf heteroconv

* fix comment
  • Loading branch information
jermainewang authored Mar 27, 2020
1 parent bbfff8c commit 3efb5d8
Show file tree
Hide file tree
Showing 20 changed files with 1,513 additions and 527 deletions.
225 changes: 1 addition & 224 deletions examples/pytorch/rgcn-hetero/entity_classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,232 +8,9 @@
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from functools import partial

import dgl.function as fn
from dgl.data.rdf import AIFB, MUTAG, BGS, AM

class RelGraphConvHetero(nn.Module):
r"""Relational graph convolution layer.
Parameters
----------
in_feat : int
Input feature size.
out_feat : int
Output feature size.
rel_names : int
Relation names.
regularizer : str
Which weight regularizer to use "basis" or "bdd"
num_bases : int, optional
Number of bases. If is none, use number of relations. Default: None.
bias : bool, optional
True if bias is added. Default: True
activation : callable, optional
Activation function. Default: None
self_loop : bool, optional
True to include self loop message. Default: False
use_weight : bool, optional
If True, multiply the input node feature with a learnable weight matrix
before message passing.
dropout : float, optional
Dropout rate. Default: 0.0
"""
def __init__(self,
in_feat,
out_feat,
rel_names,
regularizer="basis",
num_bases=None,
bias=True,
activation=None,
self_loop=False,
use_weight=True,
dropout=0.0):
super(RelGraphConvHetero, self).__init__()
self.in_feat = in_feat
self.out_feat = out_feat
self.rel_names = rel_names
self.num_rels = len(rel_names)
self.regularizer = regularizer
self.num_bases = num_bases
if self.num_bases is None or self.num_bases > self.num_rels or self.num_bases < 0:
self.num_bases = self.num_rels
self.bias = bias
self.activation = activation
self.self_loop = self_loop

self.use_weight = use_weight
if use_weight:
if regularizer == "basis":
# add basis weights
self.weight = nn.Parameter(th.Tensor(self.num_bases, self.in_feat, self.out_feat))
if self.num_bases < self.num_rels:
# linear combination coefficients
self.w_comp = nn.Parameter(th.Tensor(self.num_rels, self.num_bases))
nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain('relu'))
if self.num_bases < self.num_rels:
nn.init.xavier_uniform_(self.w_comp,
gain=nn.init.calculate_gain('relu'))
else:
raise ValueError("Only basis regularizer is supported.")

# bias
if self.bias:
self.h_bias = nn.Parameter(th.Tensor(out_feat))
nn.init.zeros_(self.h_bias)

# weight for self loop
if self.self_loop:
self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))
nn.init.xavier_uniform_(self.loop_weight,
gain=nn.init.calculate_gain('relu'))

self.dropout = nn.Dropout(dropout)

def basis_weight(self):
"""Message function for basis regularizer"""
if self.num_bases < self.num_rels:
# generate all weights from bases
weight = self.weight.view(self.num_bases,
self.in_feat * self.out_feat)
weight = th.matmul(self.w_comp, weight).view(
self.num_rels, self.in_feat, self.out_feat)
else:
weight = self.weight
return {self.rel_names[i] : w.squeeze(0) for i, w in enumerate(th.split(weight, 1, dim=0))}

def forward(self, g, xs):
"""Forward computation
Parameters
----------
g : DGLHeteroGraph
Input graph.
xs : dict[str, torch.Tensor]
Node feature for each node type.
Returns
-------
list of torch.Tensor
New node features for each node type.
"""
g = g.local_var()
for ntype in g.ntypes:
g.nodes[ntype].data['x'] = xs[ntype]
if self.use_weight:
ws = self.basis_weight()
funcs = {}
for i, (srctype, etype, dsttype) in enumerate(g.canonical_etypes):
g.nodes[srctype].data['h%d' % i] = th.matmul(
g.nodes[srctype].data['x'], ws[etype])
funcs[(srctype, etype, dsttype)] = (fn.copy_u('h%d' % i, 'm'), fn.mean('m', 'h'))
else:
funcs = {}
for i, (srctype, etype, dsttype) in enumerate(g.canonical_etypes):
g.nodes[srctype].data['h%d' % i] = g.nodes[srctype].data['x']
funcs[(srctype, etype, dsttype)] = (fn.copy_u('h%d' % i, 'm'), fn.mean('m', 'h'))
# message passing
g.multi_update_all(funcs, 'sum')

hs = {ntype : g.nodes[ntype].data['h'] for ntype in g.ntypes}
new_hs = {}
for ntype, h in hs.items():
# apply bias and activation
if self.self_loop:
h = h + th.matmul(xs[ntype], self.loop_weight)
if self.bias:
h = h + self.h_bias
if self.activation:
h = self.activation(h)
h = self.dropout(h)
new_hs[ntype] = h
return new_hs

class RelGraphEmbed(nn.Module):
r"""Embedding layer for featureless heterograph."""
def __init__(self,
g,
embed_size,
embed_name='embed',
activation=None,
dropout=0.0):
super(RelGraphEmbed, self).__init__()
self.g = g
self.embed_size = embed_size
self.embed_name = embed_name
self.activation = activation
self.dropout = nn.Dropout(dropout)

# create weight embeddings for each node for each relation
self.embeds = nn.ParameterDict()
for ntype in g.ntypes:
embed = nn.Parameter(th.Tensor(g.number_of_nodes(ntype), self.embed_size))
nn.init.xavier_uniform_(embed, gain=nn.init.calculate_gain('relu'))
self.embeds[ntype] = embed


def forward(self, block=None):
"""Forward computation
Parameters
----------
block : DGLHeteroGraph, optional
If not specified, directly return the full graph with embeddings stored in
:attr:`embed_name`. Otherwise, extract and store the embeddings to the block
graph and return.
Returns
-------
DGLHeteroGraph
The block graph fed with embeddings.
"""
return self.embeds

class EntityClassify(nn.Module):
def __init__(self,
g,
h_dim, out_dim,
num_bases,
num_hidden_layers=1,
dropout=0,
use_self_loop=False):
super(EntityClassify, self).__init__()
self.g = g
self.h_dim = h_dim
self.out_dim = out_dim
self.rel_names = list(set(g.etypes))
self.rel_names.sort()
self.num_bases = None if num_bases < 0 else num_bases
self.num_hidden_layers = num_hidden_layers
self.dropout = dropout
self.use_self_loop = use_self_loop

self.embed_layer = RelGraphEmbed(g, self.h_dim)
self.layers = nn.ModuleList()
# i2h
self.layers.append(RelGraphConvHetero(
self.h_dim, self.h_dim, self.rel_names, "basis",
self.num_bases, activation=F.relu, self_loop=self.use_self_loop,
dropout=self.dropout, use_weight=False))
# h2h
for i in range(self.num_hidden_layers):
self.layers.append(RelGraphConvHetero(
self.h_dim, self.h_dim, self.rel_names, "basis",
self.num_bases, activation=F.relu, self_loop=self.use_self_loop,
dropout=self.dropout))
# h2o
self.layers.append(RelGraphConvHetero(
self.h_dim, self.out_dim, self.rel_names, "basis",
self.num_bases, activation=None,
self_loop=self.use_self_loop))

def forward(self):
h = self.embed_layer()
for layer in self.layers:
h = layer(self.g, h)
return h
from model import EntityClassify

def main(args):
# load graph data
Expand Down
Loading

0 comments on commit 3efb5d8

Please sign in to comment.