Skip to content

Commit

Permalink
added better checks
Browse files Browse the repository at this point in the history
  • Loading branch information
sarthakpati committed Jun 13, 2023
1 parent d2b39b1 commit 5ff57b9
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions GANDLF/utils/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,10 +526,8 @@ def get_image_from_tensor(input_tensor: torch.Tensor) -> sitk.Image:
"""
arr = input_tensor.cpu().numpy()
return_image = sitk.GetImageFromArray(arr)
# this is specifically the case for 2D rgb images
if arr.shape[1] == 3:
return_image = sitk.GetImageFromArray(arr)
elif arr.shape[0] == 1:
# this is specifically the case for 3D images
if (arr.shape[0] == 1) and (arr.shape[1] > 3):
return_image = sitk.GetImageFromArray(arr[0])

return return_image

0 comments on commit 5ff57b9

Please sign in to comment.