-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathinference.py
151 lines (105 loc) · 5.67 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import torch
from nnunet_mednext import create_mednext_v1
import data_loader
import yaml
import argparse
import os
import pdb
import numpy as np
'''
--------------------------------- Attention !!! -----------------------------------------
This script is only provide the example how the inference can be run.
Participants may need to modify the script or/and parameters to get resonable/good results (e.g., change the in_size, out_size etc.).
The sample cases we used here are actually from the train or valid split.
For the challenge, the train/validation/test splits are mutual excluded. The final ranking should be run on the test split.
'''
def offset_spatial_crop(roi_center=None, roi_size=None):
"""
for crop spatial regions of the data based on the specified `roi_center` and `roi_size`.
get the start and end of the crop
Parameters:
roi_center (tuple of int, optional): The center point of the region of interest (ROI).
roi_size (tuple of int, optional): The size of the ROI in each spatial dimension.
Returns:
start & end: start and end offsets
"""
if roi_center is None or roi_size is None:
raise ValueError("Both `roi_center` and `roi_size` must be specified.")
roi_center = [int(round(c)) for c in roi_center]
roi_size = [int(round(s)) for s in roi_size]
start = []
end = []
for i, (center, size) in enumerate(zip(roi_center, roi_size)):
half_size = size // 2 # int(round(size / 2))
start_i = max(center - half_size, 0) # Ensure we don't go below 0
end_i = max(start_i + size, start_i)
#end_i = min(center + half_size + (size % 2), ori_size[i])
start.append(start_i)
end.append(end_i)
return start, end
def cropped2ori(crop_data, ori_size, isocenter, trans_in_size):
'''
crop_data: the cropped data
ori_size: the original size of the data
isocenter: the isocenter of the original data
trans_in_size: the in_size parameter in the transfromation of loader
'''
assert (np.array(trans_in_size) == np.array(crop_data.shape)).all()
start_coords, end_coords = offset_spatial_crop(roi_center = isocenter, roi_size = trans_in_size)
# remove the padding
crop_start, crop_end = [], []
for i in range(len(ori_size)):
if end_coords[i] > ori_size[i]:
diff = end_coords[i] - ori_size[i]
crop_start.append(diff // 2)
crop_end.append(crop_data.shape[i] - diff + diff // 2)
else:
crop_start.append(0)
crop_end.append(crop_data.shape[i])
crop_data = crop_data[crop_start[0]: crop_end[0], crop_start[1]: crop_end[1], crop_start[2]: crop_end[2]]
pad_out = np.zeros(ori_size)
pad_out[start_coords[0]: end_coords[0], start_coords[1]: end_coords[1], start_coords[2]: end_coords[2]] = crop_data
return pad_out
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('cfig_path', type = str)
parser.add_argument('--phase', default = 'test', type = str)
args = parser.parse_args()
cfig = yaml.load(open(args.cfig_path), Loader=yaml.FullLoader)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ------------ data loader -----------------#
loaders = data_loader.GetLoader(cfig = cfig['loader_params'])
test_loader = loaders.test_dataloader()
if cfig['model_from_lightning']: # when the model is trained with training_lightning.py
from train_lightning import GDPLightningModel
pl_module = GDPLightningModel.load_from_checkpoint(cfig['save_model_path'], cfig = cfig, strict = True)
model = pl_module.model.to(device)
else: # when the model is trained with train.py
model = create_mednext_v1( num_input_channels = cfig['model_params']['num_input_channels'],
num_classes = cfig['model_params']['out_channels'],
model_id = cfig['model_params']['model_id'], # S, B, M and L are valid model ids
kernel_size = cfig['model_params']['kernel_size'],
deep_supervision = cfig['model_params']['deep_supervision']
).to(device)
# load pretrained model
model.load_state_dict(torch.load(cfig['save_model_path'], map_location = device))
with torch.no_grad():
model.eval()
for batch_idx, data_dict in enumerate(test_loader):
# Forward pass
outputs = model(data_dict['data'].to(device))
if cfig['act_sig']:
outputs = torch.sigmoid(outputs.clone())
outputs = outputs * cfig['scale_out']
if 'label' in data_dict.keys():
print ('L1 error is ', torch.nn.L1Loss()(outputs, data_dict['label'].to(device)).item())
if cfig['loader_params']['in_size'] != cfig['loader_params']['out_size']:
outputs = torch.nn.functional.interpolate(outputs, size = cfig['loader_params']['in_size'], mode = 'area')
for index in range(len(outputs)):
pad_out = np.zeros(data_dict['ori_img_size'][index].numpy().tolist())
crop_data = outputs[index][0].cpu().numpy()
ori_size = data_dict['ori_img_size'][index].numpy().tolist()
isocenter = data_dict['ori_isocenter'][index].numpy().tolist()
trans_in_size = cfig['loader_params']['in_size']
pred2orisize = cropped2ori(crop_data, ori_size, isocenter, trans_in_size) * cfig['loader_params']['dose_div_factor']
np.save(os.path.join(cfig['save_pred_path'], data_dict['id'][index] + '_pred.npy'), pred2orisize)