9
9
10
10
import argparse
11
11
import copy
12
+ from functools import partial
12
13
import json
13
14
import logging
14
15
import os
@@ -44,11 +45,15 @@ def get_parser():
44
45
help = "Keep a constant dataset split for all configs and iterations" )
45
46
parser .add_argument ("-l" , "--all-logs" , dest = "all_logs" , action = 'store_true' ,
46
47
help = "Keep all log directories for each iteration." )
48
+ parser .add_argument ('-t' , '--thr-increment' , dest = "thr_increment" , required = False , type = float ,
49
+ help = "A threshold analysis is performed at the end of the training using the trained model and "
50
+ "the validation sub-dataset to find the optimal binarization threshold. The specified "
51
+ "value indicates the increment between 0 and 1 used during the analysis (e.g. 0.1)." )
47
52
48
53
return parser
49
54
50
55
51
- def train_worker (config ):
56
+ def train_worker (config , thr_incr ):
52
57
current = mp .current_process ()
53
58
# ID of process used to assign a GPU
54
59
ID = int (current .name [- 1 ]) - 1
@@ -59,7 +64,8 @@ def train_worker(config):
59
64
# Call ivado cmd_train
60
65
try :
61
66
# Save best validation score
62
- best_training_dice , best_training_loss , best_validation_dice , best_validation_loss = ivado .run_command (config )
67
+ best_training_dice , best_training_loss , best_validation_dice , best_validation_loss = \
68
+ ivado .run_command (config , thr_increment = thr_incr )
63
69
64
70
except :
65
71
logging .exception ('Got exception on main handler' )
@@ -74,13 +80,6 @@ def train_worker(config):
74
80
75
81
76
82
def test_worker (config ):
77
- current = mp .current_process ()
78
- # ID of process used to assign a GPU
79
- ID = int (current .name [- 1 ]) - 1
80
-
81
- # Use GPU i from the array specified in the config file
82
- config ["gpu" ] = config ["gpu" ][ID ]
83
-
84
83
# Call ivado cmd_eval
85
84
try :
86
85
# Save best test score
@@ -130,7 +129,8 @@ def make_category(base_item, keys, values, is_all_combin=False):
130
129
return items , names
131
130
132
131
133
- def automate_training (config , param , fixed_split , all_combin , n_iterations = 1 , run_test = False , all_logs = False ):
132
+ def automate_training (config , param , fixed_split , all_combin , n_iterations = 1 , run_test = False , all_logs = False ,
133
+ thr_increment = None ):
134
134
"""Automate multiple training processes on multiple GPUs.
135
135
136
136
Hyperparameter optimization of models is tedious and time-consuming. This function automatizes this optimization
@@ -157,6 +157,9 @@ def automate_training(config, param, fixed_split, all_combin, n_iterations=1, ru
157
157
Flag: --n-iteration, -n
158
158
run_test (bool): If True, the trained model is also run on the testing subdataset. flag: --run-test
159
159
all_logs (bool): If True, all the log directories are kept for every iteration. Flag: --all-logs, -l
160
+ thr_increment (float): A threshold analysis is performed at the end of the training using the trained model and
161
+ the validation sub-dataset to find the optimal binarization threshold. The specified value indicates the
162
+ increment between 0 and 1 used during the ROC analysis (e.g. 0.1). Flag: -t, --thr-increment
160
163
"""
161
164
# Load initial config
162
165
with open (config , "r" ) as fhandle :
@@ -240,12 +243,13 @@ def automate_training(config, param, fixed_split, all_combin, n_iterations=1, ru
240
243
"_n=" + str (i ).zfill (2 ))
241
244
else :
242
245
config ["log_directory" ] += "_n=" + str (i ).zfill (2 )
243
- validation_scores = pool .map (train_worker , config_list )
246
+ validation_scores = pool .map (partial ( train_worker , thr_incr = thr_increment ) , config_list )
244
247
val_df = pd .DataFrame (validation_scores , columns = [
245
248
'log_directory' , 'best_training_dice' , 'best_training_loss' , 'best_validation_dice' ,
246
249
'best_validation_loss' ])
247
250
248
251
if run_test :
252
+ new_config_list = []
249
253
for config in config_list :
250
254
# Delete path_pred
251
255
path_pred = os .path .join (config ['log_directory' ], 'pred_masks' )
@@ -255,7 +259,13 @@ def automate_training(config, param, fixed_split, all_combin, n_iterations=1, ru
255
259
except OSError as e :
256
260
print ("Error: %s - %s." % (e .filename , e .strerror ))
257
261
258
- test_results = pool .map (test_worker , config_list )
262
+ # Take the config file within the log_directory because binarize_prediction may have been updated
263
+ json_path = os .path .join (config ['log_directory' ], 'config_file.json' )
264
+ with open (json_path ) as f :
265
+ config = json .load (f )
266
+ new_config_list .append (config )
267
+
268
+ test_results = pool .map (test_worker , new_config_list )
259
269
260
270
df_lst = []
261
271
# Merge all eval df together to have a single excel file
@@ -318,9 +328,13 @@ def automate_training(config, param, fixed_split, all_combin, n_iterations=1, ru
318
328
def main ():
319
329
parser = get_parser ()
320
330
args = parser .parse_args ()
331
+
332
+ # Get thr increment if available
333
+ thr_increment = args .thr_increment if args .thr_increment else None
334
+
321
335
# Run automate training
322
336
automate_training (args .config , args .params , bool (args .fixed_split ), bool (args .all_combin ), int (args .n_iterations ),
323
- bool (args .run_test ), args .all_logs )
337
+ bool (args .run_test ), args .all_logs , thr_increment )
324
338
325
339
326
340
if __name__ == '__main__' :
0 commit comments