Skip to content

Commit

Permalink
ENH: Improve logic for infering render type
Browse files Browse the repository at this point in the history
Assume data with two dimensions and shape (N, 2) or (N, 3) is a point set, all
other data is assumed to be an image.
  • Loading branch information
bnmajor committed Jan 26, 2023
1 parent 45dc683 commit a82c678
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 22 deletions.
8 changes: 0 additions & 8 deletions itkwidgets/_initialization_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,3 @@ def init_params_dict(itk_viewer):
'y_slice': itk_viewer.setYSlice,
'z_slice': itk_viewer.setZSlice,
}

def init_key_aliases():
return {
'data': 'image',
'image': 'image',
'label_image': 'labelImage',
'point_set': 'pointSets',
}
19 changes: 9 additions & 10 deletions itkwidgets/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ def _get_viewer_point_set(point_set):


def _detect_render_type(data, input_type) -> RenderType:
if input_type == 'image' or input_type == 'label_image':
return RenderType.IMAGE
elif input_type == 'point_set':
return RenderType.POINT_SET
if isinstance(data, itkwasm.Image):
return RenderType.IMAGE
elif isinstance(data, itkwasm.PointSet):
Expand All @@ -164,12 +168,7 @@ def _detect_render_type(data, input_type) -> RenderType:
# We may need to do more introspection
return RenderType.IMAGE
elif isinstance(data, np.ndarray):
if input_type == 'point_set':
return RenderType.POINT_SET
else:
return RenderType.IMAGE
elif isinstance(data, zarr.Group):
if input_type == 'point_set':
if data.ndim == 2 and data.shape[1] < 4:
return RenderType.POINT_SET
else:
return RenderType.IMAGE
Expand All @@ -190,26 +189,26 @@ def _detect_render_type(data, input_type) -> RenderType:
elif isinstance(data, vtk.vtkPolyData):
return RenderType.POINT_SET
if isinstance(data, dask.array.core.Array):
if input_type == 'point_set':
if data.ndim ==2 and data.shape[1] < 4:
return RenderType.POINT_SET
else:
return RenderType.IMAGE
if HAVE_TORCH:
import torch
if isinstance(data, torch.Tensor):
if input_type == 'point_set':
if data.dim == 2 and data.shape[1] < 4:
return RenderType.POINT_SET
else:
return RenderType.IMAGE
if HAVE_XARRAY:
import xarray as xr
if isinstance(data, xr.DataArray):
if input_type == 'point_set':
if data.dims == 2 and data.shape[1] < 4:
return RenderType.POINT_SET
else:
return RenderType.IMAGE
if isinstance(data, xr.Dataset):
if input_type == 'point_set':
if data.dims == 2 and data.shape[1] < 4:
return RenderType.POINT_SET
else:
return RenderType.IMAGE
3 changes: 2 additions & 1 deletion itkwidgets/render_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
class RenderType(Enum):
"""Rendered data types"""
IMAGE = "image"
LABELIMAGE = "labelImage"
GEOMETRY = "geometry"
POINT_SET = "point_set"
POINT_SET = "pointSets"
6 changes: 3 additions & 3 deletions itkwidgets/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import uuid

from ._type_aliases import Gaussians, Style, Image, Point_Set
from ._initialization_params import init_params_dict, init_key_aliases
from ._initialization_params import init_params_dict
from ._method_types import deferred_methods
from .integrations import _detect_render_type, _get_viewer_image, _get_viewer_point_set
from .integrations.environment import ENVIRONMENT, Env
Expand Down Expand Up @@ -169,17 +169,17 @@ def init_data(self, input_data):
result= None
for (input_type, data) in input_data:
render_type = _detect_render_type(data, input_type)
key = init_key_aliases()[input_type]
if render_type is RenderType.IMAGE:
if input_type == 'label_image':
result = _get_viewer_image(data, label=True)
render_type = RenderType.LABELIMAGE
else:
result = _get_viewer_image(data, label=False)
elif render_type is RenderType.POINT_SET:
result = _get_viewer_point_set(data)
if result is None:
raise RuntimeError(f"Could not process the viewer {input_type}")
_init_data[key] = result
_init_data[render_type.value] = result
return _init_data

async def run_queued_requests(self):
Expand Down

0 comments on commit a82c678

Please sign in to comment.