-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
151 lines (105 loc) · 5.57 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import torch
from Utils import create_directory, load_ckp, pickle_load, pickle_save
import CONSTANTS
import math, os, time
import argparse
from models.model import TFun, TFun_submodel
from Dataset.MyDataset import TestDataset
parser = argparse.ArgumentParser()
parser.add_argument('--no-cuda', action='store_true', default=False, help='Disables CUDA training.')
parser.add_argument('--seed', type=int, default=42, help='Random seed.')
parser.add_argument("--load_weights", default=False, type=bool, help='Load weights from saved model')
args = parser.parse_args()
torch.manual_seed(args.seed)
args.cuda = not args.no_cuda and torch.cuda.is_available()
if args.cuda:
device = 'cuda:1'
else:
device = 'cpu'
# load all test
all_test = pickle_load(CONSTANTS.ROOT_DIR + "test/output_t1_t2/test_proteins")
ontologies = ["cc", "mf", "bp"]
models = ['esm2_t48', 'msa_1b', 'interpro', 'full']
def write_output(results, terms, filepath, cutoff=0.001):
with open(filepath, 'w') as fp:
for prt in results:
assert len(terms) == len(results[prt])
tmp = list(zip(terms, results[prt]))
tmp.sort(key = lambda x: x[1], reverse=True)
for trm, score in tmp:
if score > cutoff:
fp.write('%s\t%s\t%0.3f\n' % (prt, trm, score))
def get_term_indicies(ontology):
_term_indicies = pickle_load(CONSTANTS.ROOT_DIR + "{}/term_indicies".format(ontology))
if ontology == 'bp':
full_term_indicies, mid_term_indicies, freq_term_indicies = _term_indicies[0], _term_indicies[5], _term_indicies[30]
rare_term_indicies_2 = torch.tensor([i for i in full_term_indicies if not i in set(mid_term_indicies)]).to(device)
rare_term_indicies = torch.tensor([i for i in mid_term_indicies if not i in set(freq_term_indicies)]).to(device)
full_term_indicies, freq_term_indicies = torch.tensor(_term_indicies[0]).to(device), torch.tensor(freq_term_indicies).to(device)
else:
full_term_indicies = _term_indicies[0]
freq_term_indicies = _term_indicies[30]
rare_term_indicies = torch.tensor([i for i in full_term_indicies if not i in set(freq_term_indicies)]).to(device)
full_term_indicies = torch.tensor(full_term_indicies).to(device)
freq_term_indicies = torch.tensor(freq_term_indicies).to(device)
rare_term_indicies_2 = None
return full_term_indicies, freq_term_indicies, rare_term_indicies, rare_term_indicies_2
for ontology in ontologies:
data_pth = CONSTANTS.ROOT_DIR + "test/dataset/{}".format(ontology)
sorted_terms = pickle_load(CONSTANTS.ROOT_DIR+"/{}/sorted_terms".format(ontology))
for sub_model in models:
tst_dataset = TestDataset(data_pth=data_pth, submodel=sub_model)
tstloader = torch.utils.data.DataLoader(tst_dataset, batch_size=500, shuffle=False)
# terms, term_indicies, sub_indicies = get_term_indicies(ontology=ontology, submodel=sub_model)
full_term_indicies, freq_term_indicies, rare_term_indicies, rare_term_indicies_2 = get_term_indicies(ontology=ontology)
kwargs = {
'device': device,
'ont': ontology,
'full_indicies': full_term_indicies,
'freq_indicies': freq_term_indicies,
'rare_indicies': rare_term_indicies,
'rare_indicies_2': rare_term_indicies_2,
'sub_model': sub_model,
'load_weights': True,
'group': ""
}
if sub_model == "full":
print("Generating for {} {}".format(ontology, sub_model))
ckp_dir = CONSTANTS.ROOT_DIR + '{}/models/{}_gcn_old/'.format(ontology, sub_model)
ckp_pth = ckp_dir + "current_checkpoint.pt"
model = TFun(**kwargs)
# load model
model = load_ckp(checkpoint_dir=ckp_dir, model=model, best_model=False, model_only=True)
model.to(device)
model.eval()
results = {}
for data in tstloader:
_features, _proteins = data[:4], data[4]
output = model(_features)
output = torch.index_select(output, 1, full_term_indicies)
output = output.tolist()
for i, j in zip(_proteins, output):
results[i] = j
terms = [sorted_terms[i] for i in full_term_indicies]
filepath = CONSTANTS.ROOT_DIR + 'evaluation/raw_predictions/transfew/'
create_directory(filepath)
write_output(results, terms, filepath+'{}.tsv'.format(ontology), cutoff=0.01)
else:
print("Generating for {} {}".format(ontology, sub_model))
ckp_dir = CONSTANTS.ROOT_DIR + '{}/models/{}/'.format(ontology, sub_model)
ckp_pth = ckp_dir + "current_checkpoint.pt"
model = TFun_submodel(**kwargs)
model.to(device)
# print("Loading model checkpoint @ {}".format(ckp_pth))
model = load_ckp(checkpoint_dir=ckp_dir, model=model, best_model=False, model_only=True)
model.eval()
results = {}
for data in tstloader:
_features, _proteins = data[0], data[1]
output = model(_features).tolist()
for i, j in zip(_proteins, output):
results[i] = j
terms = [sorted_terms[i] for i in freq_term_indicies]
filepath = CONSTANTS.ROOT_DIR + 'evaluation/raw_predictions/{}/'.format(sub_model)
create_directory(filepath)
write_output(results, terms, filepath+'{}.tsv'.format(ontology), cutoff=0.01)