Skip to content

Commit

Permalink
Refactor src dir into models, layers, & move utils
Browse files Browse the repository at this point in the history
  • Loading branch information
chaitjo committed Jun 18, 2023
1 parent 7a817bb commit 545f55f
Show file tree
Hide file tree
Showing 22 changed files with 989 additions and 870 deletions.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time
import random
from tqdm import tqdm
from tqdm.autonotebook import tqdm # from tqdm import tqdm
import numpy as np
from sklearn.metrics import accuracy_score

Expand Down
7 changes: 7 additions & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from models.schnet import SchNetModel
from models.dimenet import DimeNetPPModel
from models.spherenet import SphereNetModel
from models.egnn import EGNNModel
from models.gvpgnn import GVPGNNModel
from models.tfn import TFNModel
from models.mace import MACEModel
105 changes: 105 additions & 0 deletions models/dimenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from typing import Callable, Union

import torch
from torch.nn import functional as F
from torch_geometric.nn import DimeNetPlusPlus
from torch_scatter import scatter


class DimeNetPPModel(DimeNetPlusPlus):
"""
DimeNet model from "Directional message passing for molecular graphs".
This class extends the DimeNetPlusPlus base class for PyG.
"""
def __init__(
self,
hidden_channels: int = 128,
in_dim: int = 1,
out_dim: int = 1,
num_layers: int = 4,
int_emb_size: int = 64,
basis_emb_size: int = 8,
out_emb_channels: int = 256,
num_spherical: int = 7,
num_radial: int = 6,
cutoff: float = 10,
max_num_neighbors: int = 32,
envelope_exponent: int = 5,
num_before_skip: int = 1,
num_after_skip: int = 2,
num_output_layers: int = 3,
act: Union[str, Callable] = 'swish'
):
"""
Initializes an instance of the DimeNetPPModel class with the provided parameters.
Parameters:
- hidden_channels (int): Number of channels in the hidden layers (default: 128)
- in_dim (int): Input dimension of the model (default: 1)
- out_dim (int): Output dimension of the model (default: 1)
- num_layers (int): Number of layers in the model (default: 4)
- int_emb_size (int): Embedding size for interaction features (default: 64)
- basis_emb_size (int): Embedding size for basis functions (default: 8)
- out_emb_channels (int): Number of channels in the output embeddings (default: 256)
- num_spherical (int): Number of spherical harmonics (default: 7)
- num_radial (int): Number of radial basis functions (default: 6)
- cutoff (float): Cutoff distance for interactions (default: 10)
- max_num_neighbors (int): Maximum number of neighboring atoms to consider (default: 32)
- envelope_exponent (int): Exponent of the envelope function (default: 5)
- num_before_skip (int): Number of layers before the skip connections (default: 1)
- num_after_skip (int): Number of layers after the skip connections (default: 2)
- num_output_layers (int): Number of output layers (default: 3)
- act (Union[str, Callable]): Activation function (default: 'swish' or callable)
Note:
- The `act` parameter can be either a string representing a built-in activation function,
or a callable object that serves as a custom activation function.
"""
super().__init__(
hidden_channels,
out_dim,
num_layers,
int_emb_size,
basis_emb_size,
out_emb_channels,
num_spherical,
num_radial,
cutoff,
max_num_neighbors,
envelope_exponent,
num_before_skip,
num_after_skip,
num_output_layers,
act
)

def forward(self, batch):

i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = self.triplets(
batch.edge_index, num_nodes=batch.atoms.size(0))

# Calculate distances.
dist = (batch.pos[i] - batch.pos[j]).pow(2).sum(dim=-1).sqrt()

# Calculate angles.
pos_i = batch.pos[idx_i]
pos_ji, pos_ki = batch.pos[idx_j] - pos_i, batch.pos[idx_k] - pos_i
a = (pos_ji * pos_ki).sum(dim=-1)
b = torch.cross(pos_ji, pos_ki).norm(dim=-1)
angle = torch.atan2(b, a)

rbf = self.rbf(dist)
sbf = self.sbf(dist, angle, idx_kj)

# Embedding block.
x = self.emb(batch.atoms, rbf, i, j)
P = self.output_blocks[0](x, rbf, i, num_nodes=batch.pos.size(0))

# Interaction blocks.
for interaction_block, output_block in zip(self.interaction_blocks,
self.output_blocks[1:]):
x = interaction_block(x, rbf, sbf, idx_kj, idx_ji)
P += output_block(x, rbf, i)

