Skip to content

Commit

Permalink
Merge pull request rampasek#4 from rampasek/upgrade_PyG204
Browse files Browse the repository at this point in the history
Upgrade to PyG v2.0.4
  • Loading branch information
rampasek authored Jul 20, 2022
2 parents e38969b + 0a2ff7c commit 2f08c5f
Show file tree
Hide file tree
Showing 50 changed files with 231 additions and 325 deletions.
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/recipe-for-a-general-powerful-scalable-graph/graph-regression-on-zinc)](https://paperswithcode.com/sota/graph-regression-on-zinc?p=recipe-for-a-general-powerful-scalable-graph)


![HGNet-viz](./GraphGPS.png)
![GraphGPS-viz](./GraphGPS.png)

How to build a graph Transformer? We provide a 3-part recipe on how to build graph Transformers with linear complexity. Our GPS recipe consists of choosing 3 main ingredients:
1. positional/structural encoding: [LapPE](https://arxiv.org/abs/2106.03893), [RWSE](https://arxiv.org/abs/2110.07875), [SignNet](https://arxiv.org/abs/2202.13013), [EquivStableLapPE](https://arxiv.org/abs/2203.00199)
2. local message-passing mechanism: [GatedGCN](https://arxiv.org/abs/1711.07553), [GINE](https://arxiv.org/abs/1905.12265), [PNA](https://arxiv.org/abs/2004.05718)
3. global attention mechanism: [Transformer](https://arxiv.org/abs/1706.03762), [Performer](https://arxiv.org/abs/2009.14794), [BigBird](https://arxiv.org/abs/2007.14062)

In this *GraphGPS* package we provide several positional/structural encodings and model choices, implementing the GPS recipe. GraphGPS is built using [PyG](https://www.pyg.org/) and [GraphGym from PyG2](https://pytorch-geometric.readthedocs.io/en/2.0.0/notes/graphgym.html).
Specifically PyG v2.0.2 is required.
Specifically *PyG v2.0.4* is required.


### Python environment setup with Conda
Expand All @@ -21,16 +21,16 @@ Specifically PyG v2.0.2 is required.
conda create -n graphgps python=3.9
conda activate graphgps

conda install pytorch=1.9 torchvision torchaudio -c pytorch -c nvidia
conda install pyg=2.0.2 -c pyg -c conda-forge
conda install pandas scikit-learn
conda install pytorch=1.10 torchvision torchaudio -c pytorch -c nvidia
conda install pyg=2.0.4 -c pyg -c conda-forge

# RDKit is required for OGB-LSC PCQM4Mv2 and datasets derived from it.
conda install openbabel fsspec rdkit -c conda-forge

pip install torchmetrics
pip install performer-pytorch
pip install torchmetrics==0.7.2
pip install ogb
pip install tensorboardX
pip install wandb

conda clean --all
Expand Down Expand Up @@ -88,7 +88,7 @@ python -m unittest -v unittests.test_eigvecs
## Citation

If you find this work useful, please cite our paper:
```
```bibtex
@article{rampasek2022GPS,
title={{Recipe for a General, Powerful, Scalable Graph Transformer}},
author={Ladislav Ramp\'{a}\v{s}ek and Mikhail Galkin and Vijay Prakash Dwivedi and Anh Tuan Luu and Guy Wolf and Dominique Beaini},
Expand Down
5 changes: 2 additions & 3 deletions graphgps/act/example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn as nn

from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_act

Expand All @@ -18,6 +19,4 @@ def forward(self, x):


register_act('swish', SWISH(inplace=cfg.mem.inplace))

register_act('lrelu_03',
nn.LeakyReLU(negative_slope=0.3, inplace=cfg.mem.inplace))
register_act('lrelu_03', nn.LeakyReLU(0.3, inplace=cfg.mem.inplace))
4 changes: 1 addition & 3 deletions graphgps/config/custom_gnn_config.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from torch_geometric.graphgym.register import register_config


@register_config('custom_gnn')
def custom_gnn_cfg(cfg):
"""Extending config group of GraphGym's built-in GNN for purposes of our
CustomGNN network model.
"""

# Use residual connections between the GNN layers.
cfg.gnn.residual = False


register_config('custom_gnn', custom_gnn_cfg)
3 changes: 1 addition & 2 deletions graphgps/config/dataset_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from torch_geometric.graphgym.register import register_config


@register_config('dataset_cfg')
def dataset_cfg(cfg):
"""Dataset-specific config options.
"""
Expand All @@ -13,5 +14,3 @@ def dataset_cfg(cfg):

# VOC/COCO Superpixels dataset version based on SLIC compactness parameter.
cfg.dataset.slic_compactness = 10

register_config('dataset_cfg', dataset_cfg)
10 changes: 3 additions & 7 deletions graphgps/config/defaults_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from torch_geometric.graphgym.register import register_config


@register_config('overwrite_defaults')
def overwrite_defaults_cfg(cfg):
"""Overwrite the default config values that are first set by GraphGym in
torch_geometric.graphgym.config.set_cfg
Expand All @@ -13,14 +14,12 @@ def overwrite_defaults_cfg(cfg):

# Overwrite default dataset name
cfg.dataset.name = 'none'

# Overwrite default rounding precision
cfg.round = 5


register_config('overwrite_defaults', overwrite_defaults_cfg)


@register_config('extended_cfg')
def extended_cfg(cfg):
"""General extended config options.
"""
Expand All @@ -33,6 +32,3 @@ def extended_cfg(cfg):
cfg.train.finetune = ""
# Freeze the pretrained part of the network, learning only the new head
cfg.train.freeze_pretrained = False


register_config('extended_cfg', extended_cfg)
7 changes: 2 additions & 5 deletions graphgps/config/example.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from yacs.config import CfgNode as CN

from torch_geometric.graphgym.register import register_config
from yacs.config import CfgNode as CN


@register_config('example')
def set_cfg_example(cfg):
r'''
This function sets the default config value for customized options
Expand All @@ -21,6 +21,3 @@ def set_cfg_example(cfg):

# then argument can be specified within the group
cfg.example_group.example_arg = 'example'


register_config('example', set_cfg_example)
4 changes: 1 addition & 3 deletions graphgps/config/gt_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from yacs.config import CfgNode as CN


@register_config('cfg_gt')
def set_cfg_gt(cfg):
"""Configuration for Graph Transformer-style models, e.g.:
- Spectral Attention Network (SAN) Graph Transformer.
Expand Down Expand Up @@ -69,6 +70,3 @@ def set_cfg_gt(cfg):
cfg.gt.bigbird.block_size = 3

cfg.gt.bigbird.layer_norm_eps = 1e-6


register_config('cfg_gt', set_cfg_gt)
4 changes: 1 addition & 3 deletions graphgps/config/optimizers_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from torch_geometric.graphgym.register import register_config


@register_config('extended_optim')
def extended_optim_cfg(cfg):
"""Extend optimizer config group that is first set by GraphGym in
torch_geometric.graphgym.config.set_cfg
Expand All @@ -24,6 +25,3 @@ def extended_optim_cfg(cfg):

# Clip gradient norms while training
cfg.optim.clip_grad_norm = False


register_config('extended_optim', extended_optim_cfg)
8 changes: 2 additions & 6 deletions graphgps/config/posenc_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from yacs.config import CfgNode as CN


@register_config('posenc')
def set_cfg_posenc(cfg):
"""Extend configuration with positional encoding options.
"""
Expand Down Expand Up @@ -45,11 +46,10 @@ def set_cfg_posenc(cfg):
# a separate variable in the PyG graph batch object.
pecfg.pass_as_var = False


# Config for EquivStable LapPE
cfg.posenc_EquivStableLapPE.enable = False
cfg.posenc_EquivStableLapPE.raw_norm_type = 'none'

# Config for Laplacian Eigen-decomposition for PEs that use it.
for name in ['posenc_LapPE', 'posenc_SignNet', 'posenc_EquivStableLapPE']:
pecfg = getattr(cfg, name)
Expand All @@ -68,7 +68,6 @@ def set_cfg_posenc(cfg):
cfg.posenc_SignNet.phi_out_dim = 4
cfg.posenc_SignNet.phi_hidden_dim = 64


for name in ['posenc_RWSE', 'posenc_HKdiagSE', 'posenc_ElstaticSE']:
pecfg = getattr(cfg, name)

Expand All @@ -86,6 +85,3 @@ def set_cfg_posenc(cfg):

# Override default, electrostatic kernel has fixed set of 10 measures.
cfg.posenc_ElstaticSE.kernel.times_func = 'range(10)'


register_config('posenc', set_cfg_posenc)
4 changes: 1 addition & 3 deletions graphgps/config/split_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from torch_geometric.graphgym.register import register_config


@register_config('split')
def set_cfg_split(cfg):
"""Reconfigure the default config value for dataset split options.
Expand All @@ -20,6 +21,3 @@ def set_cfg_split(cfg):
# Choose to run multiple splits in one program execution, if set,
# takes the precedence over cfg.dataset.split_index for split selection
cfg.run_multiple_splits = []


register_config('split', set_cfg_split)
4 changes: 1 addition & 3 deletions graphgps/config/wandb_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from yacs.config import CfgNode as CN


@register_config('cfg_wandb')
def set_cfg_wandb(cfg):
"""Weights & Biases tracker configuration.
"""
Expand All @@ -20,6 +21,3 @@ def set_cfg_wandb(cfg):

# Optional run name
cfg.wandb.name = ""


register_config('cfg_wandb', set_cfg_wandb)
8 changes: 2 additions & 6 deletions graphgps/encoder/ast_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
max_depth = 20


@register_node_encoder('ASTNode')
class ASTNodeEncoder(torch.nn.Module):
"""The Abstract Syntax Tree (AST) Node Encoder used for ogbg-code2 dataset.
Expand Down Expand Up @@ -59,9 +60,7 @@ def forward(self, batch):
return batch


register_node_encoder('ASTNode', ASTNodeEncoder)


@register_edge_encoder('ASTEdge')
class ASTEdgeEncoder(torch.nn.Module):
"""The Abstract Syntax Tree (AST) Edge Encoder used for ogbg-code2 dataset.
Expand All @@ -84,6 +83,3 @@ def forward(self, batch):
self.embedding_direction(batch.edge_attr[:, 1])
batch.edge_attr = embedding
return batch


register_edge_encoder('ASTEdge', ASTEdgeEncoder)
4 changes: 1 addition & 3 deletions graphgps/encoder/dummy_edge_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from torch_geometric.graphgym.register import register_edge_encoder


@register_edge_encoder('DummyEdge')
class DummyEdgeEncoder(torch.nn.Module):
def __init__(self, emb_dim):
super().__init__()
Expand All @@ -14,6 +15,3 @@ def forward(self, batch):
dummy_attr = batch.edge_index.new_zeros(batch.edge_index.shape[1])
batch.edge_attr = self.encoder(dummy_attr)
return batch


register_edge_encoder('DummyEdge', DummyEdgeEncoder)
4 changes: 1 addition & 3 deletions graphgps/encoder/equivstable_laplace_pos_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch_geometric.graphgym.register import register_node_encoder


@register_node_encoder('EquivStableLapPE')
class EquivStableLapPENodeEncoder(torch.nn.Module):
"""Equivariant and Stable Laplace Positional Embedding node encoder.
Expand Down Expand Up @@ -48,6 +49,3 @@ def forward(self, batch):
batch.pe_EquivStableLapPE = pos_enc

return batch


register_node_encoder('EquivStableLapPE', EquivStableLapPENodeEncoder)
21 changes: 9 additions & 12 deletions graphgps/encoder/example.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import torch
from torch_geometric.graphgym.register import (register_node_encoder,
register_edge_encoder)

from ogb.utils.features import get_bond_feature_dims

from torch_geometric.graphgym.register import (
register_edge_encoder,
register_node_encoder,
)


@register_node_encoder('example')
class ExampleNodeEncoder(torch.nn.Module):
"""
Provides an encoder for integer node features
Parameters:
num_classes - the number of classes for the embedding mapping to learn
"""
def __init__(self, emb_dim, num_classes=None):
super(ExampleNodeEncoder, self).__init__()
super().__init__()

self.encoder = torch.nn.Embedding(num_classes, emb_dim)
torch.nn.init.xavier_uniform_(self.encoder.weight.data)
Expand All @@ -25,12 +27,10 @@ def forward(self, batch):
return batch


register_node_encoder('example', ExampleNodeEncoder)


@register_edge_encoder('example')
class ExampleEdgeEncoder(torch.nn.Module):
def __init__(self, emb_dim):
super(ExampleEdgeEncoder, self).__init__()
super().__init__()

self.bond_embedding_list = torch.nn.ModuleList()
full_bond_feature_dims = get_bond_feature_dims()
Expand All @@ -48,6 +48,3 @@ def forward(self, batch):

batch.edge_attr = bond_embedding
return batch


register_edge_encoder('example', ExampleEdgeEncoder)
11 changes: 3 additions & 8 deletions graphgps/encoder/kernel_pos_encoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import logging

import torch
import torch.nn as nn
from torch_geometric.graphgym.config import cfg
Expand Down Expand Up @@ -103,25 +101,22 @@ def forward(self, batch):
return batch


@register_node_encoder('RWSE')
class RWSENodeEncoder(KernelPENodeEncoder):
"""Random Walk Structural Encoding node encoder.
"""
kernel_type = 'RWSE'

register_node_encoder('RWSE', RWSENodeEncoder)


@register_node_encoder('HKdiagSE')
class HKdiagSENodeEncoder(KernelPENodeEncoder):
"""Heat kernel (diagonal) Structural Encoding node encoder.
"""
kernel_type = 'HKdiagSE'

register_node_encoder('HKdiagSE', HKdiagSENodeEncoder)


@register_node_encoder('ElstaticSE')
class ElstaticSENodeEncoder(KernelPENodeEncoder):
"""Electrostatic interactions Structural Encoding node encoder.
"""
kernel_type = 'ElstaticSE'

register_node_encoder('ElstaticSE', ElstaticSENodeEncoder)
4 changes: 1 addition & 3 deletions graphgps/encoder/laplace_pos_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch_geometric.graphgym.register import register_node_encoder


@register_node_encoder('LapPE')
class LapPENodeEncoder(torch.nn.Module):
"""Laplace Positional Embedding node encoder.
Expand Down Expand Up @@ -139,6 +140,3 @@ def forward(self, batch):
if self.pass_as_var:
batch.pe_LapPE = pos_enc
return batch


register_node_encoder('LapPE', LapPENodeEncoder)
Loading

0 comments on commit 2f08c5f

Please sign in to comment.