Skip to content

Commit

Permalink
Added models
Browse files Browse the repository at this point in the history
  • Loading branch information
gcorso committed Apr 5, 2020
0 parents commit e48a3ba
Show file tree
Hide file tree
Showing 13 changed files with 1,021 additions and 0 deletions.
78 changes: 78 additions & 0 deletions models/gat/layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import torch
import torch.nn as nn
import torch.nn.functional as F

class GATHead(nn.Module):

def __init__(self, in_features, out_features, alpha, activation=True, device='cpu'):
super(GATHead, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.activation = activation

self.W = nn.Parameter(torch.zeros(size=(in_features, out_features), device=device))
self.a = nn.Parameter(torch.zeros(size=(2 * out_features, 1), device=device))
self.leakyrelu = nn.LeakyReLU(alpha)

self.reset_parameters()

def reset_parameters(self):
nn.init.xavier_uniform_(self.W.data, gain=0.1414)
nn.init.xavier_uniform_(self.a.data, gain=0.1414)

def forward(self, input, adj):

h = torch.matmul(input, self.W)
(B, N, _) = adj.shape
a_input = torch.cat([h.repeat(1, 1, N).view(B, N * N, -1), h.repeat(1, N, 1)], dim=1).view(B, N, -1, 2 * self.out_features)
e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(3))

zero_vec = -9e15 * torch.ones_like(e)

attention = torch.where(adj > 0, e, zero_vec)
attention = F.softmax(attention, dim=1)
h_prime = torch.matmul(attention, h)

if self.activation:
return F.elu(h_prime)
else:
return h_prime

def __repr__(self):
return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'


class GATLayer(nn.Module):
"""
Graph Attention Layer, GAT paper at https://arxiv.org/abs/1710.10903
Implementation inspired by https://github.com/Diego999/pyGAT
"""

def __init__(self, in_features, out_features, alpha, nheads=1, activation=True, device='cpu'):
"""
:param in_features: size of the input per node
:param out_features: size of the output per node
:param alpha: slope of the leaky relu
:param nheads: number of attention heads
:param activation: whether to apply a non-linearity
:param device: device used for computation
"""
super(GATLayer, self).__init__()
assert (out_features % nheads == 0)

self.input_head = in_features
self.output_head = out_features // nheads

self.heads = nn.ModuleList()
for _ in range(nheads):
self.heads.append(GATHead(in_features=self.input_head, out_features=self.output_head, alpha=alpha,
activation=activation, device=device))

def forward(self, input, adj):
y = torch.cat([head(input, adj) for head in self.heads], dim=2)
return y

def __repr__(self):
return self.__class__.__name__ + ' (' \
+ str(self.in_features) + ' -> ' \
+ str(self.out_features) + ')'
34 changes: 34 additions & 0 deletions models/gat/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from __future__ import division
from __future__ import print_function

from models.gat.layer import GATLayer
from util.train import execute_train, build_arg_parser

# Training settings
parser = build_arg_parser()
parser.add_argument('--nheads', type=int, default=4, help='Number of attentions heads.')
parser.add_argument('--alpha', type=float, default=0.2, help='Alpha for the leaky_relu.')
args = parser.parse_args()

execute_train(gnn_args=dict(nfeat=None,
nhid=args.hidden,
nodes_out=None,
graph_out=None,
dropout=args.dropout,
device=None,
first_conv_descr=dict(layer_type=GATLayer,
args=dict(
nheads=args.nheads,
alpha=args.alpha
)),
middle_conv_descr=dict(layer_type=GATLayer,
args=dict(
nheads=args.nheads,
alpha=args.alpha
)),
fc_layers=args.fc_layers,
conv_layers=args.conv_layers,
skip=args.skip,
gru=args.gru,
fixed=args.fixed,
variable=args.variable), args=args)
56 changes: 56 additions & 0 deletions models/gcn/layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


class GCNLayer(nn.Module):
"""
GCN layer, similar to https://arxiv.org/abs/1609.02907
Implementation inspired by https://github.com/tkipf/pygcn
"""

