Skip to content

Commit fbc70a6

Browse files
authored
Merge pull request #6 from ZhouGengmo/add_data_process
add data preprocess file; add training workflow
2 parents 0ad37fb + e1c5289 commit fbc70a6

9 files changed

+368
-143
lines changed

README.md

+74-3
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,13 @@ It is a [Uni-Mol](https://github.com/dptech-corp/Uni-Mol)-based neural network.
7575

7676
### Usage
7777

78+
#### Dependencies
79+
80+
The dependencies of Uni-p*K*<sub>a</sub> are the same as those of Uni-Mol.
81+
82+
- [Uni-Core](https://github.com/dptech-corp/Uni-Core), check its [Installation Documentation](https://github.com/dptech-corp/Uni-Core#installation).
83+
- rdkit==2022.9.3, install via `pip install rdkit-pypi==2022.9.3`
84+
7885
The recommended environment is the docker image.
7986

8087
```
@@ -83,7 +90,71 @@ docker pull dptechnology/unimol:latest-pytorch1.11.0-cuda11.3
8390

8491
See details in [Uni-Mol](https://github.com/dptech-corp/Uni-Mol/tree/main/unimol#dependencies) repository.
8592

86-
After the full datasets had been downloaded, use `scripts/pretrain_pka_mlm_aml.sh` to pretrain the model, use `scripts/finetune_pka_aml.sh` to finetune the model, use `infer_test.sh` to test the trained model on a macro-p*K*<sub>a</sub> dataset, and use `infer_free_energy.sh` to infer the free energy of given structures for any p*K*<sub>a</sub>-related tasks.
8793

88-
## Todo
89-
Ready-to-run training workflow
94+
### Ready-to-run training workflow
95+
96+
#### Data
97+
98+
The raw data can be downloaded from [[AISSquare](https://www.aissquare.com/datasets/detail?pageType=datasets&name=Uni-pKa-Dataset)].
99+
100+
101+
#### Pretrain with ChemBL
102+
103+
First, preprocess the ChemBL training and validation sets, and then pretrain the model:
104+
105+
```bash
106+
# Preprocess training set
107+
python ./scripts/preprocess_pka.py --raw-csv-file Datasets/tsv/chembl_train.tsv --processed-lmdb-dir chembl --task-name train
108+
109+
# Preprocess validation set
110+
python ./scripts/preprocess_pka.py --raw-csv-file Datasets/tsv/chembl_valid.tsv --processed-lmdb-dir chembl --task-name valid
111+
112+
# Copy the necessary dict file
113+
cp -r unimol/examples/* chembl
114+
115+
# Pretrain the model
116+
bash pretrain_pka.sh
117+
```
118+
119+
Note: The `head_name` in the subsequent scripts must match the `task_name` in `pretrain_pka.sh`.
120+
121+
122+
#### Finetune with dwar-iBond
123+
124+
Next, preprocess the dwar-iBond dataset and finetune the model:
125+
126+
```bash
127+
# Preprocess
128+
python ./scripts/preprocess_pka.py --raw-csv-file Datasets/tsv/dwar-iBond.tsv --processed-lmdb-dir dwar --task-name dwar-iBond
129+
130+
# Copy the necessary dict file
131+
cp -r unimol/examples/* dwar
132+
133+
# Finetune the model
134+
bash finetune_pka.sh
135+
```
136+
137+
#### Infer p*K*<sub>a</sub>
138+
139+
Infer with the finetuned model, taking novartis_acid as an example:
140+
141+
```bash
142+
# Preprocess
143+
python ./scripts/preprocess_pka.py --raw-csv-file Datasets/tsv/novartis_acid.tsv --processed-lmdb-dir novartis_acid --task-name novartis_acid
144+
145+
# Copy the necessary examples from unimol
146+
cp -r unimol/examples/* novartis_acid
147+
148+
# Run inference
149+
bash infer_pka.sh
150+
```
151+
To test with other external test datasets, it may be necessary to modify `data_path`, `infer_task`, and `results_path` in `infer_pka.sh`.
152+
153+
#### Obtain the result files and calculate the metrics
154+
After inference, extract the results to CSV files and calculate the performance metrics (e.g., MAE, RMSE) on the results:
155+
156+
```bash
157+
python ./scripts/infer_mean_ensemble.py --task pka --nfolds 5 --results-path novartis_acid_results
158+
```
159+
160+
The metrics are calculated using the average of the 5-fold model predictions.

scripts/finetune_pka.sh

-51
This file was deleted.

scripts/infer_free_energy.sh

-27
This file was deleted.

scripts/infer_mean_ensemble.py

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright (c) DP Technology.
2+
# This source code is licensed under the MIT license found in the
3+
# LICENSE file in the root directory of this source tree.
4+
5+
import pandas as pd
6+
import os
7+
import argparse
8+
import numpy as np
9+
import glob
10+
11+
12+
def cal_metrics(df):
13+
mae = np.abs(df["predict"] - df["target"]).mean()
14+
mse = ((df["predict"] - df["target"]) ** 2).mean()
15+
rmse = np.sqrt(mse)
16+
return mae, rmse
17+
18+
19+
def get_csv_results(results_path, nfolds, task):
20+
21+
all_smi_list, all_predict_list, all_target_list = [], [], []
22+
23+
for fold_idx in range(nfolds):
24+
print(f"Processing fold {fold_idx}...")
25+
fold_path = os.path.join(results_path, f'fold_{fold_idx}')
26+
pkl_files = glob.glob(f"{fold_path}/*.pkl")
27+
fold_data = pd.read_pickle(pkl_files[0])
28+
29+
smi_list, predict_list, target_list = [], [], []
30+
for batch in fold_data:
31+
sz = batch["bsz"]
32+
for i in range(sz):
33+
smi_list.append(batch["smi_name"][i])
34+
predict_list.append(batch["predict"][i].cpu().item())
35+
target_list.append(batch["target"][i].cpu().item())
36+
fold_df = pd.DataFrame({"smiles": smi_list, "predict": predict_list, "target": target_list})
37+
fold_df.to_csv(f'{fold_path}/fold_{fold_idx}.csv',index=False, sep='\t')
38+
39+
# for final combined results
40+
all_smi_list.extend(smi_list)
41+
all_predict_list.extend(predict_list)
42+
all_target_list.extend(target_list)
43+
44+
print(f"Combining results from {nfolds} folds into a single file...")
45+
combined_df = pd.DataFrame({"smiles": all_smi_list, "predict": all_predict_list, "target": all_target_list})
46+
combined_df.to_csv(f'{results_path}/all_results.csv', index=False, sep='\t')
47+
48+
print(f"Calculating mean results for each SMILES...")
49+
mean_results = combined_df.groupby('smiles', as_index=False).agg({
50+
'predict': 'mean',
51+
'target': 'mean'
52+
})
53+
mean_results.to_csv(f'{results_path}/mean_results.csv', index=False, sep='\t')
54+
if task == 'pka':
55+
print(f"MAE and RMSE for this task...")
56+
mae, rmse = cal_metrics(mean_results)
57+
print(f'MAE: {round(mae, 4)}, RMSE: {round(rmse, 4)}')
58+
print(f"Done!")
59+
60+
61+
def main():
62+
parser = argparse.ArgumentParser(description='Model infer result mean ensemble')
63+
parser.add_argument(
64+
'--results-path',
65+
type=str,
66+
default='results',
67+
help='path to save infer results'
68+
)
69+
parser.add_argument(
70+
"--nfolds",
71+
default=5,
72+
type=int,
73+
help="cross validation split folds"
74+
)
75+
parser.add_argument(
76+
"--task",
77+
default='pka',
78+
type=str,
79+
choices=['pka', 'free_energy']
80+
)
81+
args = parser.parse_args()
82+
get_csv_results(args.results_path, args.nfolds, args.task)
83+
84+
85+
if __name__ == "__main__":
86+
main()

0 commit comments

Comments
 (0)