return P.sum(dim=0) if batch is None else scatter(P, batch.batch, dim=0)
87 changes: 87 additions & 0 deletions models/egnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import torch
from torch.nn import functional as F
from torch_geometric.nn import global_add_pool, global_mean_pool

from models.layers.egnn_layer import EGNNLayer


class EGNNModel(torch.nn.Module):
"""
E-GNN model from "E(n) Equivariant Graph Neural Networks".
"""
def __init__(
self,
num_layers: int = 5,
emb_dim: int = 128,
in_dim: int = 1,
out_dim: int = 1,
activation: str = "relu",
norm: str = "layer",
aggr: str = "sum",
pool: str = "sum",
residual: bool = True,
equivariant_pred: bool = False
):
"""
Initializes an instance of the EGNNModel class with the provided parameters.
Parameters:
- num_layers (int): Number of layers in the model (default: 5)
- emb_dim (int): Dimension of the node embeddings (default: 128)
- in_dim (int): Input dimension of the model (default: 1)
- out_dim (int): Output dimension of the model (default: 1)
- activation (str): Activation function to be used (default: "relu")
- norm (str): Normalization method to be used (default: "layer")
- aggr (str): Aggregation method to be used (default: "sum")
- pool (str): Global pooling method to be used (default: "sum")
- residual (bool): Whether to use residual connections (default: True)
- equivariant_pred (bool): Whether it is an equivariant prediction task (default: False)
"""
super().__init__()
self.equivariant_pred = equivariant_pred
self.residual = residual

# Embedding lookup for initial node features
self.emb_in = torch.nn.Embedding(in_dim, emb_dim)

# Stack of GNN layers
self.convs = torch.nn.ModuleList()
for _ in range(num_layers):
self.convs.append(EGNNLayer(emb_dim, activation, norm, aggr))

# Global pooling/readout function
self.pool = {"mean": global_mean_pool, "sum": global_add_pool}[pool]

if self.equivariant_pred:
# Linear predictor for equivariant tasks using geometric features
self.pred = torch.nn.Linear(emb_dim + 3, out_dim)
else:
# MLP predictor for invariant tasks using only scalar features
self.pred = torch.nn.Sequential(
torch.nn.Linear(emb_dim, emb_dim),
torch.nn.ReLU(),
torch.nn.Linear(emb_dim, out_dim)
)

def forward(self, batch):

h = self.emb_in(batch.atoms) # (n,) -> (n, d)
pos = batch.pos # (n, 3)

for conv in self.convs:
# Message passing layer
h_update, pos_update = conv(h, pos, batch.edge_index)

# Update node features (n, d) -> (n, d)
h = h + h_update if self.residual else h_update

# Update node coordinates (no residual) (n, 3) -> (n, 3)
pos = pos_update

if not self.equivariant_pred:
# Select only scalars for invariant prediction
out = self.pool(h, batch.batch) # (n, d) -> (batch_size, d)
else:
out = self.pool(torch.cat([h, pos], dim=-1), batch.batch)

return self.pred(out) # (batch_size, out_dim)
127 changes: 127 additions & 0 deletions models/gvpgnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import torch
from torch.nn import functional as F
from torch_geometric.nn import global_add_pool, global_mean_pool

from models.mace_modules.blocks import RadialEmbeddingBlock
import models.layers.gvp_layer as gvp


class GVPGNNModel(torch.nn.Module):
"""
GVP-GNN model from "Equivariant Graph Neural Networks for 3D Macromolecular Structure".
"""
def __init__(
self,
r_max: float = 10.0,
num_bessel: int = 8,
num_polynomial_cutoff: int = 5,
num_layers: int = 5,
in_dim=1,
out_dim=1,
s_dim: int = 128,
v_dim: int = 16,
s_dim_edge: int = 32,
v_dim_edge: int = 1,
pool: str = "sum",
residual: bool = True,
equivariant_pred: bool = False
):
"""
Initializes an instance of the GVPGNNModel class with the provided parameters.
Parameters:
- r_max (float): Maximum distance for Bessel basis functions (default: 10.0)
- num_bessel (int): Number of Bessel basis functions (default: 8)
- num_polynomial_cutoff (int): Number of polynomial cutoff basis functions (default: 5)
- num_layers (int): Number of layers in the model (default: 5)
- in_dim (int): Input dimension of the model (default: 1)
- out_dim (int): Output dimension of the model (default: 1)
- s_dim (int): Dimension of the node state embeddings (default: 128)
- v_dim (int): Dimension of the node vector embeddings (default: 16)
- s_dim_edge (int): Dimension of the edge state embeddings (default: 32)
- v_dim_edge (int): Dimension of the edge vector embeddings (default: 1)
- pool (str): Global pooling method to be used (default: "sum")
- residual (bool): Whether to use residual connections (default: True)
- equivariant_pred (bool): Whether it is an equivariant prediction task (default: False)
"""
super().__init__()

