forked from open-mmlab/mmsegmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmetrics_calculate.py.old
128 lines (116 loc) · 5.49 KB
/
metrics_calculate.py.old
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from mmseg.core import eval_metrics,eval_attach_metrics
from mmseg.apis import inference_segmentor, init_segmentor,inference_segmentor_concactdata
import mmcv
import os
import numpy as np
from mmcv.utils import print_log
from terminaltables import AsciiTable
from skimage import io
def evaluate(results,
gt_seg_maps,
num_classes=2,
metric='mIoU',
logger=None,
ignore_index=255,
label_map=None,
reduce_zero_label=False,
class_names=['other','water'],
att_metrics = ['PRE','REC','F-measure','F-max','FPR','FNR'],
efficient_test=False,
**kwargs):
"""Evaluate the dataset.
Args:
results (list): Testing results of the dataset.
metric (str | list[str]): Metrics to be evaluated. 'mIoU' and
'mDice' are supported.
logger (logging.Logger | None | str): Logger used for printing
related information during evaluation. Default: None.
Returns:
dict[str, float]: Default metrics.
"""
if isinstance(metric, str):
metric = [metric]
allowed_metrics = ['mIoU', 'mDice'] # 'PRE','REC','F-measure','F-max','FPR','FNR'
if not set(metric).issubset(set(allowed_metrics)):
raise KeyError('metric {} is not supported'.format(metric))
eval_results = {}
ret_metrics = eval_metrics(
results,
gt_seg_maps,
num_classes,
ignore_index,
metric,
label_map=label_map,
reduce_zero_label=reduce_zero_label)
class_table_data = [['Class'] + [m[1:] for m in metric] + ['Acc']]
ret_metrics_round = [
np.round(ret_metric * 100, 2) for ret_metric in ret_metrics
]
for i in range(num_classes):
class_table_data.append([class_names[i]] +
[m[i] for m in ret_metrics_round[2:]] +
[ret_metrics_round[1][i]])
ret_metrics_mean = [
np.round(np.nanmean(ret_metric) * 100, 2)
for ret_metric in ret_metrics
]
if att_metrics is not None:
attach_metrics = eval_attach_metrics(
results,
gt_seg_maps,
num_classes,
ignore_index,
att_metrics,
label_map=label_map,
reduce_zero_label=reduce_zero_label)
summary_table_data = \
[['Scope'] + ['m' + head for head in class_table_data[0][1:]] + [attach for attach in att_metrics] + [
'aAcc']]
summary_table_data.append(
['global'] + ret_metrics_mean[2:] + [ret_metrics_mean[1]] + [att_metrics for att_metrics in
attach_metrics] + [ret_metrics_mean[0]])
else:
summary_table_data = [['Scope'] + ['m' + head for head in class_table_data[0][1:]] + ['aAcc']]
summary_table_data.append(['global'] + ret_metrics_mean[2:] + [ret_metrics_mean[1]] + [ret_metrics_mean[0]])
print_log('per class results:', logger)
table = AsciiTable(class_table_data)
print_log('\n' + table.table, logger=logger)
print_log('Summary:', logger)
table = AsciiTable(summary_table_data)
print_log('\n' + table.table, logger=logger)
for i in range(1, len(summary_table_data[0])):
eval_results[summary_table_data[0]
[i]] = summary_table_data[1][i] / 100.0
if mmcv.is_list_of(results, str):
for file_name in results:
os.remove(file_name)
return eval_results
seg_path = '/home/home2/zrd/data/val/validation-paper1/label_deal'
#test_path = '/home/home2/zrd/data/val/validation-paper1/cpnetplus_r50-d8-combine-box'
#test_path = '/home/home2/zrd/data/val/validation-paper1/cpnet_r50-d8-combine-zrdy/'
#test_path = '/home/home2/zrd/data/val/validation-paper1/cpnetplus_inter_r50-d8-combine-zrdy/'
#test_path = '/home/home2/zrd/data/val/validation-paper1/cpnetplus_inter_r50-d8-combine-zrdy/'
#test_path = '/home/home2/zrd/data/val/validation-paper1/cpnetplus_intra_r50-d8-combine-zrdy/'
#test_path = '/home/home2/zrd/data/val/validation-paper1/cpnetplus_m-v2-d8_512x512_160k_combine-zrdy/'
#test_path = '/home/home2/zrd/data/val/validation-paper1/cpnetplus_r50-d8-combine-zrdy/'
#test_path = '/home/home2/zrd/data/val/validation-paper1/cpnetplus_rs101-d8-combine-zrdy/'
#test_path = '/home/home2/zrd/data/val/validation-paper1/configs/deeplabv3plus/deeplabv3plus_r50-d8_512x512_80k_combine-zrdy/'
#test_path = '/home/home2/zrd/data/val/validation-paper1/configs/deeplabv3plus/deeplabv3plus_r50-d8_512x512_80k_combine-zrdy/'
#test_path = '/home/home2/zrd/data/val/validation-paper1/pspnet_r50-d8_combine-zrdy/'
#test_path = '/home/home2/zrd/data/val/validation-paper1/wl/'
#test_path = '/home/home2/zrd/data/val/validation-paper1/wasr/'
#test_path = '/home/home2/zrd/data/val/validation-paper1/segnet/'
#test_path = '/home/home2/zrd/data/val/validation-paper1/wl-largeBN-mobilenetv2/'
test_path = '/home/home2/zrd/data/val/validation-paper1/wl-largeBN/'
seg_list=[]
test_list=[]
for file_path in os.listdir(seg_path):
if os.path.isfile(os.path.join(seg_path, file_path)) == True:
seg_file = os.path.join(seg_path, file_path)
test_file = os.path.join(test_path, file_path)
seg = io.imread(seg_file)
test = io.imread(test_file)
if seg.shape == test.shape:
test_list.append(test)
seg_list.append(seg)
evaluate(test_list,seg_list)