Skip to content

Latest commit

 

History

History
 
 

examples

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 
 
 

Examples

Training

We provide three example config files for the ET for training on QM9, MD17 and ANI1 respectively. To train on a QM9 target other than energy_U0, change the parameter dataset_arg in the QM9 config file. Changing the MD17 molecule to train on works analogously. To train an ET from scratch you can use the following code from the torchmd-net directory:

CUDA_VISIBLE_DEVICES=0,1 torchmd-train --conf examples/ET-{QM9,MD17,ANI1}.yaml

Use the CUDA_VISIBLE_DEVICES environment variable to select which and how many GPUs you want to train on. The example above selects GPUs with indices 0 and 1. The training code will want to save checkpoints and config files in a directory called logs/, which you can change either in the config .yaml file or as an additional command line argument: --log-dir path/to/log-dir.

Loading checkpoints

You can access several pretrained checkpoint files under the following URLs:

The checkpoints can be loaded using the load_model function in TorchMD-Net. Additional model arguments (e.g. turning on force prediction on top of energies) for inference can also be passed to the function. See the following example code for loading an ET pretrained on the ANI1 dataset:

from torchmdnet.models.model import load_model
model = load_model("ANI1-equivariant_transformer/epoch=359-val_loss=0.0004-test_loss=0.0120.ckpt", derivative=True)

The following example shows how to run inference on the model checkpoint. For single molecules, you just have to pass atomic numbers and position tensors, to evaluate the model on multiple molecules simultaneously, also include a batch tensor, which contains the molecule index of each atom.

import torch

# single molecule
z = torch.tensor([1, 1, 8], dtype=torch.long)
pos = torch.rand(3, 3)
energy, forces = model(z, pos)

# multiple molecules
z = torch.tensor([1, 1, 8, 1, 1, 8], dtype=torch.long)
pos = torch.rand(6, 3)
batch = torch.tensor([0, 0, 0, 1, 1, 1], dtype=torch.long)
energies, forces = model(z, pos, batch)