Skip to content

Commit

Permalink
PCQM4Mv2 test-dev (rampasek#17)
Browse files Browse the repository at this point in the history
* Added GELU

* added activation function configurability and attention-weight logging

* PCQM4Mv2 inference data loader for `dev` and `challenge` splits; + PNA fix

* activation function configurability

* run modes to run: PCQM4M offical inference; attention-weight logging

* PCQM4Mv2 GPS 16layer deep architecture and corresponding inference config

* updated run cmds script

* PCQM4Mv2 configs

* updated README with PCQM4M instructions
  • Loading branch information
rampasek authored Nov 11, 2022
1 parent 6305368 commit 8b9309c
Show file tree
Hide file tree
Showing 17 changed files with 418 additions and 68 deletions.
49 changes: 46 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,51 @@ python main.py --cfg configs/SAN/zinc-SAN.yaml wandb.use False
python main.py --cfg tests/configs/graph/zinc.yaml wandb.use False
```

## Running GraphGPS on OGB-LSC PCQM4Mv2
### Training
```bash
# "small" GPS (GatedGCN+Transformer) with RWSE: 5layers, 304dim, 6152001 params
python main.py --cfg configs/GPS/pcqm4m-GPS+RWSE.yaml
# "medium" GPS (GatedGCN+Transformer) with RWSE: 10layers, 384dim, 19414641 params
python main.py --cfg configs/GPS/pcqm4m-GPSmedium+RWSE.yaml
# "deep" GPS (GatedGCN+Transformer) with RWSE: 16layers, 256dim, 13807345 params
python main.py --cfg configs/GPS/pcqm4m-GPSdeep+RWSE.yaml
```

### Expected performance
- Note 1: For training we set aside 150k molecules as a custom validation set for the model selection / early stopping.
The official `valid` set is used as the testing set in our training setup.
For running inference on `test-dev` and `test-challenge` look further below.

- Note 2: GPS-medium took ~48h, GPS-deep ~60h to train on a single NVidia A100 GPU. Your reproduced results may slightly vary.

- Note 3: This version of GPS **does not** use 3D atomic position information.

| Model config | parameters | train MAE | custom valid MAE | official valid MAE |
|--------------|-----------:|----------:|-----------------:|-------------------:|
| GPS-small | 6,152,001 | 0.0638 | 0.0849 | 0.0937 |
| GPS-medium | 19,414,641 | 0.0726 | 0.0805 | 0.0858 |
| GPS-deep | 13,807,345 | 0.0641 | 0.0796 | 0.0852 |

### Inference and submission files for OGB-LSC leaderboard
You need a saved pretrained model from the previous step, then run it with an "inference" script that loads official
`valid`, `test-dev`, and `test-challenge` splits, then runs inference, and the official OGB Evaluator.

```bash
# You can download our pretrained GPS-deep (151 MB).
wget https://www.dropbox.com/s/aomimvak4gb6et3/pcqm4m-GPS%2BRWSE.deep.zip
unzip pcqm4m-GPS+RWSE.deep.zip -d pretrained/

# Run inference and official OGB Evaluator.
python main.py --cfg configs/GPS/pcqm4m-GPSdeep-inference.yaml

# Result files for OGB-LSC Leaderboard.
results/pcqm4m-GPSdeep-inference/0/y_pred_pcqm4m-v2_test-challenge.npz
results/pcqm4m-GPSdeep-inference/0/y_pred_pcqm4m-v2_test-dev.npz
```


### Benchmarking GPS on 11 datasets
## Benchmarking GPS on 11 datasets
See `run/run_experiments.sh` script to run multiple random seeds per each of the 11 datasets. We rely on Slurm job scheduling system.

Alternatively, you can run them in terminal following the example below. Configs for all 11 datasets are in `configs/GPS/`.
Expand All @@ -70,7 +113,7 @@ To use W&B logging, set `wandb.use True` and have a `gtransformers` entity set-u



### Unit tests
## Unit tests

To run all unit tests, execute from the project root directory:

Expand All @@ -87,7 +130,7 @@ python -m unittest -v unittests.test_eigvecs

## Citation

If you find this work useful, please cite our paper:
If you find this work useful, please cite our NeurIPS 2022 paper:
```bibtex
@article{rampasek2022GPS,
title={{Recipe for a General, Powerful, Scalable Graph Transformer}},
Expand Down
3 changes: 2 additions & 1 deletion configs/GPS/pcqm4m-GPS+RWSE.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ train:
mode: custom
batch_size: 256
eval_period: 1
ckpt_period: 100
ckpt_best: True
# ckpt_period: 100
model:
type: GPSModel
loss_fun: l1
Expand Down
60 changes: 60 additions & 0 deletions configs/GPS/pcqm4m-GPSdeep+RWSE.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
out_dir: results
metric_best: mae
metric_agg: argmin
wandb:
use: True
project: pcqm4m
dataset:
format: OGB
name: PCQM4Mv2-full
task: graph
task_type: regression
transductive: False
node_encoder: True
node_encoder_name: Atom+RWSE
node_encoder_bn: False
edge_encoder: True
edge_encoder_name: Bond
edge_encoder_bn: False
posenc_RWSE:
enable: True
kernel:
times_func: range(1,17)
model: Linear
dim_pe: 20
raw_norm_type: BatchNorm
train:
mode: custom
batch_size: 256
eval_period: 1
ckpt_best: True
# ckpt_period: 100
model:
type: GPSModel
loss_fun: l1
graph_pooling: mean
gt:
layer_type: CustomGatedGCN+Transformer
layers: 16
n_heads: 8
dim_hidden: 256 # `gt.dim_hidden` must match `gnn.dim_inner`
dropout: 0.1
attn_dropout: 0.1
layer_norm: False
batch_norm: True
gnn:
head: san_graph
layers_pre_mp: 0
layers_post_mp: 3 # Not used when `gnn.head: san_graph`
dim_inner: 256 # `gt.dim_hidden` must match `gnn.dim_inner`
batchnorm: True
act: gelu
dropout: 0.0
optim:
clip_grad_norm: True
optimizer: adamW
weight_decay: 0.0
base_lr: 0.0002
max_epoch: 150
scheduler: linear_with_warmup
num_warmup_epochs: 10
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ wandb:
project: scratch # W&B project for debugging runs.
dataset:
format: OGB
name: PCQM4Mv2-full
name: PCQM4Mv2-inference
task: graph
task_type: regression
transductive: False
Expand All @@ -17,10 +17,10 @@ dataset:
edge_encoder_name: Bond
edge_encoder_bn: False
pretrained:
dir: pretrained/pcqm4m-GPS+RWSE.medium
dir: pretrained/pcqm4m-GPS+RWSE.deep
reset_prediction_head: False
train:
mode: inference-only
mode: PCQM4Mv2-inference
batch_size: 256
model:
type: GPSModel
Expand All @@ -30,5 +30,5 @@ gnn:
head: san_graph
layers_post_mp: 3 # Not used when `gnn.head: san_graph`
batchnorm: True
act: relu
act: gelu
dropout: 0.0
3 changes: 2 additions & 1 deletion configs/GPS/pcqm4m-GPSmedium+RWSE.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ train:
mode: custom
batch_size: 256
eval_period: 1
ckpt_period: 100
ckpt_best: True
# ckpt_period: 100
model:
type: GPSModel
loss_fun: l1
Expand Down
3 changes: 3 additions & 0 deletions graphgps/act/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,6 @@ def forward(self, x):

register_act('swish', SWISH(inplace=cfg.mem.inplace))
register_act('lrelu_03', nn.LeakyReLU(0.3, inplace=cfg.mem.inplace))

# Add Gaussian Error Linear Unit (GELU).
register_act('gelu', nn.GELU())
9 changes: 5 additions & 4 deletions graphgps/encoder/kernel_pos_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,19 +54,20 @@ def __init__(self, dim_emb, expand_x=True):
else:
self.raw_norm = None

activation = nn.ReLU() # register.act_dict[cfg.gnn.act]
if model_type == 'mlp':
layers = []
if n_layers == 1:
layers.append(nn.Linear(num_rw_steps, dim_pe))
layers.append(nn.ReLU())
layers.append(activation)
else:
layers.append(nn.Linear(num_rw_steps, 2 * dim_pe))
layers.append(nn.ReLU())
layers.append(activation)
for _ in range(n_layers - 2):
layers.append(nn.Linear(2 * dim_pe, 2 * dim_pe))
layers.append(nn.ReLU())
layers.append(activation)
layers.append(nn.Linear(2 * dim_pe, dim_pe))
layers.append(nn.ReLU())
layers.append(activation)
self.pe_encoder = nn.Sequential(*layers)
elif model_type == 'linear':
self.pe_encoder = nn.Linear(num_rw_steps, dim_pe)
Expand Down
18 changes: 10 additions & 8 deletions graphgps/encoder/laplace_pos_encoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn as nn
import torch_geometric.graphgym.register as register
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_node_encoder

Expand Down Expand Up @@ -49,6 +50,7 @@ def __init__(self, dim_emb, expand_x=True):
else:
self.raw_norm = None

activation = nn.ReLU() # register.act_dict[cfg.gnn.act]
if model_type == 'Transformer':
# Transformer model for LapPE
encoder_layer = nn.TransformerEncoderLayer(d_model=dim_pe,
Expand All @@ -60,15 +62,15 @@ def __init__(self, dim_emb, expand_x=True):
# DeepSet model for LapPE
layers = []
if n_layers == 1:
layers.append(nn.ReLU())
layers.append(activation)
else:
self.linear_A = nn.Linear(2, 2 * dim_pe)
layers.append(nn.ReLU())
layers.append(activation)
for _ in range(n_layers - 2):
layers.append(nn.Linear(2 * dim_pe, 2 * dim_pe))
layers.append(nn.ReLU())
layers.append(activation)
layers.append(nn.Linear(2 * dim_pe, dim_pe))
layers.append(nn.ReLU())
layers.append(activation)
self.pe_encoder = nn.Sequential(*layers)

self.post_mlp = None
Expand All @@ -77,15 +79,15 @@ def __init__(self, dim_emb, expand_x=True):
layers = []
if post_n_layers == 1:
layers.append(nn.Linear(dim_pe, dim_pe))
layers.append(nn.ReLU())
layers.append(activation)
else:
layers.append(nn.Linear(dim_pe, 2 * dim_pe))
layers.append(nn.ReLU())
layers.append(activation)
for _ in range(post_n_layers - 2):
layers.append(nn.Linear(2 * dim_pe, 2 * dim_pe))
layers.append(nn.ReLU())
layers.append(activation)
layers.append(nn.Linear(2 * dim_pe, dim_pe))
layers.append(nn.ReLU())
layers.append(activation)
self.post_mlp = nn.Sequential(*layers)


Expand Down
2 changes: 1 addition & 1 deletion graphgps/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def load_pretrained_model_cfg(cfg):
# Copy over GNN cfg but not those for the prediction head
compare_cfg(cfg, pretrained_cfg, 'gnn.head')
compare_cfg(cfg, pretrained_cfg, 'gnn.layers_post_mp')
compare_cfg(cfg, pretrained_cfg, 'gnn.act')
compare_cfg(cfg, pretrained_cfg, 'gnn.act', strict=True)
compare_cfg(cfg, pretrained_cfg, 'gnn.dropout')
head = cfg.gnn.head
post_mp = cfg.gnn.layers_post_mp
Expand Down
4 changes: 2 additions & 2 deletions graphgps/head/san_graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric.graphgym.register as register
from torch_geometric.graphgym import cfg
Expand Down Expand Up @@ -27,6 +26,7 @@ def __init__(self, dim_in, dim_out, L=2):
nn.Linear(dim_in // 2 ** L, dim_out, bias=True))
self.FC_layers = nn.ModuleList(list_FC_layers)
self.L = L
self.activation = register.act_dict[cfg.gnn.act]

def _apply_index(self, batch):
return batch.graph_feature, batch.y
Expand All @@ -35,7 +35,7 @@ def forward(self, batch):
graph_emb = self.pooling_fun(batch.x, batch.batch)
for l in range(self.L):
graph_emb = self.FC_layers[l](graph_emb)
graph_emb = F.relu(graph_emb)
graph_emb = self.activation(graph_emb)
graph_emb = self.FC_layers[self.L](graph_emb)
batch.graph_feature = graph_emb
pred, label = self._apply_index(batch)
Expand Down
15 changes: 8 additions & 7 deletions graphgps/layer/gatedgcn_layer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.graphgym.register as register
import torch_geometric.nn as pyg_nn
from torch_geometric.graphgym.models.layer import LayerConfig
from torch_scatter import scatter

from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_layer
from torch_scatter import scatter


class GatedGCNLayer(pyg_nn.conv.MessagePassing):
Expand All @@ -15,9 +14,10 @@ class GatedGCNLayer(pyg_nn.conv.MessagePassing):
Residual Gated Graph ConvNets
https://arxiv.org/pdf/1711.07553.pdf
"""
def __init__(self, in_dim, out_dim, dropout, residual,
def __init__(self, in_dim, out_dim, dropout, residual, act='relu',
equivstable_pe=False, **kwargs):
super().__init__(**kwargs)
self.activation = register.act_dict[act]
self.A = pyg_nn.Linear(in_dim, out_dim, bias=True)
self.B = pyg_nn.Linear(in_dim, out_dim, bias=True)
self.C = pyg_nn.Linear(in_dim, out_dim, bias=True)
Expand All @@ -29,7 +29,7 @@ def __init__(self, in_dim, out_dim, dropout, residual,
self.EquivStablePE = equivstable_pe
if self.EquivStablePE:
self.mlp_r_ij = nn.Sequential(
nn.Linear(1, out_dim), nn.ReLU(),
nn.Linear(1, out_dim), self.activation,
nn.Linear(out_dim, 1),
nn.Sigmoid())

Expand Down Expand Up @@ -69,8 +69,8 @@ def forward(self, batch):
x = self.bn_node_x(x)
e = self.bn_edge_e(e)

x = F.relu(x)
e = F.relu(e)
x = self.activation(x)
e = self.activation(e)

x = F.dropout(x, self.dropout, training=self.training)
e = F.dropout(e, self.dropout, training=self.training)
Expand Down Expand Up @@ -145,6 +145,7 @@ def __init__(self, layer_config: LayerConfig, **kwargs):
out_dim=layer_config.dim_out,
dropout=0., # Dropout is handled by GraphGym's `GeneralLayer` wrapper
residual=False, # Residual connections are handled by GraphGym's `GNNStackStage` wrapper
act=layer_config.act,
**kwargs)

def forward(self, batch):
Expand Down
Loading

0 comments on commit 8b9309c

Please sign in to comment.