forked from colemiller94/gatedgan
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
109 lines (87 loc) · 4.02 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
from https://github.com/aitorzip/PyTorch-CycleGAN/blob/master/utils.py
"""
import time
import sys
import datetime
from visdom import Visdom
import numpy as np
import torch.nn.init
def label2tensor(label,tensor):
for i in range(label.size(0)):
tensor[i].fill_(label[i])
return tensor
def tensor2image(tensor):
image = 127.5*(tensor[0].cpu().float().numpy() + 1.0)
if image.shape[0] == 1:
image = np.tile(image, (3,1,1))
return image.astype(np.uint8)
def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm2d') != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant(m.bias.data, 0.0)
class Logger():
def __init__(self, n_epochs, batches_epoch):
#self.viz = Visdom(server='',port=)
self.viz = Visdom()
self.n_epochs = n_epochs
self.batches_epoch = batches_epoch
self.epoch = 1
self.batch = 1
self.prev_time = time.time()
self.mean_period = 0
self.losses = {}
self.loss_windows = {}
self.image_windows = {}
def log(self, losses=None, images=None, category=None):
self.mean_period += (time.time() - self.prev_time)
self.prev_time = time.time()
sys.stdout.write('\rEpoch %03d/%03d [%04d/%04d] -- ' % (self.epoch, self.n_epochs, self.batch, self.batches_epoch))
for i, loss_name in enumerate(losses.keys()):
if loss_name not in self.losses:
self.losses[loss_name] = losses[loss_name].item()
else:
self.losses[loss_name] += losses[loss_name].item()
if (i+1) == len(losses.keys()):
sys.stdout.write('%s: %.4f -- ' % (loss_name, self.losses[loss_name]/self.batch))
else:
sys.stdout.write('%s: %.4f | ' % (loss_name, self.losses[loss_name]/self.batch))
batches_done = self.batches_epoch*(self.epoch - 1) + self.batch
batches_left = self.batches_epoch*(self.n_epochs - self.epoch) + self.batches_epoch - self.batch
sys.stdout.write('ETA: %s' % (datetime.timedelta(seconds=batches_left*self.mean_period/batches_done)))
# Draw images
for image_name, tensor in images.items():
if image_name not in self.image_windows:
self.image_windows[image_name] = self.viz.image(tensor2image(tensor.data), opts={'title':image_name})
else:
self.viz.image(tensor2image(tensor.data), win=self.image_windows[image_name], opts={'title':image_name})
# End of epoch
# if True:
if (self.batch % self.batches_epoch) == 0:
# Plot losses
for loss_name, loss in self.losses.items():
if loss_name not in self.loss_windows:
self.loss_windows[loss_name] = self.viz.line(X=np.array([self.epoch]), Y=np.array([loss/self.batch]),
opts={'xlabel': 'epochs', 'ylabel': loss_name, 'title': loss_name})
else:
self.viz.line(X=np.array([self.epoch]), Y=np.array([loss/self.batch]), win=self.loss_windows[loss_name], update='append')
# Reset losses for next epoch
self.losses[loss_name] = 0.0
self.epoch += 1
self.batch = 1
sys.stdout.write('\n')
else:
self.batch += 1
class LambdaLR():
def __init__(self, n_epochs, offset, decay_start_epoch):
assert ((n_epochs - decay_start_epoch) > 0), "Decay must start before the training session ends!"
self.n_epochs = n_epochs
self.offset = offset
self.decay_start_epoch = decay_start_epoch
def step(self, epoch):
return 1.0 - max(0, epoch + self.offset - self.decay_start_epoch)/(self.n_epochs - self.decay_start_epoch)