Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
FabianIsensee committed Aug 26, 2021
1 parent 35622ba commit 5209f92
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
2 changes: 1 addition & 1 deletion batchgenerators/augmentations/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ def resize_multichannel_image(multichannel_image, new_shape, order=3):
new_shp = [multichannel_image.shape[0]] + list(new_shape)
result = np.zeros(new_shp, dtype=multichannel_image.dtype)
for i in range(multichannel_image.shape[0]):
result[i] = resize(multichannel_image[i].astype(float), new_shape, order, "edge", clip=True, anti_aliasing=False)
result[i] = resize(multichannel_image[i].astype(float), new_shape, order, clip=True, anti_aliasing=False)
return result.astype(tpe)


Expand Down
19 changes: 15 additions & 4 deletions tests/test_spatial_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,30 @@ class AugmentResize(unittest.TestCase):

def setUp(self):
np.random.seed(123)
self.data_3D = np.random.random((2, 4, 5, 6))
self.data_3D = np.random.random((2, 12, 14, 31))
self.seg_3D = np.random.random(self.data_3D.shape)

def test_resize(self):
data_resized, seg_resized = augment_resize(self.data_3D, self.seg_3D, target_size=2)
data_resized, seg_resized = augment_resize(self.data_3D, self.seg_3D, target_size=15)

mean_resized = float(np.mean(data_resized))
mean_original = float(np.mean(self.data_3D))

self.assertAlmostEqual(mean_original, mean_resized, places=2)

self.assertTrue(all((data_resized.shape[i] == 15 and seg_resized.shape[i] == 15) for i in
range(1, len(data_resized.shape))))

def test_resize2(self):
data_resized, seg_resized = augment_resize(self.data_3D, self.seg_3D, target_size=(7, 5, 6))

mean_resized = float(np.mean(data_resized))
mean_original = float(np.mean(self.data_3D))

self.assertAlmostEqual(mean_original, mean_resized, places=2)

self.assertTrue(all((data_resized.shape[i] == 2 and seg_resized.shape[i] == 2) for i in
range(len(data_resized.shape))))
self.assertTrue(all([i == j for i, j in zip(data_resized.shape[1:], (7, 5, 6))]))
self.assertTrue(all([i == j for i, j in zip(seg_resized.shape[1:], (7, 5, 6))]))


class AugmentRot90(unittest.TestCase):
Expand Down

0 comments on commit 5209f92

Please sign in to comment.