-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_pretrain.py
56 lines (45 loc) · 1.87 KB
/
test_pretrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# -*- coding: utf-8 -*-
"""
Created on Thu Sep 14 15:54:13 2023
@author: dfhuang
"""
import argparse
from utils import *
from evaluate_metrics import *
import pandas as pd
import torch
from MCDTA import mcDTA
from train import load_features
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
mol_sequence_dim = 64
protein_sequence_dim = 20
# load corresponding features
def test(set_cate = 'test'):
X_test, y_test = load_features(set_cate = set_cate,proportion=1.0)
########################## load model ################################
with open ('./models/2024-03-25T14_23_47--best_model_param.pkl','rb') as handle:
Model = mcDTA( mol_sequence_channel= mol_sequence_dim,
pro_sequence_channel = protein_sequence_dim,
out_channel=128, n_output=1)
Model.model.load_state_dict(torch.load(handle,map_location=torch.device('cuda:0')))
y_pred,y_true = Model.test(X_test, y_test)
y_pred=y_pred.detach().numpy().flatten()
y_true=y_true.detach().numpy().flatten()
assert len(y_pred) == len(y_true)
test_result = [mae(y_true, y_pred), rmse(y_true, y_pred), pearson(y_true, y_pred), spearman(y_true, y_pred), ci(y_true, y_pred), r_squared(y_true, y_pred)]
print(test_result)
#### Store data
y_pred = [round(i,2) for i in y_pred]
y_true = [round(i,2) for i in y_true]
dic = {'y_pred':y_pred,'y_true':y_true}
df = pd.DataFrame(dic)
df.to_excel('./results/%s_result.xlsx'%set_cate,index=False)
return
def func():
parser = argparse.ArgumentParser(description='parse dataset categories parameters')
parser.add_argument('--Set','-S' ,type=str, help='the dataset name')
args = parser.parse_args()
set_cate = args.Set
test(set_cate=set_cate)
if __name__ == "__main__":
func()