Skip to content

Commit

Permalink
fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
FabianIsensee committed Aug 26, 2021
1 parent 4df0ca5 commit 7c5c557
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 13 deletions.
8 changes: 4 additions & 4 deletions batchgenerators/augmentations/spatial_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def augment_rot90(sample_data, sample_seg, num_rot=(1, 2, 3), axes=(0, 1, 2)):
return sample_data, sample_seg


def augment_resize(sample_data, sample_seg, target_size, order=3, order_seg=1, cval_seg=0):
def augment_resize(sample_data, sample_seg, target_size, order=3, order_seg=1):
"""
Reshapes data (and seg) to target_size
:param sample_data: np.ndarray or list/tuple of np.ndarrays, must be (c, x, y(, z))) (if list/tuple then each entry
Expand All @@ -69,14 +69,14 @@ def augment_resize(sample_data, sample_seg, target_size, order=3, order_seg=1, c
if sample_seg is not None:
target_seg = np.ones([sample_seg.shape[0]] + target_size_here)
for c in range(sample_seg.shape[0]):
target_seg[c] = resize_segmentation(sample_seg[c], target_size_here, order_seg, cval_seg)
target_seg[c] = resize_segmentation(sample_seg[c], target_size_here, order_seg)
else:
target_seg = None

return sample_data, target_seg


def augment_zoom(sample_data, sample_seg, zoom_factors, order=3, order_seg=1, cval_seg=0):
def augment_zoom(sample_data, sample_seg, zoom_factors, order=3, order_seg=1):
"""
zooms data (and seg) by factor zoom_factors
:param sample_data: np.ndarray or list/tuple of np.ndarrays, must be (c, x, y(, z))) (if list/tuple then each entry
Expand Down Expand Up @@ -105,7 +105,7 @@ def augment_zoom(sample_data, sample_seg, zoom_factors, order=3, order_seg=1, cv
if sample_seg is not None:
target_seg = np.ones([sample_seg.shape[0]] + target_shape_here)
for c in range(sample_seg.shape[0]):
target_seg[c] = resize_segmentation(sample_seg[c], target_shape_here, order_seg, cval_seg)
target_seg[c] = resize_segmentation(sample_seg[c], target_shape_here, order_seg)
else:
target_seg = None

Expand Down
12 changes: 4 additions & 8 deletions batchgenerators/transforms/spatial_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __call__(self, **data_dict):


class ZoomTransform(AbstractTransform):
def __init__(self, zoom_factors=1, order=3, order_seg=1, cval_seg=0, concatenate_list=False, data_key="data",
def __init__(self, zoom_factors=1, order=3, order_seg=1, concatenate_list=False, data_key="data",
label_key="seg"):
"""
Zooms 'data' (and 'seg') by zoom_factors
Expand All @@ -75,7 +75,6 @@ def __init__(self, zoom_factors=1, order=3, order_seg=1, cval_seg=0, concatenate
"""
self.concatenate_list = concatenate_list
self.cval_seg = cval_seg
self.order_seg = order_seg
self.data_key = data_key
self.label_key = label_key
Expand Down Expand Up @@ -104,8 +103,7 @@ def __call__(self, **data_dict):
sample_seg = None
if seg is not None:
sample_seg = seg[b]
res_data, res_seg = augment_zoom(data[b], sample_seg, self.zoom_factors, self.order, self.order_seg,
self.cval_seg)
res_data, res_seg = augment_zoom(data[b], sample_seg, self.zoom_factors, self.order, self.order_seg)
results.append((res_data, res_seg))

if concatenate:
Expand All @@ -122,7 +120,7 @@ def __call__(self, **data_dict):

class ResizeTransform(AbstractTransform):

def __init__(self, target_size, order=3, order_seg=1, cval_seg=0, concatenate_list=False, data_key="data",
def __init__(self, target_size, order=3, order_seg=1, concatenate_list=False, data_key="data",
label_key="seg"):
"""
Reshapes 'data' (and 'seg') to target_size
Expand All @@ -139,7 +137,6 @@ def __init__(self, target_size, order=3, order_seg=1, cval_seg=0, concatenate_li
"""
self.concatenate_list = concatenate_list
self.cval_seg = cval_seg
self.order_seg = order_seg
self.data_key = data_key
self.label_key = label_key
Expand Down Expand Up @@ -168,8 +165,7 @@ def __call__(self, **data_dict):
sample_seg = None
if seg is not None:
sample_seg = seg[b]
res_data, res_seg = augment_resize(data[b], sample_seg, self.target_size, self.order, self.order_seg,
self.cval_seg)
res_data, res_seg = augment_resize(data[b], sample_seg, self.target_size, self.order, self.order_seg)
results.append((res_data, res_seg))

if concatenate:
Expand Down
2 changes: 1 addition & 1 deletion batchgenerators/transforms/utility_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def __call__(self, **data_dict):
return self.list_of_transforms[i](**data_dict)


class OneOfTransform_perSample(AbstractTransform):
class OneOfTransformPerSample(AbstractTransform):
def __init__(self, list_of_transforms: List, relevant_keys: Union[Tuple[str, ...], List[str]],
p: Tuple[float, ...] = None):
"""
Expand Down

0 comments on commit 7c5c557

Please sign in to comment.