-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
176 lines (147 loc) · 7.91 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import joblib,copy
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
import torch,sys
from tqdm import tqdm
from collections import OrderedDict
from lib.visualize import save_img,group_images,concat_result
import os
import argparse
from lib.logger import Logger, Print_Logger
from lib.extract_patches import *
from os.path import join
from lib.dataset import TestDataset
from lib.metrics import Evaluate
import models
from lib.common import setpu_seed,dict_round
from config import parse_args
from lib.pre_processing import my_PreProc
setpu_seed(2021)
class Test():
def __init__(self, args):
self.args = args
assert (args.stride_height <= args.test_patch_height and args.stride_width <= args.test_patch_width)
# save path
self.path_experiment = join(args.outf, args.save)
# 设置test的
self.patches_imgs_test, self.test_imgs, self.test_masks, self.test_FOVs, self.new_height, self.new_width = get_data_test_overlap(
test_data_path_list=args.test_data_path_list, # 数据txt文件,里面是各个图片的位置
patch_height=args.test_patch_height,
patch_width=args.test_patch_width,
stride_height=args.stride_height,
stride_width=args.stride_width
)
self.img_height = self.test_imgs.shape[2]
self.img_width = self.test_imgs.shape[3]
test_set = TestDataset(self.patches_imgs_test)
self.test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=3)
self.lc = args.largest_connected
self.binary_images = None
# Inference prediction process
def inference(self, net):
net.eval()
preds = []
with torch.no_grad():
for batch_idx, inputs in tqdm(enumerate(self.test_loader), total=len(self.test_loader)):
inputs = inputs.cuda()
outputs = net(inputs)
outputs = outputs[:,1].data.cpu().numpy()
preds.append(outputs)
predictions = np.concatenate(preds, axis=0)
self.pred_patches = np.expand_dims(predictions,axis=1)
# Evaluate ate and visualize the predicted images
def evaluate(self):
self.pred_imgs = recompone_overlap(
self.pred_patches, self.new_height, self.new_width, self.args.stride_height, self.args.stride_width)
## restore to original dimensions
self.pred_imgs = self.pred_imgs[:, :, 0:self.img_height, 0:self.img_width]
#predictions only inside the FOV
y_scores, y_true, self.binary_images = pred_only_in_FOV(self.pred_imgs, self.test_masks, self.test_FOVs,largest_connected=self.lc)
eval = Evaluate(save_path=self.path_experiment,lc=self.lc)
eval.add_batch(y_true, y_scores, self.binary_images)
log = eval.save_all_result(plot_curve=True,save_name="performance.txt")
# save labels and probs for plot ROC and PR curve when k-fold Cross-validation
np.save('{}/result.npy'.format(self.path_experiment), np.asarray([y_true, y_scores]))
return dict_round(log, 6)
# save segmentation imgs
def save_segmentation_result(self):
img_path_list, _, _ = load_file_path_txt(self.args.test_data_path_list)
img_name_list = [item.split('/')[-1].split('.')[0] for item in img_path_list]
kill_border(self.pred_imgs, self.test_FOVs) # only for visualization
self.save_img_path = join(self.path_experiment,'result_img')
if not os.path.exists(join(self.save_img_path)):
os.makedirs(self.save_img_path)
# self.test_imgs = my_PreProc(self.test_imgs) # Uncomment to save the pre processed image
for i in range(self.test_imgs.shape[0]):
total_img = concat_result(self.test_imgs[i],self.pred_imgs[i],self.test_masks[i], self.lc)
# print(sum(self.pred_imgs[i]))
save_img(total_img,join(self.save_img_path, "Result_"+img_name_list[i]+'.png'))
# Val on the test set at each epoch
def val(self):
self.pred_imgs = recompone_overlap(
self.pred_patches, self.new_height, self.new_width, self.args.stride_height, self.args.stride_width)
## recover to original dimensions
self.pred_imgs = self.pred_imgs[:, :, 0:self.img_height, 0:self.img_width]
#predictions only inside the FOV
y_scores, y_true, self.binary_images = pred_only_in_FOV(self.pred_imgs, self.test_masks, self.test_FOVs) # test 这里的self.image改成了images
eval = Evaluate(save_path=self.path_experiment, lc=self.lc)
eval.add_batch(y_true, y_scores, self.binary_images)
confusion,accuracy,specificity,sensitivity,precision = eval.confusion_matrix()
log = OrderedDict([('F1 Score', eval.f1_score()),
('Sensitivity', sensitivity),
('Specifity', specificity),
('Accuracy', accuracy),
('AUC-ROC', eval.auc_roc()),
('Mean-IOU', eval.jaccard_index())])
return dict_round(log, 6)
def predict_wo_gt(self, net):
net.eval()
preds = []
with torch.no_grad():
for batch_idx, inputs in tqdm(enumerate(self.test_loader), total=len(self.test_loader)):
inputs = inputs.cuda()
outputs = net(inputs)
outputs = outputs[:,1].data.cpu().numpy()
preds.append(outputs)
predictions = np.concatenate(preds, axis=0)
self.pred_patches = np.expand_dims(predictions,axis=1)
self.pred_imgs = recompone_overlap(
self.pred_patches, self.new_height, self.new_width, self.args.stride_height, self.args.stride_width)
## restore to original dimensions
self.pred_imgs = self.pred_imgs[:, :, 0:self.img_height, 0:self.img_width]
#predictions only inside the FOV
y_scores, y_true, self.binary_images = pred_only_in_FOV(self.pred_imgs, self.test_masks, self.test_FOVs,largest_connected=self.lc)
img_path_list, _, _ = load_file_path_txt(self.args.test_data_path_list)
img_name_list = [item.split('/')[-1].split('.')[0] for item in img_path_list]
kill_border(self.pred_imgs, self.test_FOVs) # only for visualization
self.save_img_path = join(self.path_experiment,'result_img')
if not os.path.exists(join(self.save_img_path)):
os.makedirs(self.save_img_path)
# self.test_imgs = my_PreProc(self.test_imgs) # Uncomment to save the pre processed image
for i in range(self.test_imgs.shape[0]):
total_img = concat_result(self.test_imgs[i],self.pred_imgs[i],self.test_masks[i], self.lc)
# print(sum(self.pred_imgs[i]))
save_img(total_img,join(self.save_img_path, "Result_"+img_name_list[i].split('\\')[-1]+'.png'))
if __name__ == '__main__':
# 基础设置
args = parse_args()
save_path = join(args.outf, args.save)
sys.stdout = Print_Logger(os.path.join(save_path, 'test_log.txt'))
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
cudnn.benchmark = True
# 模型设置
# net = models.UNetFamily.U_Net(1,2).to(device)
# net = models.MF_UNet(2, 1, 10).to(device)
# net = models.LadderNet(inplanes=args.in_channels, num_classes=args.classes, layers=3, filters=16).to(device)
net = models.Dense_Unet(in_chan=1).to(device)
# net = models.LadderNet(inplanes=1, num_classes=2, layers=3, filters=16).to(device)
# Load checkpoint
print('==> Loading checkpoint...')
checkpoint = torch.load(join(save_path, 'best_model.pth'))
net.load_state_dict(checkpoint['net'])
eval = Test(args)
# eval.predict_wo_gt(net)
eval.inference(net)
print(eval.evaluate())
# print(eval.val())
eval.save_segmentation_result()