-
Notifications
You must be signed in to change notification settings - Fork 56
/
Copy pathdemo.py
102 lines (88 loc) · 2.88 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import SimpleITK as sitk
import matplotlib.pyplot as plt
from seg_loss import MultiDiceLoss
def IoU(pred, target):
assert isinstance(pred, np.ndarray), 'prediction should be numpy.ndarray'
assert isinstance(target, np.ndarray), 'prediction should be numpy.ndarray'
eps = 1e-6
pred = pred.flatten()
target = target.flatten()
inter = np.sum(pred * target).astype(np.float32)
union = np.sum(pred).astype(np.float32) + np.sum(target).astype(np.float32) - inter
iou = inter / (union + eps)
return iou
image_paths = ['./Data/338/t2w.mha', './Data/338/adc.mha', './Data/338/dwi.mha']
seg_path = './Data/338/mask.mha'
images = []
scale = 0.75
for image_path in image_paths:
itkimage = sitk.ReadImage(image_path)
image = sitk.GetArrayFromImage(itkimage)
image = np.transpose(image, (1, 2, 0))
hei, wid, _ = image.shape
image = cv2.resize(image, (int(scale * wid), int(scale * hei)), cv2.INTER_CUBIC)
image = np.transpose(image, (2, 0, 1))
image = np.asarray(image, dtype=np.float32)
image = (image - np.mean(image)) / np.std(image)
images.append(image)
images = np.array(images)
itkimage = sitk.ReadImage(seg_path)
seg = sitk.GetArrayFromImage(itkimage)
seg = np.transpose(seg, (1, 2, 0))
hei, wid, _ = seg.shape
seg = cv2.resize(seg, (int(scale * wid), int(scale * hei)), cv2.INTER_NEAREST)
seg = np.transpose(seg, (2, 0, 1))
print(images.shape)
print(seg.shape)
seg = np.asarray(seg, np.int64)
data = torch.from_numpy(images)
target = torch.from_numpy(seg)
data = data.unsqueeze(0).cuda()
target = target.unsqueeze(0).cuda()
class Net(nn.Module):
def __init__(self, in_channels, out_channels):
super(Net, self).__init__()
self.Conv1 = nn.Sequential(
nn.Conv3d(in_channels, 32, 3, padding=1),
nn.BatchNorm3d(32),
nn.ReLU()
)
self.Conv2 = nn.Sequential(
nn.Conv3d(32, 64, 3, padding=1),
nn.BatchNorm3d(64),
nn.ReLU()
)
self.Conv3 = nn.Sequential(
nn.Conv3d(64, out_channels, 3, padding=1),
nn.Softmax(dim=1)
)
def forward(self, input):
out = self.Conv1(input)
out = self.Conv2(out)
out = self.Conv3(out)
return out
iters = 100
model = Net(len(image_paths), 3).cuda()
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.99))
dice_loss = MultiDiceLoss([0.5, 0.5, 0.5], num_class=3, k=1, ohem=True)
ious = []
losses = []
for i in range(iters):
optimizer.zero_grad()
out1 = model(data)
loss, _ = dice_loss(out1, target)
loss.backward()
optimizer.step()
_, pred1 = out1.max(1)
iou = IoU(pred1.cpu().numpy(), target.cpu().numpy())
ious.append(iou)
losses.append(loss.item())
plt.plot(ious)
plt.plot(losses)
plt.show()
print(np.min(losses))