forked from moskomule/senet.pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
64 lines (55 loc) · 2.24 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
from pathlib import Path
import torch
from tqdm import tqdm
class Trainer(object):
cuda = torch.cuda.is_available()
torch.backends.cudnn.benchmark = True
def __init__(self, model, optimizer, loss_f, save_dir=None, save_freq=5):
self.model = model
if self.cuda:
model.cuda()
self.optimizer = optimizer
self.loss_f = loss_f
self.save_dir = save_dir
self.save_freq = save_freq
def _iteration(self, data_loader, is_train=True):
loop_loss = []
accuracy = []
for data, target in tqdm(data_loader, ncols=80):
if self.cuda:
data, target = data.cuda(), target.cuda()
output = self.model(data)
loss = self.loss_f(output, target)
loop_loss.append(loss.data.item() / len(data_loader))
accuracy.append((output.data.max(1)[1] == target.data).sum().item())
if is_train:
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
mode = "train" if is_train else "test"
print(f">>>[{mode}] loss: {sum(loop_loss):.2f}/accuracy: {sum(accuracy) / len(data_loader.dataset):.2%}")
return loop_loss, accuracy
def train(self, data_loader):
self.model.train()
with torch.enable_grad():
loss, correct = self._iteration(data_loader)
def test(self, data_loader):
self.model.eval()
with torch.no_grad():
loss, correct = self._iteration(data_loader, is_train=False)
def loop(self, epochs, train_data, test_data, scheduler=None):
for ep in range(1, epochs + 1):
if scheduler is not None:
scheduler.step()
print("epochs: {}".format(ep))
self.train(train_data)
self.test(test_data)
if ep % self.save_freq:
self.save(ep)
def save(self, epoch, **kwargs):
if self.save_dir is not None:
model_out_path = Path(self.save_dir)
state = {"epoch": epoch, "weight": self.model.state_dict()}
if not model_out_path.exists():
model_out_path.mkdir()
torch.save(state, model_out_path / "model_epoch_{}.pth".format(epoch))