Skip to content

Commit

Permalink
ensure 2d rgb cases are picked up correctly and all aux cases are han…
Browse files Browse the repository at this point in the history
…dled
  • Loading branch information
sarthakpati committed Jun 13, 2023
1 parent 590b137 commit 832a998
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion GANDLF/utils/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,4 +524,12 @@ def get_image_from_tensor(input_tensor: torch.Tensor) -> sitk.Image:
Returns:
sitk.Image: The converted sitk image.
"""
return sitk.GetImageFromArray(input_tensor.cpu().numpy())
arr = input_tensor.cpu().numpy()
# this is specifically the case for 2D rgb images
return_image = sitk.GetImageFromArray(arr)
if arr.shape[1] == 3:
return_image = sitk.GetImageFromArray(arr)
elif arr.shape[0] == 1:
return_image = sitk.GetImageFromArray(arr[0])

return return_image

0 comments on commit 832a998

Please sign in to comment.