def __init__(self, in_features, out_features, bias=True, device='cpu'):
"""
:param in_features: size of the input per node
:param out_features: size of the output per node
:param bias: whether to add a learnable bias before the activation
:param device: device used for computation
"""
super(GCNLayer, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.device = device
self.W = nn.Parameter(torch.zeros(size=(in_features, out_features), device=device))
if bias:
self.b = nn.Parameter(torch.zeros(out_features, device=device))
else:
self.register_parameter('b', None)
self.reset_parameters()

def reset_parameters(self):
stdv = 1. / math.sqrt(self.W.size(1))
self.W.data.uniform_(-stdv, stdv)
if self.b is not None:
self.b.data.uniform_(-stdv, stdv)

def forward(self, X, adj):
(B, N, _) = adj.shape

# linear transformation
XW = torch.matmul(X, self.W)

# normalised mean aggregation
adj = adj + torch.eye(N, device=self.device).unsqueeze(0)
rD = torch.mul(torch.pow(torch.sum(adj, -1, keepdim=True), -0.5), torch.eye(N, device=self.device).unsqueeze(0)) # D^{-1/2]
adj = torch.matmul(torch.matmul(rD, adj), rD) # D^{-1/2] A' D^{-1/2]
y = torch.bmm(adj, XW)

if self.b is not None:
y = y + self.b
return F.leaky_relu(y)

def __repr__(self):
return self.__class__.__name__ + ' (' \
+ str(self.in_features) + ' -> ' \
+ str(self.out_features) + ')'
25 changes: 25 additions & 0 deletions models/gcn/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from __future__ import division
from __future__ import print_function

from models.gnn_framework import GNN
from models.gcn.layer import GCNLayer
from util.train import execute_train, build_arg_parser

# Training settings
parser = build_arg_parser()
args = parser.parse_args()

execute_train(gnn_args=dict(nfeat=None,
nhid=args.hidden,
nodes_out=None,
graph_out=None,
dropout=args.dropout,
device=None,
first_conv_descr=dict(layer_type=GCNLayer, args=dict()),
middle_conv_descr=dict(layer_type=GCNLayer, args=dict()),
fc_layers=args.fc_layers,
conv_layers=args.conv_layers,
skip=args.skip,
gru=args.gru,
fixed=args.fixed,
variable=args.variable), args=args)
45 changes: 45 additions & 0 deletions models/gin/layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch
import torch.nn as nn
from util.layers import FCLayer, MLP


class GINLayer(nn.Module):
"""
Graph Isomorphism Network layer, similar to https://arxiv.org/abs/1810.00826
"""

def __init__(self, in_features, out_features, fc_layers=2, device='cpu'):
"""
:param in_features: size of the input per node
:param out_features: size of the output per node
:param fc_layers: number of fully connected layers after the sum aggregator
:param device: device used for computation
"""
super(GINLayer, self).__init__()

self.device = device
self.in_features = in_features
self.out_features = out_features
self.epsilon = nn.Parameter(torch.zeros(size=(1,), device=device))
self.post_transformation = MLP(in_size=in_features, hidden_size=max(in_features, out_features), out_size=out_features,
layers=fc_layers, mid_activation='relu', last_activation='relu', mid_b_norm=True,
last_b_norm=False, device=device)
self.reset_parameters()

def reset_parameters(self):
self.epsilon.data.fill_(0.1)

def forward(self, input, adj):
(B, N, _) = adj.shape

# sum aggregation
mod_adj = adj + torch.eye(N, device=self.device).unsqueeze(0)*(1+self.epsilon)
support = torch.matmul(mod_adj, input)

# post-aggregation transformation
return self.post_transformation(support)

def __repr__(self):
return self.__class__.__name__ + ' (' \
+ str(self.in_features) + ' -> ' \
+ str(self.out_features) + ')'
25 changes: 25 additions & 0 deletions models/gin/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from __future__ import division
from __future__ import print_function

from models.gin.layer import GINLayer
from util.train import execute_train, build_arg_parser

# Training settings
parser = build_arg_parser()
parser.add_argument('--gin_fc_layers', type=int, default=2, help='Number of fully connected layers after the aggregation.')
args = parser.parse_args()

execute_train(gnn_args=dict(nfeat=None,
nhid=args.hidden,
nodes_out=None,
graph_out=None,
dropout=args.dropout,
device=None,
first_conv_descr=dict(layer_type=GINLayer, args=dict(fc_layers=args.gin_fc_layers)),
middle_conv_descr=dict(layer_type=GINLayer, args=dict(fc_layers=args.gin_fc_layers)),
fc_layers=args.fc_layers,
conv_layers=args.conv_layers,
skip=args.skip,
gru=args.gru,
fixed=args.fixed,
variable=args.variable), args=args)
106 changes: 106 additions & 0 deletions models/gnn_framework.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import types
import torch
import torch.nn as nn
import torch.nn.functional as F
from util.layers import GRU, S2SReadout, MLP


class GNN(nn.Module):
def __init__(self, nfeat, nhid, nodes_out, graph_out, dropout, conv_layers=2, fc_layers=3, first_conv_descr=None,
middle_conv_descr=None, final_activation='LeakyReLU', skip=False, gru=False, fixed=False, variable=False,
device='cpu'):
"""
:param nfeat: number of input features per node
:param nhid: number of hidden features per node
:param nodes_out: number of nodes' labels
:param graph_out: number of graph labels
:param dropout: dropout value
:param conv_layers: if variable, conv_layers should be a function : adj -> int, otherwise an int
:param fc_layers: number of fully connected layers before the labels
:param first_conv_descr: dict or SimpleNamespace: "type"-> type of layer, "args" -> dict of calling args
:param middle_conv_descr: dict or SimpleNamespace : "type"-> type of layer, "args" -> dict of calling args
:param final_activation: activation to be used on the last fc layer before the labels
:param skip: whether to use skip connections feeding to the readout
:param gru: whether to use a shared GRU after each convolution
:param fixed: whether to reuse the same middle convolutional layer multiple times
:param variable: whether the number of convolutional layers is variable or fixed
:param device: device used for computation
"""
super(GNN, self).__init__()
if variable:
assert callable(conv_layers), "conv_layers should be a function from adjacency matrix to int"
assert fixed, "With a variable number of layers they must be fixed"
assert not skip, "cannot have skip and fixed at the same time"
else:
assert type(conv_layers) == int, "conv_layers should be an int"
assert conv_layers > 0, "conv_layers should be greater than 0"

if type(first_conv_descr) == dict:
first_conv_descr = types.SimpleNamespace(**first_conv_descr)
assert type(first_conv_descr) == types.SimpleNamespace, "first_conv_descr should be either a dict or a SimpleNamespace"
if type(first_conv_descr.args) == dict:
first_conv_descr.args = types.SimpleNamespace(**first_conv_descr.args)
assert type(first_conv_descr.args) == types.SimpleNamespace, "first_conv_descr.args should be either a dict or a SimpleNamespace"

if type(middle_conv_descr) == dict:
middle_conv_descr = types.SimpleNamespace(**middle_conv_descr)
assert type(middle_conv_descr) == types.SimpleNamespace, "middle_conv_descr should be either a dict or a SimpleNamespace"
if type(middle_conv_descr.args) == dict:
middle_conv_descr.args = types.SimpleNamespace(**middle_conv_descr.args)
assert type(middle_conv_descr.args) == types.SimpleNamespace, "middle_conv_descr.args should be either a dict or a SimpleNamespace"

self.dropout = dropout
self.conv_layers = nn.ModuleList()
self.skip = skip
self.fixed = fixed
self.variable = variable
self.n_fixed_conv = conv_layers
self.gru = GRU(input_size=nhid, hidden_size=nhid, device=device) if gru else None

# first graph convolution
first_conv_descr.args.in_features = nfeat
first_conv_descr.args.out_features = nhid
first_conv_descr.args.device = device
self.conv_layers.append(first_conv_descr.layer_type(**vars(first_conv_descr.args)))

# middle graph convolutions
middle_conv_descr.args.in_features = nhid
middle_conv_descr.args.out_features = nhid
middle_conv_descr.args.device = device
for l in range(1 if fixed else conv_layers - 1):
self.conv_layers.append(
middle_conv_descr.layer_type(**vars(middle_conv_descr.args)))

n_conv_out = nfeat + conv_layers * nhid if skip else nhid

# nodes output: fully connected layers
self.nodes_read_out = MLP(in_size=n_conv_out, hidden_size=n_conv_out, out_size=nodes_out, layers=fc_layers,
mid_activation="LeakyReLU", last_activation=final_activation, device=device)

# graph output: S2S readout
self.graph_read_out = S2SReadout(n_conv_out, n_conv_out, graph_out, fc_layers=fc_layers, device=device,
final_activation=final_activation)

def forward(self, x, adj):
# graph convolutions
skip_connections = [x] if self.skip else None

n_layers = self.n_fixed_conv(adj) if self.variable else self.n_fixed_conv
conv_layers = [self.conv_layers[0]] + ([self.conv_layers[1]] * (n_layers-1)) if self.fixed else self.conv_layers

for layer, conv in enumerate(conv_layers):
y = conv(x, adj)
x = y if self.gru is None else self.gru(x, y)

if self.skip:
skip_connections.append(x)

# dropout at all layers but the last
if layer != n_layers - 1:
x = F.dropout(x, self.dropout, training=self.training)

if self.skip:
x = torch.cat(skip_connections, dim=2)

# readout output
return (self.nodes_read_out(x), self.graph_read_out(x))
Loading

0 comments on commit e48a3ba

Please sign in to comment.