Skip to content

Commit 0b2f022

Browse files
authored
Merge pull request ivadomed#383 from ivadomed/cg/roc
Find optimal threshold with ROC analysis
2 parents bcad0ce + f269258 commit 0b2f022

18 files changed

+235
-42
lines changed

docs/source/configuration_file.rst

+3-3
Original file line numberDiff line numberDiff line change
@@ -303,9 +303,9 @@ UNet3D (Optional)
303303
Testing parameters
304304
------------------
305305

306-
- ``binarize_prediction``: Bool. Binarize output predictions using a
307-
threshold of 0.5. If ``false``, output predictions are float between
308-
0 and 1.
306+
- ``binarize_prediction``: Float. Threshold (between 0 and 1) used to binarize
307+
the predictions before computing the validation metrics. To use soft predictions
308+
(i.e. no binarisation, float between 0 and 1) for metric computation, indicate -1.
309309

310310
uncertainty
311311
^^^^^^^^^^^

ivadomed/config/config.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
"film_layers": [0, 1, 0, 0, 0, 0, 0, 0, 0, 0]
7373
},
7474
"testing_parameters": {
75-
"binarize_prediction": true,
75+
"binarize_prediction": -1,
7676
"uncertainty": {
7777
"epistemic": false,
7878
"aleatoric": false,

ivadomed/config/config_classification.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
"applied": true
6868
},
6969
"testing_parameters": {
70-
"binarize_prediction": true,
70+
"binarize_prediction": -1,
7171
"uncertainty": {
7272
"epistemic": false,
7373
"aleatoric": false,

ivadomed/config/config_sctTesting.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
"depth": 2
6565
},
6666
"testing_parameters": {
67-
"binarize_prediction": true,
67+
"binarize_prediction": -1,
6868
"uncertainty": {
6969
"epistemic": false,
7070
"aleatoric": false,

ivadomed/config/config_small.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
"film_layers": [0, 1, 0, 0, 0, 0, 0, 0]
7070
},
7171
"testing_parameters": {
72-
"binarize_prediction": true,
72+
"binarize_prediction": -1,
7373
"uncertainty": {
7474
"epistemic": false,
7575
"aleatoric": false,

ivadomed/config/config_spineGeHemis.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@
8282
"n_filters": 1
8383
},
8484
"testing_parameters": {
85-
"binarize_prediction": true,
85+
"binarize_prediction": -1,
8686
"uncertainty": {
8787
"epistemic": false,
8888
"aleatoric": false,

ivadomed/config/config_tumorSeg.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@
7575
"n_filters": 8
7676
},
7777
"testing_parameters": {
78-
"binarize_prediction": true,
78+
"binarize_prediction": -1,
7979
"uncertainty": {
8080
"epistemic": false,
8181
"aleatoric": false,

ivadomed/config/config_vertebral_labeling.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@
6868
"film_layers": [0, 1, 0, 0, 0, 0, 0, 0, 0, 0]
6969
},
7070
"testing_parameters": {
71-
"binarize_prediction": false,
71+
"binarize_prediction": -1,
7272
"uncertainty": {
7373
"epistemic": false,
7474
"aleatoric": false,

ivadomed/evaluation.py

+1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def evaluate(bids_path, log_directory, path_preds, target_suffix, eval_params):
5050
# 3D evaluation
5151
nib_pred = nib.load(fname_pred)
5252
data_pred = nib_pred.get_fdata()
53+
5354
h, w, d = data_pred.shape[:3]
5455
n_classes = len(fname_gt)
5556
data_gt = np.zeros((h, w, d, n_classes))

ivadomed/main.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,20 @@ def get_parser():
3535
' The parameter indicates the number of 2D slices used to generate GIFs, one GIF '
3636
'per slice. A GIF shows predictions of a given slice from the validation '
3737
'sub-dataset. They are saved within the log directory.')
38+
optional_args.add_argument('-t', '--thr-increment', dest="thr_increment", required=False, type=float,
39+
help='A threshold analysis is performed at the end of the training using the trained '
40+
'model and the validation sub-dataset to find the optimal binarization threshold. '
41+
'The specified value indicates the increment between 0 and 1 used during the '
42+
'analysis (e.g. 0.1). Plot is saved under "log_directory/thr.png" and the '
43+
'optimal threshold in "log_directory/config_file.json as "binarize_prediction" '
44+
'parameter.')
3845
optional_args.add_argument('-h', '--help', action='help', default=argparse.SUPPRESS,
3946
help='Shows function documentation.')
4047

4148
return parser
4249

4350

44-
def run_command(context, n_gif=0):
51+
def run_command(context, n_gif=0, thr_increment=None):
4552
"""Run main command.
4653
4754
This function is central in the ivadomed project as training / testing / evaluation commands are run via this
@@ -53,7 +60,9 @@ def run_command(context, n_gif=0):
5360
n_gif (int): Generates a GIF during training if larger than zero, one frame per epoch for a given slice. The
5461
parameter indicates the number of 2D slices used to generate GIFs, one GIF per slice. A GIF shows
5562
predictions of a given slice from the validation sub-dataset. They are saved within the log directory.
56-
63+
thr_increment (float): A threshold analysis is performed at the end of the training using the trained model and
64+
the validation sub-dataset to find the optimal binarization threshold. The specified value indicates the
65+
increment between 0 and 1 used during the ROC analysis (e.g. 0.1).
5766
Returns:
5867
If "train" command: Returns floats: best loss score for both training and validation.
5968
If "test" command: Returns dict: of averaged metrics computed on the testing sub dataset.
@@ -172,7 +181,7 @@ def run_command(context, n_gif=0):
172181
print('Model directory already exists: {}'.format(path_model))
173182

174183
# RUN TRAINING
175-
best_training_dice, best_training_loss, best_validation_dice, best_validation_loss = imed_training.train(
184+
best_training_dice, best_training_loss, best_validation_dice, best_validation_loss, thr = imed_training.train(
176185
model_params=model_params,
177186
dataset_train=ds_train,
178187
dataset_val=ds_valid,
@@ -182,8 +191,13 @@ def run_command(context, n_gif=0):
182191
cuda_available=cuda_available,
183192
metric_fns=metric_fns,
184193
n_gif=n_gif,
194+
thr_increment=thr_increment,
185195
debugging=context["debugging"])
186196

197+
# Update threshold in config file
198+
if thr_increment:
199+
context["testing_parameters"]["binarize_prediction"] = thr
200+
187201
# Save config file within log_directory and log_directory/model_name
188202
with open(os.path.join(log_directory, "config_file.json"), 'w') as fp:
189203
json.dump(context, fp, indent=4)
@@ -265,7 +279,9 @@ def run_main():
265279
context = json.load(fhandle)
266280

267281
# Run command
268-
run_command(context=context, n_gif=args.gif)
282+
run_command(context=context,
283+
n_gif=args.gif if args.gif is not None else 0,
284+
thr_increment=args.thr_increment if args.thr_increment else None)
269285

270286

271287
if __name__ == "__main__":

ivadomed/metrics.py

+46-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from collections import defaultdict
22

3-
from scipy import spatial
3+
import matplotlib.pyplot as plt
44
import numpy as np
5+
from scipy import spatial
56

67

78
class MetricManager(object):
@@ -15,7 +16,7 @@ class MetricManager(object):
1516
result_dict (dict): Dictionary storing metrics.
1617
num_samples (int): Number of samples.
1718
"""
18-
19+
1920
def __init__(self, metric_fns):
2021
self.metric_fns = metric_fns
2122
self.num_samples = 0
@@ -275,3 +276,46 @@ def multi_class_dice_score(im1, im2):
275276
dice_per_class += dice_score(im1[i,], im2[i,], empty_score=1.0)
276277

277278
return dice_per_class / n_classes
279+
280+
281+
def plot_roc_curve(tpr, fpr, opt_thr_idx, fname_out):
282+
"""Plot ROC curve.
283+
284+
Args:
285+
tpr (list): True positive rates.
286+
fpr (list): False positive rates.
287+
opt_thr_idx (int): Index of the optimal threshold.
288+
fname_out (str): Output filename.
289+
"""
290+
plt.figure()
291+
lw = 2
292+
plt.plot(fpr, tpr, color='darkorange', lw=lw, marker='o')
293+
plt.plot([fpr[opt_thr_idx]], [tpr[opt_thr_idx]], color="darkgreen", marker="o", linestyle="None")
294+
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
295+
plt.xlim([0.0, 1.0])
296+
plt.ylim([0.0, 1.05])
297+
plt.xlabel('False Positive Rate')
298+
plt.ylabel('True Positive Rate')
299+
plt.title('ROC curve')
300+
plt.savefig(fname_out)
301+
302+
303+
def plot_dice_thr(thr_list, dice_list, opt_thr_idx, fname_out):
304+
"""Plot Dice results against thresholds.
305+
306+
Args:
307+
thr_list (list): Thresholds list.
308+
dice_list (list): Dice results.
309+
opt_thr_idx (int): Index of the optimal threshold.
310+
fname_out (str): Output filename.
311+
"""
312+
plt.figure()
313+
lw = 2
314+
plt.plot(thr_list, dice_list, color='darkorange', lw=lw, marker='o')
315+
plt.plot([thr_list[opt_thr_idx]], [dice_list[opt_thr_idx]], color="darkgreen", marker="o", linestyle="None")
316+
plt.xlim([0.0, 1.0])
317+
plt.ylim([min(dice_list) - 0.02, max(dice_list) + 0.02])
318+
plt.xlabel('Thresholds')
319+
plt.ylabel('Dice')
320+
plt.title('Threshold analysis')
321+
plt.savefig(fname_out)

ivadomed/models.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -798,7 +798,7 @@ def forward(self, x):
798798
out = out[:, 1:, ]
799799
else:
800800
if self.relu_activation:
801-
out = nn.ReLU()(x) / nn.ReLU()(x).max() if bool(nn.ReLU()(x).max()) else nn.ReLU()(x)
801+
out = nn.ReLU()(seg_layer) / nn.ReLU()(seg_layer).max() if bool(nn.ReLU()(seg_layer).max()) else nn.ReLU()(seg_layer)
802802
else:
803803
out = torch.sigmoid(seg_layer)
804804
return out

ivadomed/scripts/automate_training.py

+27-13
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import argparse
1111
import copy
12+
from functools import partial
1213
import json
1314
import logging
1415
import os
@@ -44,11 +45,15 @@ def get_parser():
4445
help="Keep a constant dataset split for all configs and iterations")
4546
parser.add_argument("-l", "--all-logs", dest="all_logs", action='store_true',
4647
help="Keep all log directories for each iteration.")
48+
parser.add_argument('-t', '--thr-increment', dest="thr_increment", required=False, type=float,
49+
help="A threshold analysis is performed at the end of the training using the trained model and "
50+
"the validation sub-dataset to find the optimal binarization threshold. The specified "
51+
"value indicates the increment between 0 and 1 used during the analysis (e.g. 0.1).")
4752

4853
return parser
4954

5055

51-
def train_worker(config):
56+
def train_worker(config, thr_incr):
5257
current = mp.current_process()
5358
# ID of process used to assign a GPU
5459
ID = int(current.name[-1]) - 1
@@ -59,7 +64,8 @@ def train_worker(config):
5964
# Call ivado cmd_train
6065
try:
6166
# Save best validation score
62-
best_training_dice, best_training_loss, best_validation_dice, best_validation_loss = ivado.run_command(config)
67+
best_training_dice, best_training_loss, best_validation_dice, best_validation_loss = \
68+
ivado.run_command(config, thr_increment=thr_incr)
6369

6470
except:
6571
logging.exception('Got exception on main handler')
@@ -74,13 +80,6 @@ def train_worker(config):
7480

7581

7682
def test_worker(config):
77-
current = mp.current_process()
78-
# ID of process used to assign a GPU
79-
ID = int(current.name[-1]) - 1
80-
81-
# Use GPU i from the array specified in the config file
82-
config["gpu"] = config["gpu"][ID]
83-
8483
# Call ivado cmd_eval
8584
try:
8685
# Save best test score
@@ -130,7 +129,8 @@ def make_category(base_item, keys, values, is_all_combin=False):
130129
return items, names
131130

132131

133-
def automate_training(config, param, fixed_split, all_combin, n_iterations=1, run_test=False, all_logs=False):
132+
def automate_training(config, param, fixed_split, all_combin, n_iterations=1, run_test=False, all_logs=False,
133+
thr_increment=None):
134134
"""Automate multiple training processes on multiple GPUs.
135135
136136
Hyperparameter optimization of models is tedious and time-consuming. This function automatizes this optimization
@@ -157,6 +157,9 @@ def automate_training(config, param, fixed_split, all_combin, n_iterations=1, ru
157157
Flag: --n-iteration, -n
158158
run_test (bool): If True, the trained model is also run on the testing subdataset. flag: --run-test
159159
all_logs (bool): If True, all the log directories are kept for every iteration. Flag: --all-logs, -l
160+
thr_increment (float): A threshold analysis is performed at the end of the training using the trained model and
161+
the validation sub-dataset to find the optimal binarization threshold. The specified value indicates the
162+
increment between 0 and 1 used during the ROC analysis (e.g. 0.1). Flag: -t, --thr-increment
160163
"""
161164
# Load initial config
162165
with open(config, "r") as fhandle:
@@ -240,12 +243,13 @@ def automate_training(config, param, fixed_split, all_combin, n_iterations=1, ru
240243
"_n=" + str(i).zfill(2))
241244
else:
242245
config["log_directory"] += "_n=" + str(i).zfill(2)
243-
validation_scores = pool.map(train_worker, config_list)
246+
validation_scores = pool.map(partial(train_worker, thr_incr=thr_increment), config_list)
244247
val_df = pd.DataFrame(validation_scores, columns=[
245248
'log_directory', 'best_training_dice', 'best_training_loss', 'best_validation_dice',
246249
'best_validation_loss'])
247250

248251
if run_test:
252+
new_config_list = []
249253
for config in config_list:
250254
# Delete path_pred
251255
path_pred = os.path.join(config['log_directory'], 'pred_masks')
@@ -255,7 +259,13 @@ def automate_training(config, param, fixed_split, all_combin, n_iterations=1, ru
255259
except OSError as e:
256260
print("Error: %s - %s." % (e.filename, e.strerror))
257261

258-
test_results = pool.map(test_worker, config_list)
262+
# Take the config file within the log_directory because binarize_prediction may have been updated
263+
json_path = os.path.join(config['log_directory'], 'config_file.json')
264+
with open(json_path) as f:
265+
config = json.load(f)
266+
new_config_list.append(config)
267+
268+
test_results = pool.map(test_worker, new_config_list)
259269

260270
df_lst = []
261271
# Merge all eval df together to have a single excel file
@@ -318,9 +328,13 @@ def automate_training(config, param, fixed_split, all_combin, n_iterations=1, ru
318328
def main():
319329
parser = get_parser()
320330
args = parser.parse_args()
331+
332+
# Get thr increment if available
333+
thr_increment = args.thr_increment if args.thr_increment else None
334+
321335
# Run automate training
322336
automate_training(args.config, args.params, bool(args.fixed_split), bool(args.all_combin), int(args.n_iterations),
323-
bool(args.run_test), args.all_logs)
337+
bool(args.run_test), args.all_logs, thr_increment)
324338

325339

326340
if __name__ == '__main__':

ivadomed/testing.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -175,15 +175,17 @@ def run_inference(test_loader, model, model_params, testing_params, ofolder, cud
175175
fname_out=fname_pred,
176176
slice_axis=slice_axis,
177177
kernel_dim='2d',
178-
bin_thr=0.9 if testing_params["binarize_prediction"] else -1)
178+
bin_thr=testing_params["binarize_prediction"])
179179
# TODO: Adapt to multilabel
180-
preds_npy_list.append(output_nii.get_fdata()[:, :, :, 0])
180+
output_data = output_nii.get_fdata()[:, :, :, 0]
181+
preds_npy_list.append(output_data)
182+
181183
gt_npy_list.append(nib.load(fname_tmp).get_fdata())
182184

183185
output_nii_shape = output_nii.get_fdata().shape
184186
if len(output_nii_shape) == 4 and output_nii_shape[-1] > 1:
185187
imed_utils.save_color_labels(np.stack(pred_tmp_lst, -1),
186-
testing_params["binarize_prediction"],
188+
testing_params["binarize_prediction"] > 0,
187189
fname_tmp,
188190
fname_pred.split(".nii.gz")[0] + '_color.nii.gz',
189191
imed_utils.AXIS_DCT[testing_params['slice_axis']])
@@ -221,8 +223,10 @@ def run_inference(test_loader, model, model_params, testing_params, ofolder, cud
221223
fname_out=fname_pred,
222224
slice_axis=slice_axis,
223225
kernel_dim='3d',
224-
bin_thr=0.5 if testing_params["binarize_prediction"] else -1)
225-
preds_npy_list.append(output_nii.get_fdata().transpose(3, 0, 1, 2))
226+
bin_thr=testing_params["binarize_prediction"])
227+
output_data = output_nii.get_fdata().transpose(3, 0, 1, 2)
228+
preds_npy_list.append(output_data)
229+
226230
gt_lst = []
227231
for gt in metadata[0]['gt_filenames']:
228232
# For multi-label, if all labels are not in every image
@@ -236,7 +240,7 @@ def run_inference(test_loader, model, model_params, testing_params, ofolder, cud
236240

237241
if pred_undo.shape[0] > 1:
238242
imed_utils.save_color_labels(pred_undo,
239-
testing_params['binarize_prediction'],
243+
testing_params['binarize_prediction'] > 0,
240244
batch['input_metadata'][smp_idx][0]['input_filenames'],
241245
fname_pred.split(".nii.gz")[0] + '_color.nii.gz',
242246
slice_axis)

0 commit comments

Comments
 (0)