forked from chaitjo/geometric-gnn-dojo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor src dir into models, layers, & move utils
- Loading branch information
Showing
22 changed files
with
989 additions
and
870 deletions.
There are no files selected for viewing
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from models.schnet import SchNetModel | ||
from models.dimenet import DimeNetPPModel | ||
from models.spherenet import SphereNetModel | ||
from models.egnn import EGNNModel | ||
from models.gvpgnn import GVPGNNModel | ||
from models.tfn import TFNModel | ||
from models.mace import MACEModel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
from typing import Callable, Union | ||
|
||
import torch | ||
from torch.nn import functional as F | ||
from torch_geometric.nn import DimeNetPlusPlus | ||
from torch_scatter import scatter | ||
|
||
|
||
class DimeNetPPModel(DimeNetPlusPlus): | ||
""" | ||
DimeNet model from "Directional message passing for molecular graphs". | ||
This class extends the DimeNetPlusPlus base class for PyG. | ||
""" | ||
def __init__( | ||
self, | ||
hidden_channels: int = 128, | ||
in_dim: int = 1, | ||
out_dim: int = 1, | ||
num_layers: int = 4, | ||
int_emb_size: int = 64, | ||
basis_emb_size: int = 8, | ||
out_emb_channels: int = 256, | ||
num_spherical: int = 7, | ||
num_radial: int = 6, | ||
cutoff: float = 10, | ||
max_num_neighbors: int = 32, | ||
envelope_exponent: int = 5, | ||
num_before_skip: int = 1, | ||
num_after_skip: int = 2, | ||
num_output_layers: int = 3, | ||
act: Union[str, Callable] = 'swish' | ||
): | ||
""" | ||
Initializes an instance of the DimeNetPPModel class with the provided parameters. | ||
Parameters: | ||
- hidden_channels (int): Number of channels in the hidden layers (default: 128) | ||
- in_dim (int): Input dimension of the model (default: 1) | ||
- out_dim (int): Output dimension of the model (default: 1) | ||
- num_layers (int): Number of layers in the model (default: 4) | ||
- int_emb_size (int): Embedding size for interaction features (default: 64) | ||
- basis_emb_size (int): Embedding size for basis functions (default: 8) | ||
- out_emb_channels (int): Number of channels in the output embeddings (default: 256) | ||
- num_spherical (int): Number of spherical harmonics (default: 7) | ||
- num_radial (int): Number of radial basis functions (default: 6) | ||
- cutoff (float): Cutoff distance for interactions (default: 10) | ||
- max_num_neighbors (int): Maximum number of neighboring atoms to consider (default: 32) | ||
- envelope_exponent (int): Exponent of the envelope function (default: 5) | ||
- num_before_skip (int): Number of layers before the skip connections (default: 1) | ||
- num_after_skip (int): Number of layers after the skip connections (default: 2) | ||
- num_output_layers (int): Number of output layers (default: 3) | ||
- act (Union[str, Callable]): Activation function (default: 'swish' or callable) | ||
Note: | ||
- The `act` parameter can be either a string representing a built-in activation function, | ||
or a callable object that serves as a custom activation function. | ||
""" | ||
super().__init__( | ||
hidden_channels, | ||
out_dim, | ||
num_layers, | ||
int_emb_size, | ||
basis_emb_size, | ||
out_emb_channels, | ||
num_spherical, | ||
num_radial, | ||
cutoff, | ||
max_num_neighbors, | ||
envelope_exponent, | ||
num_before_skip, | ||
num_after_skip, | ||
num_output_layers, | ||
act | ||
) | ||
|
||
def forward(self, batch): | ||
|
||
i, j, idx_i, idx_j, idx_k, idx_kj, idx_ji = self.triplets( | ||
batch.edge_index, num_nodes=batch.atoms.size(0)) | ||
|
||
# Calculate distances. | ||
dist = (batch.pos[i] - batch.pos[j]).pow(2).sum(dim=-1).sqrt() | ||
|
||
# Calculate angles. | ||
pos_i = batch.pos[idx_i] | ||
pos_ji, pos_ki = batch.pos[idx_j] - pos_i, batch.pos[idx_k] - pos_i | ||
a = (pos_ji * pos_ki).sum(dim=-1) | ||
b = torch.cross(pos_ji, pos_ki).norm(dim=-1) | ||
angle = torch.atan2(b, a) | ||
|
||
rbf = self.rbf(dist) | ||
sbf = self.sbf(dist, angle, idx_kj) | ||
|
||
# Embedding block. | ||
x = self.emb(batch.atoms, rbf, i, j) | ||
P = self.output_blocks[0](x, rbf, i, num_nodes=batch.pos.size(0)) | ||
|
||
# Interaction blocks. | ||
for interaction_block, output_block in zip(self.interaction_blocks, | ||
self.output_blocks[1:]): | ||
x = interaction_block(x, rbf, sbf, idx_kj, idx_ji) | ||
P += output_block(x, rbf, i) | ||
|
||
return P.sum(dim=0) if batch is None else scatter(P, batch.batch, dim=0) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import torch | ||
from torch.nn import functional as F | ||
from torch_geometric.nn import global_add_pool, global_mean_pool | ||
|
||
from models.layers.egnn_layer import EGNNLayer | ||
|
||
|
||
class EGNNModel(torch.nn.Module): | ||
""" | ||
E-GNN model from "E(n) Equivariant Graph Neural Networks". | ||
""" | ||
def __init__( | ||
self, | ||
num_layers: int = 5, | ||
emb_dim: int = 128, | ||
in_dim: int = 1, | ||
out_dim: int = 1, | ||
activation: str = "relu", | ||
norm: str = "layer", | ||
aggr: str = "sum", | ||
pool: str = "sum", | ||
residual: bool = True, | ||
equivariant_pred: bool = False | ||
): | ||
""" | ||
Initializes an instance of the EGNNModel class with the provided parameters. | ||
Parameters: | ||
- num_layers (int): Number of layers in the model (default: 5) | ||
- emb_dim (int): Dimension of the node embeddings (default: 128) | ||
- in_dim (int): Input dimension of the model (default: 1) | ||
- out_dim (int): Output dimension of the model (default: 1) | ||
- activation (str): Activation function to be used (default: "relu") | ||
- norm (str): Normalization method to be used (default: "layer") | ||
- aggr (str): Aggregation method to be used (default: "sum") | ||
- pool (str): Global pooling method to be used (default: "sum") | ||
- residual (bool): Whether to use residual connections (default: True) | ||
- equivariant_pred (bool): Whether it is an equivariant prediction task (default: False) | ||
""" | ||
super().__init__() | ||
self.equivariant_pred = equivariant_pred | ||
self.residual = residual | ||
|
||
# Embedding lookup for initial node features | ||
self.emb_in = torch.nn.Embedding(in_dim, emb_dim) | ||
|
||
# Stack of GNN layers | ||
self.convs = torch.nn.ModuleList() | ||
for _ in range(num_layers): | ||
self.convs.append(EGNNLayer(emb_dim, activation, norm, aggr)) | ||
|
||
# Global pooling/readout function | ||
self.pool = {"mean": global_mean_pool, "sum": global_add_pool}[pool] | ||
|
||
if self.equivariant_pred: | ||
# Linear predictor for equivariant tasks using geometric features | ||
self.pred = torch.nn.Linear(emb_dim + 3, out_dim) | ||
else: | ||
# MLP predictor for invariant tasks using only scalar features | ||
self.pred = torch.nn.Sequential( | ||
torch.nn.Linear(emb_dim, emb_dim), | ||
torch.nn.ReLU(), | ||
torch.nn.Linear(emb_dim, out_dim) | ||
) | ||
|
||
def forward(self, batch): | ||
|
||
h = self.emb_in(batch.atoms) # (n,) -> (n, d) | ||
pos = batch.pos # (n, 3) | ||
|
||
for conv in self.convs: | ||
# Message passing layer | ||
h_update, pos_update = conv(h, pos, batch.edge_index) | ||
|
||
# Update node features (n, d) -> (n, d) | ||
h = h + h_update if self.residual else h_update | ||
|
||
# Update node coordinates (no residual) (n, 3) -> (n, 3) | ||
pos = pos_update | ||
|
||
if not self.equivariant_pred: | ||
# Select only scalars for invariant prediction | ||
out = self.pool(h, batch.batch) # (n, d) -> (batch_size, d) | ||
else: | ||
out = self.pool(torch.cat([h, pos], dim=-1), batch.batch) | ||
|
||
return self.pred(out) # (batch_size, out_dim) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
import torch | ||
from torch.nn import functional as F | ||
from torch_geometric.nn import global_add_pool, global_mean_pool | ||
|
||
from models.mace_modules.blocks import RadialEmbeddingBlock | ||
import models.layers.gvp_layer as gvp | ||
|
||
|
||
class GVPGNNModel(torch.nn.Module): | ||
""" | ||
GVP-GNN model from "Equivariant Graph Neural Networks for 3D Macromolecular Structure". | ||
""" | ||
def __init__( | ||
self, | ||
r_max: float = 10.0, | ||
num_bessel: int = 8, | ||
num_polynomial_cutoff: int = 5, | ||
num_layers: int = 5, | ||
in_dim=1, | ||
out_dim=1, | ||
s_dim: int = 128, | ||
v_dim: int = 16, | ||
s_dim_edge: int = 32, | ||
v_dim_edge: int = 1, | ||
pool: str = "sum", | ||
residual: bool = True, | ||
equivariant_pred: bool = False | ||
): | ||
""" | ||
Initializes an instance of the GVPGNNModel class with the provided parameters. | ||
Parameters: | ||
- r_max (float): Maximum distance for Bessel basis functions (default: 10.0) | ||
- num_bessel (int): Number of Bessel basis functions (default: 8) | ||
- num_polynomial_cutoff (int): Number of polynomial cutoff basis functions (default: 5) | ||
- num_layers (int): Number of layers in the model (default: 5) | ||
- in_dim (int): Input dimension of the model (default: 1) | ||
- out_dim (int): Output dimension of the model (default: 1) | ||
- s_dim (int): Dimension of the node state embeddings (default: 128) | ||
- v_dim (int): Dimension of the node vector embeddings (default: 16) | ||
- s_dim_edge (int): Dimension of the edge state embeddings (default: 32) | ||
- v_dim_edge (int): Dimension of the edge vector embeddings (default: 1) | ||
- pool (str): Global pooling method to be used (default: "sum") | ||
- residual (bool): Whether to use residual connections (default: True) | ||
- equivariant_pred (bool): Whether it is an equivariant prediction task (default: False) | ||
""" | ||
super().__init__() | ||
|
||
self.r_max = r_max | ||
self.num_layers = num_layers | ||
self.equivariant_pred = equivariant_pred | ||
self.s_dim = s_dim | ||
self.v_dim = v_dim | ||
|
||
activations = (F.relu, None) | ||
_DEFAULT_V_DIM = (s_dim, v_dim) | ||
_DEFAULT_E_DIM = (s_dim_edge, v_dim_edge) | ||
|
||
# Node embedding | ||
self.emb_in = torch.nn.Embedding(in_dim, s_dim) | ||
self.W_v = torch.nn.Sequential( | ||
gvp.LayerNorm((s_dim, 0)), | ||
gvp.GVP((s_dim, 0), _DEFAULT_V_DIM, | ||
activations=(None, None), vector_gate=True) | ||
) | ||
|
||
# Edge embedding | ||
self.radial_embedding = RadialEmbeddingBlock( | ||
r_max=r_max, | ||
num_bessel=num_bessel, | ||
num_polynomial_cutoff=num_polynomial_cutoff, | ||
) | ||
self.W_e = torch.nn.Sequential( | ||
gvp.LayerNorm((self.radial_embedding.out_dim, 1)), | ||
gvp.GVP((self.radial_embedding.out_dim, 1), _DEFAULT_E_DIM, | ||
activations=(None, None), vector_gate=True) | ||
) | ||
|
||
# Stack of GNN layers | ||
self.layers = torch.nn.ModuleList( | ||
gvp.GVPConvLayer( | ||
_DEFAULT_V_DIM, _DEFAULT_E_DIM, | ||
activations=activations, vector_gate=True, | ||
residual=residual | ||
) | ||
for _ in range(num_layers) | ||
) | ||
|
||
# Global pooling/readout function | ||
self.pool = {"mean": global_mean_pool, "sum": global_add_pool}[pool] | ||
|
||
if self.equivariant_pred: | ||
# Linear predictor for equivariant tasks using geometric features | ||
self.pred = torch.nn.Linear(s_dim + v_dim * 3, out_dim) | ||
else: | ||
# MLP predictor for invariant tasks using only scalar features | ||
self.pred = torch.nn.Sequential( | ||
torch.nn.Linear(s_dim, s_dim), | ||
torch.nn.ReLU(), | ||
torch.nn.Linear(s_dim, out_dim) | ||
) | ||
|
||
def forward(self, batch): | ||
|
||
# Edge features | ||
vectors = batch.pos[batch.edge_index[0]] - batch.pos[batch.edge_index[1]] # [n_edges, 3] | ||
lengths = torch.linalg.norm(vectors, dim=-1, keepdim=True) # [n_edges, 1] | ||
|
||
h_V = self.emb_in(batch.atoms) # (n,) -> (n, d) | ||
h_E = ( | ||
self.radial_embedding(lengths), | ||
torch.nan_to_num(torch.div(vectors, lengths)).unsqueeze_(-2) | ||
) | ||
|
||
h_V = self.W_v(h_V) | ||
h_E = self.W_e(h_E) | ||
|
||
for layer in self.layers: | ||
h_V = layer(h_V, batch.edge_index, h_E) | ||
|
||
out = self.pool(gvp._merge(*h_V), batch.batch) # (n, d) -> (batch_size, d) | ||
|
||
if not self.equivariant_pred: | ||
# Select only scalars for invariant prediction | ||
out = out[:,:self.s_dim] | ||
|
||
return self.pred(out) # (batch_size, out_dim) |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.