forked from MIC-DKFZ/batchgenerators
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_multithreaded_augmenter.py
222 lines (179 loc) · 7.42 KB
/
test_multithreaded_augmenter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
# Copyright 2021 Division of Medical Image Computing, German Cancer Research Center (DKFZ)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from time import sleep
import numpy as np
from batchgenerators.dataloading.data_loader import SlimDataLoaderBase
from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
from batchgenerators.examples.multithreaded_dataloading import DummyDL, DummyDLWithShuffle
from batchgenerators.transforms.abstract_transforms import Compose
from batchgenerators.transforms.spatial_transforms import MirrorTransform, TransposeAxesTransform
from batchgenerators.transforms.utility_transforms import NumpyToTensor
from skimage.data import camera, checkerboard, astronaut, binary_blobs, coins
from skimage.transform import resize
from copy import deepcopy
class DummyDL2DImage(SlimDataLoaderBase):
def __init__(self, batch_size, num_threads=8):
data = []
target_shape = (224, 224)
c = camera()
c = resize(c.astype(np.float64), target_shape, 1, anti_aliasing=False, clip=True, mode='reflect').astype(np.float32)
data.append(c[None])
c = checkerboard()
c = resize(c.astype(np.float64), target_shape, 1, anti_aliasing=False, clip=True, mode='reflect').astype(np.float32)
data.append(c[None])
c = astronaut().mean(-1)
c = resize(c.astype(np.float64), target_shape, 1, anti_aliasing=False, clip=True, mode='reflect').astype(np.float32)
data.append(c[None])
c = binary_blobs()
c = resize(c.astype(np.float64), target_shape, 1, anti_aliasing=False, clip=True, mode='reflect').astype(np.float32)
data.append(c[None])
c = coins()
c = resize(c.astype(np.float64), target_shape, 1, anti_aliasing=False, clip=True, mode='reflect').astype(np.float32)
data.append(c[None])
data = np.stack(data)
super(DummyDL2DImage, self).__init__(data, batch_size, num_threads)
def generate_train_batch(self):
idx = np.random.choice(len(self._data), self.batch_size)
res = []
for i in idx:
res.append(self._data[i:i+1])
res = np.vstack(res)
return {'data': res}
class TestMultiThreadedAugmenter(unittest.TestCase):
"""
This test is inspired by the multithreaded example I did a while back
"""
def setUp(self):
np.random.seed(1234)
self.num_threads = 4
self.dl = DummyDL(self.num_threads)
self.dl_with_shuffle = DummyDLWithShuffle(self.num_threads)
self.dl_images = DummyDL2DImage(4, self.num_threads)
def test_no_crash(self):
"""
This one should just not crash, that's all
:return:
"""
dl = self.dl_images
mt_dl = MultiThreadedAugmenter(dl, None, self.num_threads, 1, None, False)
for _ in range(20):
_ = mt_dl.next()
def test_DummyDL(self):
"""
DummyDL must return numbers from 0 to 99 in ascending order
:return:
"""
dl = DummyDL(1)
res = []
for i in dl:
res.append(i)
assert len(res) == 100
res_copy = deepcopy(res)
res.sort()
assert all((i == j for i, j in zip(res, res_copy)))
assert all((i == j for i, j in zip(res, np.arange(0, 100))))
def test_order(self):
"""
Coordinating workers in a multiprocessing envrionment is difficult. We want DummyDL in a multithreaded
environment to still give us the numbers from 0 to 99 in ascending order
:return:
"""
dl = self.dl
mt = MultiThreadedAugmenter(dl, None, self.num_threads, 1, None, False)
res = []
for i in mt:
res.append(i)
assert len(res) == 100
res_copy = deepcopy(res)
res.sort()
assert all((i == j for i, j in zip(res, res_copy)))
assert all((i == j for i, j in zip(res, np.arange(0, 100))))
def test_restart_and_order(self):
"""
Coordinating workers in a multiprocessing envrionment is difficult. We want DummyDL in a multithreaded
environment to still give us the numbers from 0 to 99 in ascending order.
We want the MultiThreadedAugmenter to restart and return the same result in each run
:return:
"""
dl = self.dl
mt = MultiThreadedAugmenter(dl, None, self.num_threads, 1, None, False)
res = []
for i in mt:
res.append(i)
assert len(res) == 100
res_copy = deepcopy(res)
res.sort()
assert all((i == j for i, j in zip(res, res_copy)))
assert all((i == j for i, j in zip(res, np.arange(0, 100))))
res = []
for i in mt:
res.append(i)
assert len(res) == 100
res_copy = deepcopy(res)
res.sort()
assert all((i == j for i, j in zip(res, res_copy)))
assert all((i == j for i, j in zip(res, np.arange(0, 100))))
res = []
for i in mt:
res.append(i)
assert len(res) == 100
res_copy = deepcopy(res)
res.sort()
assert all((i == j for i, j in zip(res, res_copy)))
assert all((i == j for i, j in zip(res, np.arange(0, 100))))
def test_image_pipeline_and_pin_memory(self):
'''
This just should not crash
:return:
'''
try:
import torch
except ImportError:
'''dont test if torch is not installed'''
return
tr_transforms = []
tr_transforms.append(MirrorTransform())
tr_transforms.append(TransposeAxesTransform(transpose_any_of_these=(0, 1), p_per_sample=0.5))
tr_transforms.append(NumpyToTensor(keys='data', cast_to='float'))
composed = Compose(tr_transforms)
dl = self.dl_images
mt = MultiThreadedAugmenter(dl, composed, 4, 1, None, True)
for _ in range(50):
res = mt.next()
assert isinstance(res['data'], torch.Tensor)
assert res['data'].is_pinned()
# let mt finish caching, otherwise it's going to print an error (which is not a problem and will not prevent
# the success of the test but it does not look pretty)
sleep(2)
def test_image_pipeline(self):
'''
This just should not crash
:return:
'''
tr_transforms = []
tr_transforms.append(MirrorTransform())
tr_transforms.append(TransposeAxesTransform(transpose_any_of_these=(0, 1), p_per_sample=0.5))
composed = Compose(tr_transforms)
dl = self.dl_images
mt = MultiThreadedAugmenter(dl, composed, 4, 1, None, False)
for _ in range(50):
res = mt.next()
# let mt finish caching, otherwise it's going to print an error (which is not a problem and will not prevent
# the success of the test but it does not look pretty)
sleep(2)
if __name__ == '__main__':
from multiprocessing import freeze_support
freeze_support()
unittest.main()