This is the implementation of our ICLR 2022 paper: https://arxiv.org/pdf/2110.04624.pdf
Our model is tested in Linux with the following packages:
- CUDA >= 11.1
- PyTorch == 1.8.2 (LTS Version)
- Numpy >= 1.18.1
- tqdm
Our data is retreived from the Structural Antibody Database (SAbDab). The training, validation, and test data (compressed) is located in data/sabdab
.
To train a RefineGNN for CDR-H3, please run
python ab_train.py --cdr_type 3 --train_path data/sabdab/hcdr3_cluster/train_data.jsonl --val_path data/sabdab/hcdr3_cluster/val_data.jsonl --test_path data/sabdab/hcdr3_cluster/test_data.jsonl
The default hyperparameters are: hidden layer dimension --hidden_size 256
, number of message passing layers --depth 4
, KNN neighborhood size --K_neighbors 9
, and the framework residue block size --block_size 8
(multi-resolution modeling, section 3.3).
During training, this script will report perplexity (PPL) and root-mean-square-error (RMSD) over the validation set. You can also train a RefineGNN for a different CDR region by changing --cdr_type 2
(CDR-H2) and --cdr_type 1
(CDR-H1).
If you don't want to train RefineGNN from scratch, please load a pre-trained model and run inference on the test set by
python ab_train.py --cdr_type 3 --load_model ckpts/RefineGNN-hcdr3/model.best --epoch 0
where --epoch 0
means zero training epochs.
Note: The above training script usually requires 20~24GB GPU memory. The GPU memory consumption can be substantially reduced by removing the multi-resolution modeling component. If you have limited GPU memory, you can train a RefineGNN without multi-resolution modeling by
python baseline_train.py --cdr_type 3 --train_path data/sabdab/hcdr3_cluster/train_data.jsonl --val_path data/sabdab/hcdr3_cluster/val_data.jsonl --test_path data/sabdab/hcdr3_cluster/test_data.jsonl --architecture RefineGNN_attonly
The above training script usually consumes 4GB GPU memory. You can also train our AR-GNN baseline by setting --architecture AR-GNN
.
To train a RefineGNN for this task, please run
python ab_train.py --train_path data/rabd/train.jsonl --val_path data/rabd/val.jsonl --test_path data/rabd/test.jsonl
At test time, we generate 10000 CDR-H3 sequences for each antibody and select the top 100 candidates with the lowest perplexity. You can load a pre-trained model and run inference on the test set by
python rabd_test.py --load_model ckpts/RefineGNN-rabd/model.best
Besides CDR sequence design, we can also use RefineGNN to predict CDR loop structure given an antibody VH sequence. To train a RefineGNN for structure prediction alone, please run
python fold_train.py --cdr 123
--cdr 123
means the model will predict CDR-H1, CDR-H2, and CDR-H3 combined. You can change it to --cdr 3
if you want to predict CDR-H3 structure only.
For convenience, we have provided a pre-trained checkpoint for CDR-H1,2,3 joint structure prediction. You can print the predicted CDR structures using the following script:
python print_cdr.py --load_model ckpts/RefineGNN-hfold/model.best --save_dir pred_pdbs
The predicted structures are saved in pred_pdbs/*.pdb
. Each pdb file has a header line that reports the RMSD score. You can visualize them in PyMOL.