forked from sbelharbi/wsol-min-max-entropy-interpretability
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinstantiators.py
330 lines (265 loc) · 15.6 KB
/
instantiators.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
from collections import Sequence
import warnings
from torch.optim import SGD
from torch.optim import Adam
import torch.optim.lr_scheduler as lr_scheduler
from deepmil import models, criteria
from deepmil import lr_scheduler as my_lr_scheduler
from tools import Dict2Obj, count_nb_params
import loader
import prepocess_offline
import stain_tools.stain_augmentor as stain_augmentors
def instantiate_train_loss(args):
"""
Instantiate the train loss.
:param args: object. Contains the configuration of the exp that has been read from the yaml file.
:return: train_loss: instance of deepmil.criteria.TotalLoss()
"""
return criteria.TotalLoss()
def instantiate_eval_loss(args):
"""
Instantiate the evaluation (test phase) loss.
:param args: object. Contains the configuration of the exp that has been read from the yaml file.
:return: eval_loss: instance of deepmil.criteria.TotalLossEval()
"""
return criteria.TotalLossEval()
def instantiate_models(args):
"""Instantiate the necessary models.
Input:
args: object. Contains the configuration of the exp that has been read from the yaml file.
Output:
segmentor: instance of module from deepmil.representation; Embeds the instance.
classifier: instance of module from deepmil.decision_pooling; pools the score of each class.
"""
p = Dict2Obj(args.model)
model = models.__dict__[p.name](pretrained=p.pretrained, num_masks=p.num_masks,
sigma=p.sigma, w=p.w, num_classes=p.num_classes, scale=p.scale,
modalities=p.modalities, kmax=p.kmax, kmin=p.kmin, alpha=p.alpha,
dropout=p.dropout, nbr_times_erase=args.nbr_times_erase,
sigma_erase=args.sigma_erase)
print("Mi-max entropy model `{}` was successfully instantiated. Nbr.params: {} .... [OK]".format(
model.__class__.__name__, count_nb_params(model)))
return model
def instantiate_optimizer(args, model):
"""Instantiate an optimizer.
Input:
args: object. Contains the configuration of the exp that has been read from the yaml file.
mode: a pytorch model with parameters.
Output:
optimizer: a pytorch optimizer.
lrate_scheduler: a pytorch learning rate scheduler (or None).
"""
if args.optimizer["name"] == "sgd":
optimizer = SGD(model.parameters(), lr=args.optimizer["lr"], momentum=args.optimizer["momentum"],
dampening=args.optimizer["dampening"], weight_decay=args.optimizer["weight_decay"],
nesterov=args.optimizer["nesterov"])
elif args.optimizer["name"] == "adam":
optimizer = Adam(params=model.parameters(), lr=args.optimizer["lr"], betas=args.optimizer["betas"],
eps=args.optimizer["eps"], weight_decay=args.optimizer["weight_decay"],
amsgrad=args.optimizer["amsgrad"])
else:
raise ValueError("Unsupported optimizer `{}` .... [NOT OK]".format(args.optimizer["name"]))
print("Optimizer `{}` was successfully instantiated .... [OK]".format([key + ":" + str(args.optimizer[key]) for
key in args.optimizer.keys()]))
if args.optimizer["lr_scheduler"]:
if args.optimizer["lr_scheduler"]["name"] == "step":
lr_scheduler_ = args.optimizer["lr_scheduler"]
lrate_scheduler = lr_scheduler.StepLR(optimizer,
step_size=lr_scheduler_["step_size"],
gamma=lr_scheduler_["gamma"],
last_epoch=lr_scheduler_["last_epoch"])
print("Learning scheduler `{}` was successfully instantiated .... [OK]".format(
[key + ":" + str(lr_scheduler_[key]) for key in lr_scheduler_.keys()]
))
elif args.optimizer["lr_scheduler"]["name"] == "mystep":
lr_scheduler_ = args.optimizer["lr_scheduler"]
lrate_scheduler = my_lr_scheduler.MyStepLR(optimizer,
step_size=lr_scheduler_["step_size"],
gamma=lr_scheduler_["gamma"],
last_epoch=lr_scheduler_["last_epoch"],
min_lr=lr_scheduler_["min_lr"])
print("Learning scheduler `{}` was successfully instantiated .... [OK]".format(
[key + ":" + str(lr_scheduler_[key]) for key in lr_scheduler_.keys()]
))
elif args.optimizer["lr_scheduler"]["name"] == "multistep":
lr_scheduler_ = args.optimizer["lr_scheduler"]
lrate_scheduler = lr_scheduler.MultiStepLR(optimizer,
milestones=lr_scheduler_["milestones"],
gamma=lr_scheduler_["gamma"],
last_epoch=lr_scheduler_["last_epoch"])
print("Learning scheduler `{}` was successfully instantiated .... [OK]".format(
[key + ":" + str(lr_scheduler_[key]) for key in lr_scheduler_.keys()]
))
else:
raise ValueError("Unsupported learning rate scheduler `{}` .... [NOT OK]".format(
args.optimizer["lr_scheduler"]["name"]))
else:
lrate_scheduler = None
return optimizer, lrate_scheduler
def instantiate_preprocessor(args):
"""
Instantiate a preprocessor class from preprocess_offline.
:param args: object. Contains the configuration of the exp that has been read from the yaml file.
:return: an instance of a preprocessor.
"""
if args.preprocessor:
if args.preprocessor["name"] == "Preprocessor":
if "stain" in args.preprocessor.keys():
stain = Dict2Obj(args.preprocessor["stain"])
name_classes = args.name_classes
preprocessor = prepocess_offline.__dict__["Preprocessor"](stain, name_classes)
print(
"Preprocessor `{}` was successfully instantiated with the stain preprocessing ON .... [OK]".format(
args.preprocessor["name"])
)
return preprocessor
else:
raise ValueError("Unknown preprocessing operation .... [NOT OK]")
else:
raise ValueError("Unsupported preprocessor `{}` .... [NOT OK]".format(args.preprocessor["name"]))
else:
print("Proceeding WITHOUT preprocessor .... [OK]")
return None
def instantiate_patch_splitter(args, deterministic=True):
"""
Instantiate the patch splitter and its relevant instances.
For every set.
However, for train, determninistic is set to False to allow dropout over the patches IF requiested.
Over valid an test sets, deterministic is True.
:param args: object. Contains the configuration of the exp that has been read from the yaml file.
:param deterministic: True/False. If True, dropping some samples will be allowed IF it was requested. Should set
to True only with the train set.
:return: an instance of a patch splitter.
"""
assert args.patch_splitter is not None, "We need a patch splitter, and you didn't specify one! .... [NOT OK]"
patch_splitter_conf = Dict2Obj(args.patch_splitter)
random_cropper = Dict2Obj(args.random_cropper)
if patch_splitter_conf.name == "PatchSplitter":
keep = 1. # default value for deterministic scenario: keep all patch (evaluation phase).
if not deterministic:
keep = patch_splitter_conf.keep
h = patch_splitter_conf.h
w = patch_splitter_conf.w
h_ = patch_splitter_conf.h_
w_ = patch_splitter_conf.w_
# Instantiate the patch transforms if there is any.
patch_transform = None
if patch_splitter_conf.patch_transform:
error_msg = "We support only one or none patch transform for now ... [NOT OK]"
assert not isinstance(patch_splitter_conf.patch_transform, Sequence), error_msg
patch_transform_config = Dict2Obj(patch_splitter_conf.patch_transform)
if patch_transform_config.name == "PseudoFoveation":
scale_factor = patch_transform_config.scale_factor
int_eps = patch_transform_config.int_eps
num_workers = patch_transform_config.num_workers
patch_transform = loader.__dict__["PseudoFoveation"](h, w, h_, w_, scale_factor, int_eps, num_workers)
print(
"Patch transform `{}` was successfully instantiated WITHIN a patch splitter `{}`"
"with `{}` workers.... [OK]".format(
patch_transform_config.name, patch_splitter_conf.name, num_workers)
)
elif patch_transform_config.name == "FastApproximationPseudoFoveation":
scale_factor = patch_transform_config.scale_factor
int_eps = patch_transform_config.int_eps
nbr_kernels = patch_transform_config.nbr_kernels
use_gpu = patch_transform_config.use_gpu
gpu_id = patch_transform_config.gpu_id
if gpu_id is None:
gpu_id = int(args.cudaid)
warnings.warn("You didn't specify the CUDA device ID to run `FastApproximationPseudoFoveation`. "
"We set it up to the same device where the model will be run `cuda:{}` .... [NOT "
"OK]".format(args.cudaid))
assert args.num_workers in [0, 1], "'config.num_workers' must be in {0, " \
"1} if loader.FastApproximationPseudoFoveation() is used. " \
"Multiprocessing does not play well when Dataloader has uses also " \
"multiprocessing .... [NOT OK]"
patch_transform = loader.__dict__["FastApproximationPseudoFoveation"](
h, w, h_, w_, scale_factor, int_eps, nbr_kernels, use_gpu, gpu_id
)
print(
"Patch transform `{}` was successfully instantiated WITHIN a patch splitter `{}` "
"with `{}` kernels with `{}` GPU and CUDA ID `{}` .... [OK]".format(
patch_transform_config.name, patch_splitter_conf.name, nbr_kernels, use_gpu, gpu_id)
)
else:
raise ValueError("Unsupported patch transform `{}` .... [NOT OK]".format(patch_transform_config.name))
else:
print("Proceeding WITHOUT any patch transform ..... [OK]")
if patch_transform:
patch_transform = [patch_transform]
padding_mode = patch_splitter_conf.padding_mode
assert hasattr(random_cropper, "make_cropped_perfect_for_split"), "The random cropper `{}` does not have the " \
"attribute `make_cropped_perfect_for_split`" \
"which we expect .... [NO OK]".format(
random_cropper.name)
if random_cropper.make_cropped_perfect_for_split and not deterministic:
padding_mode = None
patch_splitter = loader.__dict__["PatchSplitter"](
h, w, h_, w_, padding_mode, patch_transforms=patch_transform, keep=keep
)
print("Patch splitter `{}` was successfully instantiated .... [OK]".format(patch_splitter_conf.name))
else:
raise ValueError("Unsupported patch splitter `{}` .... [NOT OK]".format(patch_splitter_conf.name))
return patch_splitter
def instantiate_stain_augmentor(args):
"""
Instantiate the stain augmentor.
The possible classes are located in stain_tools.stain_augmentor.
:param args: object. Contains the configuration of the exp that has been read from the yaml file.
:return: an instance of stain augmentor, or None.
"""
if args.stain_augmentor:
error_msg = "You requested stain augmentation, but there was no stain normalization. It seems inconsistent." \
"Modify the code in order to accept a stain augmentation without stain normalization. Stain " \
"extraction is time consuming. To augment the stains, we use the same reference stain in the" \
"stain normalization phase. If you want to stain augmentation anyway, you need to provide a" \
"stain matrix because stain extration takes about 15 to 25 seconds per H&E high image of size" \
"hxw: ~1500x2000."
assert "stain" in args.preprocessor.keys(), error_msg
method = args.preprocessor["stain"]["method"]
s_augmentor_config = Dict2Obj(args.stain_augmentor)
if s_augmentor_config.name == "StainAugmentor":
sigma1 = s_augmentor_config.sigma1
sigma2 = s_augmentor_config.sigma2
augment_background = s_augmentor_config.augment_background
stain_augmentor = stain_augmentors.__dict__["StainAugmentor"](method, sigma1, sigma2, augment_background)
print("Stain augmentor `{}` was successfully instantiated .... [OK]".format(s_augmentor_config.name))
return stain_augmentor
else:
raise ValueError("Unsupported stain augmentor name `{}` .... [NOT OK]".format(s_augmentor_config.name))
else:
print("Proceeding WITHOUT stain augmentation .... [OK]")
return None
def instantiante_random_cropper(args):
"""
Instantiate a random cropper. It is used for sampling su-images from an original image in the train set.
Classes are located in loader.*
:param args: object. Contains the configuration of the exp that has been read from the yaml file.
:return: an instance of a random cropper, or None.
"""
if args.random_cropper:
r_cropper_config = Dict2Obj(args.random_cropper)
patch_splitter_config = Dict2Obj(args.patch_splitter)
if r_cropper_config.name == "RandomCropper":
min_height = r_cropper_config.min_height
min_width = r_cropper_config.min_width
max_height = r_cropper_config.max_height
max_width = r_cropper_config.max_width
make_cropped_perfect_for_split = r_cropper_config.make_cropped_perfect_for_split
h, w, h_, w_ = None, None, None, None
if make_cropped_perfect_for_split:
assert patch_splitter_config.name == "PatchSplitter", "We expected the class `PatchSplitter`" \
"but found `{}` .... [NOT OK]".format(
patch_splitter_config.name)
h = patch_splitter_config.h
w = patch_splitter_config.w
h_ = patch_splitter_config.h_
w_ = patch_splitter_config.w_
random_cropper = loader.__dict__["RandomCropper"](
min_height, min_width, max_height, max_width, make_cropped_perfect_for_split, h, w, h_, w_)
print("Random cropper `{}` was successfully instantiated .... [OK]".format(r_cropper_config.name))
return random_cropper
else:
raise ValueError("Unsuppoerted random cropper `{}` .... [NOT OK]".format(r_cropper_config.name))
else:
return None