Skip to content

Commit

Permalink
Added the model in PyTorch Geometric
Browse files Browse the repository at this point in the history
  • Loading branch information
gcorso committed Jun 7, 2020
1 parent 6985578 commit 6e1b80b
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 2 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@ Principal Neighbourhood Aggregation for Graph Nets [arxiv.org/abs/2004.05718](ht

## Overview

We provide the implementation of the Principal Neighbourhood Aggregation (PNA) in both PyTorch and DGL frameworks, along with scripts to generate and run the multitask benchmarks, scripts for running real-world benchmarks, a flexible PyTorch GNN framework and implementations of the other models used for comparison. The repository is organised as follows:
We provide the implementation of the Principal Neighbourhood Aggregation (PNA) in PyTorch, DGL and PyTorch Geometric frameworks, along with scripts to generate and run the multitask benchmarks, scripts for running real-world benchmarks, a flexible PyTorch GNN framework and implementations of the other models used for comparison. The repository is organised as follows:

- `models` contains:
- `pytorch` contains the various GNN models implemented in PyTorch:
- the implementation of the aggregators, the scalers and the PNA layer (`pna`)
- the flexible GNN framework that can be used with any type of graph convolutions (`gnn_framework.py`)
- implementations of the other GNN models used for comparison in the paper, namely GCN, GAT, GIN and MPNN
- `dgl` contains the PNA model implemented via the [DGL library](https://www.dgl.ai/): PNA tower and layer, aggregators, scalers and readout.
- `dgl` contains the PNA model implemented via the [DGL library](https://www.dgl.ai/): aggregators, scalers, and layer.
- `pytorch_geometric` contains the PNA model implemented via the [PyTorch Geometric library](https://pytorch-geometric.readthedocs.io/): aggregators, scalers, and layer.
- `layers.py` contains general NN layers used by the various models
- `multi_task` contains various scripts to recreate the multi_task benchmark along with the files used to train the various models. In `multi_task/README.md` we detail the instructions for the generation and training hyperparameters tuned.
- `real_world` contains various scripts from [Benchmarking GNNs](https://github.com/graphdeeplearning/benchmarking-gnns) to download the real-world benchmarks and train the PNA on them. In `real_world/README.md` we provide instructions for the generation and training hyperparameters tuned.
Expand Down
37 changes: 37 additions & 0 deletions models/pytorch_geometric/aggregators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch
from torch_scatter import scatter_sum, scatter_mean, scatter_max, scatter_min

EPS = 1e-5


def aggregate_sum(src, index, dim, dim_size):
return scatter_sum(src=src, index=index, dim=dim, out=None, dim_size=dim_size)


def aggregate_mean(src, index, dim, dim_size):
return scatter_mean(src=src, index=index, dim=dim, out=None, dim_size=dim_size)


def aggregate_max(src, index, dim, dim_size):
return scatter_max(src=src, index=index, dim=dim, out=None, dim_size=dim_size)[0]


def aggregate_min(src, index, dim, dim_size):
return scatter_min(src=src, index=index, dim=dim, out=None, dim_size=dim_size)[0]


def aggregate_var(src, index, dim, dim_size):
mean = aggregate_mean(src, index, dim, dim_size)
mean_squares = aggregate_mean(src * src, index, dim, dim_size)
var = mean_squares - mean * mean
return var


def aggregate_std(src, index, dim, dim_size):
var = aggregate_var(src, index, dim, dim_size)
out = torch.sqrt(torch.relu(var) + EPS)
return out


AGGREGATORS = {'mean': aggregate_mean, 'sum': aggregate_sum, 'max': aggregate_max, 'min': aggregate_min,
'std': aggregate_std, 'var': aggregate_var}
109 changes: 109 additions & 0 deletions models/pytorch_geometric/pna.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import torch
from torch import nn
from torch_geometric.nn.conv import MessagePassing

from models.pytorch_geometric.aggregators import AGGREGATORS
from models.pytorch_geometric.scalers import get_degree, SCALERS
from models.layers import MLP, FCLayer

"""
PNA: Principal Neighbourhood Aggregation
Gabriele Corso, Luca Cavalleri, Dominique Beaini, Pietro Lio, Petar Velickovic
https://arxiv.org/abs/2004.05718
"""


class PNAConv(MessagePassing):

def __init__(self, in_channels, out_channels, aggregators, scalers, avg_d, towers=1,
pretrans_layers=1, posttrans_layers=1, divide_input=False, **kwargs):
"""
:param in_channels: size of the input per node
:param in_channels: size of the output per node
:param aggregators: set of aggregation function identifiers
:param scalers: set of scaling functions identifiers
:param avg_d: average degree of nodes in the training set, used by scalers to normalize
:param towers: number of towers to use
:param pretrans_layers: number of layers in the transformation before the aggregation
:param posttrans_layers: number of layers in the transformation after the aggregation
:param divide_input: whether the input features should be split between towers or not
"""
super(PNAConv, self).__init__(aggr=None, **kwargs)
assert ((not divide_input) or in_channels % towers == 0), "if divide_input is set the number of towers has to divide in_features"
assert (out_channels % towers == 0), "the number of towers has to divide the out_features"

self.in_channels = in_channels
self.out_channels = out_channels
self.towers = towers
self.divide_input = divide_input
self.input_tower = self.in_channels // towers if divide_input else self.in_channels
self.output_tower = self.out_channels // towers

# retrieve the aggregators and scalers functions
self.aggregators = [AGGREGATORS[aggr] for aggr in aggregators]
self.scalers = [SCALERS[scale] for scale in scalers]
self.avg_d = avg_d

self.edge_encoder = FCLayer(in_size=in_channels, out_size=self.input_tower, activation='none')

# build pre-transformations and post-transformation MLP for each tower
self.pretrans = nn.ModuleList()
self.posttrans = nn.ModuleList()
for _ in range(towers):
self.pretrans.append(
MLP(in_size=3 * self.input_tower, hidden_size=self.input_tower, out_size=self.input_tower,
layers=pretrans_layers, mid_activation='relu', last_activation='none'))
self.posttrans.append(
MLP(in_size=(len(self.aggregators) * len(self.scalers) + 1) * self.input_tower, hidden_size=self.output_tower,
out_size=self.output_tower, layers=posttrans_layers, mid_activation='relu', last_activation='none'))

self.mixing_network = FCLayer(self.out_channels, self.out_channels, activation='LeakyReLU')

def forward(self, x, edge_index, edge_attr):
edge_embedding = self.edge_encoder(edge_attr)
return self.propagate(edge_index, x=x, edge_attr=edge_embedding)

def message(self, x_i, x_j, edge_attr):
if self.divide_input:
# divide the features among the towers
x_i = x_i.view(-1, self.towers, self.input_tower)
x_j = x_j.view(-1, self.towers, self.input_tower)
else:
# repeat the features over the towers
x_i = x_i.view(-1, 1, self.input_tower).repeat(1, self.towers, 1)
x_j = x_j.view(-1, 1, self.input_tower).repeat(1, self.towers, 1)
edge_attr = edge_attr.view(-1, 1, self.input_tower).repeat(1, self.towers, 1)

# pre-transformation
h_cat = torch.cat([x_i, x_j, edge_attr], dim=-1)
y = torch.zeros((h_cat.shape[0], self.towers, self.input_tower), device=x_i.device)
for tower, trans in enumerate(self.pretrans):
y[:, tower, :] = trans(h_cat[:, tower, :])
return y

def aggregate(self, inputs, index, dim_size=None):
D = get_degree(inputs, index, self.node_dim, dim_size)

# aggregators
inputs = torch.cat([aggregator(inputs, index, dim=self.node_dim, dim_size=dim_size)
for aggregator in self.aggregators], dim=-1)
# scalers
return torch.cat([scaler(inputs, D, self.avg_d) for scaler in self.scalers], dim=-1)

def update(self, aggr_out, x):
# post-transformation
if self.divide_input:
x = x.view(-1, self.towers, self.input_tower)
else:
x = x.view(-1, 1, self.input_tower).repeat(1, self.towers, 1)
aggr_cat = torch.cat([x, aggr_out], dim=-1)
y = torch.zeros((aggr_cat.shape[0], self.towers, self.output_tower), device=x.device)
for tower, trans in enumerate(self.posttrans):
y[:, tower, :] = trans(aggr_cat[:, tower, :])

# concatenate and mix all the towers
y = y.view(-1, self.towers * self.output_tower)
y = self.mixing_network(y)

return y

51 changes: 51 additions & 0 deletions models/pytorch_geometric/scalers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch
from torch_scatter import scatter_sum


def get_degree(src, index, dim, dim_size):
# returns a tensor with the various degrees of the nodes
index_dim = dim
if index_dim < 0:
index_dim = index_dim + src.dim()
if index.dim() <= index_dim:
index_dim = index.dim() - 1

ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
count = scatter_sum(ones, index, index_dim, None, dim_size)
count.clamp_(1) # ensure no 0s
count = count.unsqueeze(-1).unsqueeze(-1)
return count


def scale_identity(src, D, avg_d=None):
return src


def scale_amplification(src, D, avg_d=None):
# log(D + 1) / d * X where d is the average of the ``log(D + 1)`` in the training set
scale = (torch.log(D + 1) / avg_d["log"])
out = src * scale
return out


def scale_attenuation(src, D, avg_d=None):
# (log(D + 1))^-1 / d * X where d is the average of the ``log(D + 1))^-1`` in the training set
scale = (avg_d["log"] / torch.log(D + 1))
out = src * scale
return out


def scale_linear(src, D, avg_d=None):
scale = D / avg_d["lin"]
out = src * scale
return out


def scale_inverse_linear(src, D, avg_d=None):
scale = avg_d["lin"] / D
out = src * scale
return out


SCALERS = {'identity': scale_identity, 'amplification': scale_amplification, 'attenuation': scale_attenuation,
'linear': scale_linear, 'inverse_linear': scale_inverse_linear}

0 comments on commit 6e1b80b

Please sign in to comment.