-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
72 lines (55 loc) · 2.21 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import numpy as np
import torch
from PIL import Image
from old.model import device
def is_image_file(filename):
return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])
def load_img(filepath):
img = Image.open(filepath).convert('RGB')
img = img.resize((256, 256), Image.BICUBIC)
return img
def save_img(image_tensor, filename):
image_numpy = image_tensor.float().numpy()
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
image_numpy = image_numpy.clip(0, 255)
image_numpy = image_numpy.astype(np.uint8)
image_pil = Image.fromarray(image_numpy)
image_pil.save(filename)
print("Image saved as {}".format(filename))
class AverageMeter(object):
"""
Keeps track of most recent, average, sum, and count of a metric.
"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def save_checkpoint(epoch, epochs_since_improvement, model, optimizer, hmean, is_best):
state = {'epoch': epoch,
'epochs_since_improvement': epochs_since_improvement,
'hmean': hmean,
'model': model,
'optimizer': optimizer}
# filename = 'checkpoint_' + str(epoch) + '_' + str(loss) + '.tar'
filename = 'checkpoint.tar'
torch.save(state, filename)
# If this checkpoint is the best so far, store a copy so it doesn't get overwritten by a worse checkpoint
if is_best:
torch.save(state, 'BEST_checkpoint.tar')
def visualize(model, dataloader):
for _, data in enumerate(dataloader):
img_a, img_b = data[0].to(device), data[1].to(device)
fake_b = model(img_a)
for i, fake_img in enumerate(fake_b):
save_img(fake_img.cpu().detach(), 'images/{0}_out.jpg'.format(i))
save_img(img_b[i].cpu().detach(), 'images/{0}_real.jpg'.format(i))
save_img(img_a[i].cpu().detach(), 'images/{0}_img.jpg'.format(i))
break