-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathframework_mark.py
57 lines (46 loc) · 1.95 KB
/
framework_mark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
import torch.nn as nn
from torch.autograd import Variable as V
import cv2
import numpy as np
class MyFrame():
def __init__(self, net, loss, lr=2e-4, evalmode = False):
self.net = net().cuda()
self.net = torch.nn.DataParallel(self.net, device_ids=range(torch.cuda.device_count()))
self.optimizer = torch.optim.Adam(params=self.net.parameters(), lr=lr)
self.loss = loss()
self.old_lr = lr
if evalmode:
for i in self.net.modules():
if isinstance(i, nn.BatchNorm2d):
i.eval()
def set_input(self, img_batch, mask_batch=None, mask_mark_batch=None, img_id=None):
self.img = img_batch
self.mask = mask_batch
self.mask_mark = mask_mark_batch
self.img_id = img_id
def forward(self, volatile=False):
self.img = V(self.img.cuda(), volatile=volatile)
if self.mask is not None:
self.mask = V(self.mask.cuda(), volatile=volatile)
self.mask_mark = V(self.mask_mark.cuda(), volatile=volatile)
def optimize(self):
self.forward()
self.optimizer.zero_grad()
pred = self.net.forward(self.img)
loss_bce, loss_dice, loss = self.loss(self.mask, pred, self.mask_mark)
loss.backward()
self.optimizer.step()
return loss_bce.data, loss_dice.data, loss.data
def save(self, path):
torch.save(self.net.state_dict(), path)
def load(self, path):
self.net.load_state_dict(torch.load(path))
def update_lr(self, new_lr, mylog, factor=False):
if factor:
new_lr = self.old_lr / new_lr
for param_group in self.optimizer.param_groups:
param_group['lr'] = new_lr
print('update learning rate: %f -> %f' % (self.old_lr, new_lr), file=mylog)
print('update learning rate: %f -> %f' % (self.old_lr, new_lr))
self.old_lr = new_lr