forked from lukecavabarrett/pna
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit e48a3ba
Showing
13 changed files
with
1,021 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
class GATHead(nn.Module): | ||
|
||
def __init__(self, in_features, out_features, alpha, activation=True, device='cpu'): | ||
super(GATHead, self).__init__() | ||
self.in_features = in_features | ||
self.out_features = out_features | ||
self.activation = activation | ||
|
||
self.W = nn.Parameter(torch.zeros(size=(in_features, out_features), device=device)) | ||
self.a = nn.Parameter(torch.zeros(size=(2 * out_features, 1), device=device)) | ||
self.leakyrelu = nn.LeakyReLU(alpha) | ||
|
||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
nn.init.xavier_uniform_(self.W.data, gain=0.1414) | ||
nn.init.xavier_uniform_(self.a.data, gain=0.1414) | ||
|
||
def forward(self, input, adj): | ||
|
||
h = torch.matmul(input, self.W) | ||
(B, N, _) = adj.shape | ||
a_input = torch.cat([h.repeat(1, 1, N).view(B, N * N, -1), h.repeat(1, N, 1)], dim=1).view(B, N, -1, 2 * self.out_features) | ||
e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(3)) | ||
|
||
zero_vec = -9e15 * torch.ones_like(e) | ||
|
||
attention = torch.where(adj > 0, e, zero_vec) | ||
attention = F.softmax(attention, dim=1) | ||
h_prime = torch.matmul(attention, h) | ||
|
||
if self.activation: | ||
return F.elu(h_prime) | ||
else: | ||
return h_prime | ||
|
||
def __repr__(self): | ||
return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' | ||
|
||
|
||
class GATLayer(nn.Module): | ||
""" | ||
Graph Attention Layer, GAT paper at https://arxiv.org/abs/1710.10903 | ||
Implementation inspired by https://github.com/Diego999/pyGAT | ||
""" | ||
|
||
def __init__(self, in_features, out_features, alpha, nheads=1, activation=True, device='cpu'): | ||
""" | ||
:param in_features: size of the input per node | ||
:param out_features: size of the output per node | ||
:param alpha: slope of the leaky relu | ||
:param nheads: number of attention heads | ||
:param activation: whether to apply a non-linearity | ||
:param device: device used for computation | ||
""" | ||
super(GATLayer, self).__init__() | ||
assert (out_features % nheads == 0) | ||
|
||
self.input_head = in_features | ||
self.output_head = out_features // nheads | ||
|
||
self.heads = nn.ModuleList() | ||
for _ in range(nheads): | ||
self.heads.append(GATHead(in_features=self.input_head, out_features=self.output_head, alpha=alpha, | ||
activation=activation, device=device)) | ||
|
||
def forward(self, input, adj): | ||
y = torch.cat([head(input, adj) for head in self.heads], dim=2) | ||
return y | ||
|
||
def __repr__(self): | ||
return self.__class__.__name__ + ' (' \ | ||
+ str(self.in_features) + ' -> ' \ | ||
+ str(self.out_features) + ')' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
from models.gat.layer import GATLayer | ||
from util.train import execute_train, build_arg_parser | ||
|
||
# Training settings | ||
parser = build_arg_parser() | ||
parser.add_argument('--nheads', type=int, default=4, help='Number of attentions heads.') | ||
parser.add_argument('--alpha', type=float, default=0.2, help='Alpha for the leaky_relu.') | ||
args = parser.parse_args() | ||
|
||
execute_train(gnn_args=dict(nfeat=None, | ||
nhid=args.hidden, | ||
nodes_out=None, | ||
graph_out=None, | ||
dropout=args.dropout, | ||
device=None, | ||
first_conv_descr=dict(layer_type=GATLayer, | ||
args=dict( | ||
nheads=args.nheads, | ||
alpha=args.alpha | ||
)), | ||
middle_conv_descr=dict(layer_type=GATLayer, | ||
args=dict( | ||
nheads=args.nheads, | ||
alpha=args.alpha | ||
)), | ||
fc_layers=args.fc_layers, | ||
conv_layers=args.conv_layers, | ||
skip=args.skip, | ||
gru=args.gru, | ||
fixed=args.fixed, | ||
variable=args.variable), args=args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import math | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class GCNLayer(nn.Module): | ||
""" | ||
GCN layer, similar to https://arxiv.org/abs/1609.02907 | ||
Implementation inspired by https://github.com/tkipf/pygcn | ||
""" | ||
|
||
def __init__(self, in_features, out_features, bias=True, device='cpu'): | ||
""" | ||
:param in_features: size of the input per node | ||
:param out_features: size of the output per node | ||
:param bias: whether to add a learnable bias before the activation | ||
:param device: device used for computation | ||
""" | ||
super(GCNLayer, self).__init__() | ||
self.in_features = in_features | ||
self.out_features = out_features | ||
self.device = device | ||
self.W = nn.Parameter(torch.zeros(size=(in_features, out_features), device=device)) | ||
if bias: | ||
self.b = nn.Parameter(torch.zeros(out_features, device=device)) | ||
else: | ||
self.register_parameter('b', None) | ||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
stdv = 1. / math.sqrt(self.W.size(1)) | ||
self.W.data.uniform_(-stdv, stdv) | ||
if self.b is not None: | ||
self.b.data.uniform_(-stdv, stdv) | ||
|
||
def forward(self, X, adj): | ||
(B, N, _) = adj.shape | ||
|
||
# linear transformation | ||
XW = torch.matmul(X, self.W) | ||
|
||
# normalised mean aggregation | ||
adj = adj + torch.eye(N, device=self.device).unsqueeze(0) | ||
rD = torch.mul(torch.pow(torch.sum(adj, -1, keepdim=True), -0.5), torch.eye(N, device=self.device).unsqueeze(0)) # D^{-1/2] | ||
adj = torch.matmul(torch.matmul(rD, adj), rD) # D^{-1/2] A' D^{-1/2] | ||
y = torch.bmm(adj, XW) | ||
|
||
if self.b is not None: | ||
y = y + self.b | ||
return F.leaky_relu(y) | ||
|
||
def __repr__(self): | ||
return self.__class__.__name__ + ' (' \ | ||
+ str(self.in_features) + ' -> ' \ | ||
+ str(self.out_features) + ')' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
from models.gnn_framework import GNN | ||
from models.gcn.layer import GCNLayer | ||
from util.train import execute_train, build_arg_parser | ||
|
||
# Training settings | ||
parser = build_arg_parser() | ||
args = parser.parse_args() | ||
|
||
execute_train(gnn_args=dict(nfeat=None, | ||
nhid=args.hidden, | ||
nodes_out=None, | ||
graph_out=None, | ||
dropout=args.dropout, | ||
device=None, | ||
first_conv_descr=dict(layer_type=GCNLayer, args=dict()), | ||
middle_conv_descr=dict(layer_type=GCNLayer, args=dict()), | ||
fc_layers=args.fc_layers, | ||
conv_layers=args.conv_layers, | ||
skip=args.skip, | ||
gru=args.gru, | ||
fixed=args.fixed, | ||
variable=args.variable), args=args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import torch | ||
import torch.nn as nn | ||
from util.layers import FCLayer, MLP | ||
|
||
|
||
class GINLayer(nn.Module): | ||
""" | ||
Graph Isomorphism Network layer, similar to https://arxiv.org/abs/1810.00826 | ||
""" | ||
|
||
def __init__(self, in_features, out_features, fc_layers=2, device='cpu'): | ||
""" | ||
:param in_features: size of the input per node | ||
:param out_features: size of the output per node | ||
:param fc_layers: number of fully connected layers after the sum aggregator | ||
:param device: device used for computation | ||
""" | ||
super(GINLayer, self).__init__() | ||
|
||
self.device = device | ||
self.in_features = in_features | ||
self.out_features = out_features | ||
self.epsilon = nn.Parameter(torch.zeros(size=(1,), device=device)) | ||
self.post_transformation = MLP(in_size=in_features, hidden_size=max(in_features, out_features), out_size=out_features, | ||
layers=fc_layers, mid_activation='relu', last_activation='relu', mid_b_norm=True, | ||
last_b_norm=False, device=device) | ||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
self.epsilon.data.fill_(0.1) | ||
|
||
def forward(self, input, adj): | ||
(B, N, _) = adj.shape | ||
|
||
# sum aggregation | ||
mod_adj = adj + torch.eye(N, device=self.device).unsqueeze(0)*(1+self.epsilon) | ||
support = torch.matmul(mod_adj, input) | ||
|
||
# post-aggregation transformation | ||
return self.post_transformation(support) | ||
|
||
def __repr__(self): | ||
return self.__class__.__name__ + ' (' \ | ||
+ str(self.in_features) + ' -> ' \ | ||
+ str(self.out_features) + ')' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
from models.gin.layer import GINLayer | ||
from util.train import execute_train, build_arg_parser | ||
|
||
# Training settings | ||
parser = build_arg_parser() | ||
parser.add_argument('--gin_fc_layers', type=int, default=2, help='Number of fully connected layers after the aggregation.') | ||
args = parser.parse_args() | ||
|
||
execute_train(gnn_args=dict(nfeat=None, | ||
nhid=args.hidden, | ||
nodes_out=None, | ||
graph_out=None, | ||
dropout=args.dropout, | ||
device=None, | ||
first_conv_descr=dict(layer_type=GINLayer, args=dict(fc_layers=args.gin_fc_layers)), | ||
middle_conv_descr=dict(layer_type=GINLayer, args=dict(fc_layers=args.gin_fc_layers)), | ||
fc_layers=args.fc_layers, | ||
conv_layers=args.conv_layers, | ||
skip=args.skip, | ||
gru=args.gru, | ||
fixed=args.fixed, | ||
variable=args.variable), args=args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import types | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from util.layers import GRU, S2SReadout, MLP | ||
|
||
|
||
class GNN(nn.Module): | ||
def __init__(self, nfeat, nhid, nodes_out, graph_out, dropout, conv_layers=2, fc_layers=3, first_conv_descr=None, | ||
middle_conv_descr=None, final_activation='LeakyReLU', skip=False, gru=False, fixed=False, variable=False, | ||
device='cpu'): | ||
""" | ||
:param nfeat: number of input features per node | ||
:param nhid: number of hidden features per node | ||
:param nodes_out: number of nodes' labels | ||
:param graph_out: number of graph labels | ||
:param dropout: dropout value | ||
:param conv_layers: if variable, conv_layers should be a function : adj -> int, otherwise an int | ||
:param fc_layers: number of fully connected layers before the labels | ||
:param first_conv_descr: dict or SimpleNamespace: "type"-> type of layer, "args" -> dict of calling args | ||
:param middle_conv_descr: dict or SimpleNamespace : "type"-> type of layer, "args" -> dict of calling args | ||
:param final_activation: activation to be used on the last fc layer before the labels | ||
:param skip: whether to use skip connections feeding to the readout | ||
:param gru: whether to use a shared GRU after each convolution | ||
:param fixed: whether to reuse the same middle convolutional layer multiple times | ||
:param variable: whether the number of convolutional layers is variable or fixed | ||
:param device: device used for computation | ||
""" | ||
super(GNN, self).__init__() | ||
if variable: | ||
assert callable(conv_layers), "conv_layers should be a function from adjacency matrix to int" | ||
assert fixed, "With a variable number of layers they must be fixed" | ||
assert not skip, "cannot have skip and fixed at the same time" | ||
else: | ||
assert type(conv_layers) == int, "conv_layers should be an int" | ||
assert conv_layers > 0, "conv_layers should be greater than 0" | ||
|
||
if type(first_conv_descr) == dict: | ||
first_conv_descr = types.SimpleNamespace(**first_conv_descr) | ||
assert type(first_conv_descr) == types.SimpleNamespace, "first_conv_descr should be either a dict or a SimpleNamespace" | ||
if type(first_conv_descr.args) == dict: | ||
first_conv_descr.args = types.SimpleNamespace(**first_conv_descr.args) | ||
assert type(first_conv_descr.args) == types.SimpleNamespace, "first_conv_descr.args should be either a dict or a SimpleNamespace" | ||
|
||
if type(middle_conv_descr) == dict: | ||
middle_conv_descr = types.SimpleNamespace(**middle_conv_descr) | ||
assert type(middle_conv_descr) == types.SimpleNamespace, "middle_conv_descr should be either a dict or a SimpleNamespace" | ||
if type(middle_conv_descr.args) == dict: | ||
middle_conv_descr.args = types.SimpleNamespace(**middle_conv_descr.args) | ||
assert type(middle_conv_descr.args) == types.SimpleNamespace, "middle_conv_descr.args should be either a dict or a SimpleNamespace" | ||
|
||
self.dropout = dropout | ||
self.conv_layers = nn.ModuleList() | ||
self.skip = skip | ||
self.fixed = fixed | ||
self.variable = variable | ||
self.n_fixed_conv = conv_layers | ||
self.gru = GRU(input_size=nhid, hidden_size=nhid, device=device) if gru else None | ||
|
||
# first graph convolution | ||
first_conv_descr.args.in_features = nfeat | ||
first_conv_descr.args.out_features = nhid | ||
first_conv_descr.args.device = device | ||
self.conv_layers.append(first_conv_descr.layer_type(**vars(first_conv_descr.args))) | ||
|
||
# middle graph convolutions | ||
middle_conv_descr.args.in_features = nhid | ||
middle_conv_descr.args.out_features = nhid | ||
middle_conv_descr.args.device = device | ||
for l in range(1 if fixed else conv_layers - 1): | ||
self.conv_layers.append( | ||
middle_conv_descr.layer_type(**vars(middle_conv_descr.args))) | ||
|
||
n_conv_out = nfeat + conv_layers * nhid if skip else nhid | ||
|
||
# nodes output: fully connected layers | ||
self.nodes_read_out = MLP(in_size=n_conv_out, hidden_size=n_conv_out, out_size=nodes_out, layers=fc_layers, | ||
mid_activation="LeakyReLU", last_activation=final_activation, device=device) | ||
|
||
# graph output: S2S readout | ||
self.graph_read_out = S2SReadout(n_conv_out, n_conv_out, graph_out, fc_layers=fc_layers, device=device, | ||
final_activation=final_activation) | ||
|
||
def forward(self, x, adj): | ||
# graph convolutions | ||
skip_connections = [x] if self.skip else None | ||
|
||
n_layers = self.n_fixed_conv(adj) if self.variable else self.n_fixed_conv | ||
conv_layers = [self.conv_layers[0]] + ([self.conv_layers[1]] * (n_layers-1)) if self.fixed else self.conv_layers | ||
|
||
for layer, conv in enumerate(conv_layers): | ||
y = conv(x, adj) | ||
x = y if self.gru is None else self.gru(x, y) | ||
|
||
if self.skip: | ||
skip_connections.append(x) | ||
|
||
# dropout at all layers but the last | ||
if layer != n_layers - 1: | ||
x = F.dropout(x, self.dropout, training=self.training) | ||
|
||
if self.skip: | ||
x = torch.cat(skip_connections, dim=2) | ||
|
||
# readout output | ||
return (self.nodes_read_out(x), self.graph_read_out(x)) |
Oops, something went wrong.