self.r_max = r_max
self.num_layers = num_layers
self.equivariant_pred = equivariant_pred
self.s_dim = s_dim
self.v_dim = v_dim

activations = (F.relu, None)
_DEFAULT_V_DIM = (s_dim, v_dim)
_DEFAULT_E_DIM = (s_dim_edge, v_dim_edge)

# Node embedding
self.emb_in = torch.nn.Embedding(in_dim, s_dim)
self.W_v = torch.nn.Sequential(
gvp.LayerNorm((s_dim, 0)),
gvp.GVP((s_dim, 0), _DEFAULT_V_DIM,
activations=(None, None), vector_gate=True)
)

# Edge embedding
self.radial_embedding = RadialEmbeddingBlock(
r_max=r_max,
num_bessel=num_bessel,
num_polynomial_cutoff=num_polynomial_cutoff,
)
self.W_e = torch.nn.Sequential(
gvp.LayerNorm((self.radial_embedding.out_dim, 1)),
gvp.GVP((self.radial_embedding.out_dim, 1), _DEFAULT_E_DIM,
activations=(None, None), vector_gate=True)
)

# Stack of GNN layers
self.layers = torch.nn.ModuleList(
gvp.GVPConvLayer(
_DEFAULT_V_DIM, _DEFAULT_E_DIM,
activations=activations, vector_gate=True,
residual=residual
)
for _ in range(num_layers)
)

# Global pooling/readout function
self.pool = {"mean": global_mean_pool, "sum": global_add_pool}[pool]

if self.equivariant_pred:
# Linear predictor for equivariant tasks using geometric features
self.pred = torch.nn.Linear(s_dim + v_dim * 3, out_dim)
else:
# MLP predictor for invariant tasks using only scalar features
self.pred = torch.nn.Sequential(
torch.nn.Linear(s_dim, s_dim),
torch.nn.ReLU(),
torch.nn.Linear(s_dim, out_dim)
)

def forward(self, batch):

# Edge features
vectors = batch.pos[batch.edge_index[0]] - batch.pos[batch.edge_index[1]] # [n_edges, 3]
lengths = torch.linalg.norm(vectors, dim=-1, keepdim=True) # [n_edges, 1]

h_V = self.emb_in(batch.atoms) # (n,) -> (n, d)
h_E = (
self.radial_embedding(lengths),
torch.nan_to_num(torch.div(vectors, lengths)).unsqueeze_(-2)
)

h_V = self.W_v(h_V)
h_E = self.W_e(h_E)

for layer in self.layers:
h_V = layer(h_V, batch.edge_index, h_E)

out = self.pool(gvp._merge(*h_V), batch.batch) # (n, d) -> (batch_size, d)

if not self.equivariant_pred:
# Select only scalars for invariant prediction
out = out[:,:self.s_dim]

return self.pred(out) # (batch_size, out_dim)
File renamed without changes.
13 changes: 8 additions & 5 deletions src/egnn_layers.py → models/layers/egnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@


class EGNNLayer(MessagePassing):
def __init__(self, emb_dim, activation="relu", norm="layer", aggr="add"):
"""E(n) Equivariant GNN Layer
"""E(n) Equivariant GNN Layer
Paper: E(n) Equivariant Graph Neural Networks, Satorras et al.
Paper: E(n) Equivariant Graph Neural Networks, Satorras et al.
"""
def __init__(self, emb_dim, activation="relu", norm="layer", aggr="add"):
"""
Args:
emb_dim: (int) - hidden dimension `d`
activation: (str) - non-linearity within MLPs (swish/relu)
Expand Down Expand Up @@ -65,7 +66,9 @@ def message(self, h_i, h_j, pos_i, pos_j):
msg = torch.cat([h_i, h_j, dists], dim=-1)
msg = self.mlp_msg(msg)
# Scale magnitude of displacement vector
pos_diff = pos_diff * self.mlp_pos(msg) # torch.clamp(updates, min=-100, max=100)
pos_diff = pos_diff * self.mlp_pos(msg)
# NOTE: some papers divide pos_diff by (dists + 1) to stabilise model.
# NOTE: lucidrains clamps pos_diff between some [-n, +n], also for stability.
return msg, pos_diff

def aggregate(self, inputs, index):
Expand Down
Loading

0 comments on commit 545f55f

Please sign in to comment.