forked from TorchIO-project/torchio
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathimages.py
105 lines (97 loc) · 3.65 KB
/
images.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from pathlib import Path
from collections.abc import Sequence
from torch.utils.data import Dataset
from ..utils import get_stem
from ..io import read_image, write_image
class ImagesDataset(Dataset):
def __init__(
self,
subjects_list,
transform=None,
verbose=False,
):
"""
Each element of subjects_list is a dictionary:
subject = {
'one_image': dict(
path=path_to_one_image,
type=torchio.INTENSITY,
),
'another_image': dict(
path=path_to_another_image,
type=torchio.INTENSITY,
),
'a_label': dict(
path=path_to_a_label,
type=torchio.LABEL,
),
}
See examples/example_multimodal.py for -obviously- an example.
"""
self.parse_subjects_list(subjects_list)
self.subjects_list = subjects_list
self.transform = transform
self.verbose = verbose
def __len__(self):
return len(self.subjects_list)
def __getitem__(self, index):
subject_dict = self.subjects_list[index]
sample = {}
for image_name, image_dict in subject_dict.items():
image_path = image_dict['path']
tensor, affine = self.load_image(image_path)
image_sample_dict = dict(
data=tensor,
path=str(image_path),
affine=affine,
stem=get_stem(image_path),
type=image_dict['type'],
)
sample[image_name] = image_sample_dict
# Apply transform (this is usually the bottleneck)
if self.transform is not None:
sample = self.transform(sample)
return sample
def load_image(self, path):
if self.verbose:
print(f'Loading {path}...')
tensor, affine = read_image(path)
if self.verbose:
print(f'Loaded array with shape {tensor.shape}')
return tensor, affine
@staticmethod
def parse_subjects_list(subjects_list):
def parse_path(path):
path = Path(path).expanduser()
if not path.is_file():
raise FileNotFoundError(f'{path} not found')
if not isinstance(subjects_list, Sequence):
raise TypeError(
f'Subject list must be a sequence, not {type(subjects_list)}')
if not subjects_list:
raise ValueError('Subjects list is empty')
for element in subjects_list:
if not isinstance(element, dict):
raise TypeError(
f'All elements must be dictionaries, not {type(element)}')
if not element:
raise ValueError(f'Element seems empty: {element}')
subject_dict = element
for image_name, image_dict in subject_dict.items():
if not isinstance(image_dict, dict):
raise TypeError(
f'Type {type(image_dict)} found for {image_name},'
' instead of type dict'
)
for key in ('path', 'type'):
if key not in image_dict:
raise KeyError(
f'"{key}" key not found'
f' in image dict {image_dict}')
parse_path(image_dict['path'])
@staticmethod
def save_sample(sample, output_paths_dict):
for key, output_path in output_paths_dict.items():
tensor = sample[key]['data']
affine = sample[key]['affine']
write_image(tensor, affine, output_path)