-
Notifications
You must be signed in to change notification settings - Fork 56
/
Copy pathFocalLoss_test.py
73 lines (59 loc) · 2.11 KB
/
FocalLoss_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import torch.nn as nn
from torch import optim
from seg_loss.focal_loss import FocalLoss_Ori, BinaryFocalLoss
def test_BFL():
import matplotlib.pyplot as plt
torch.manual_seed(123)
shape = (4, 1, 32, 32, 32)
datas = 40 * (torch.randint(0, 2, shape) - 0.5)
target = torch.zeros_like(datas) + torch.randint(0, 2, size=shape)
model = nn.Sequential(*[nn.Conv3d(1, 16, kernel_size=3, padding=1, stride=1),
nn.BatchNorm3d(16),
nn.ReLU(),
nn.Conv3d(16, 1, kernel_size=3, padding=1, stride=1)])
criterion = BinaryFocalLoss()
losses = []
optimizer = optim.Adam(params=model.parameters(), lr=0.001)
for step in range(100):
out = model(datas)
loss = criterion(out, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.item())
if step % 10 == 0:
print(step)
plt.plot(losses)
plt.show()
def test_focal():
import matplotlib.pyplot as plt
torch.manual_seed(123)
num_class = 5
shape = (4, 1, 32, 32, 32)
target_shape = (4, 1, 32, 32, 32)
datas = 40 * (torch.rand(shape) - 0.5).cuda()
target = torch.randint(0, num_class, size=target_shape).cuda()
target[0, 0, 0, 0, :] = 255
target = target.long().cuda()
FL = FocalLoss_Ori(num_class=num_class, gamma=2.0, ignore_index=255, reduction='mean')
model = nn.Sequential(*[nn.Conv3d(1, 16, kernel_size=3, padding=1, stride=1),
nn.BatchNorm3d(16),
nn.ReLU(),
nn.Conv3d(16, num_class, kernel_size=3, padding=1, stride=1)])
model = model.cuda()
losses = []
optimizer = optim.Adam(params=model.parameters(), lr=0.001)
for i in range(100):
output = model(datas)
loss = FL(output, target)
losses.append(loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i % 10 == 0:
print(i)
plt.plot(losses)
plt.show()
if __name__ == '__main__':
test_focal()