-
Notifications
You must be signed in to change notification settings - Fork 139
/
Copy pathtraining_plots.py
68 lines (59 loc) · 2.28 KB
/
training_plots.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
"""This module contains functions to plot image generated when training GAN."""
import matplotlib.pyplot as plt
import numpy as np
def upscale(image):
"""Scale the image to 0-255 scale."""
return (image*127.5 + 127.5).astype(np.uint8)
def generated_images_plot(original, noised_data, generator):
"""Plot subplot of images during training."""
print('NOISED')
for i in range(9):
plt.subplot(331 + i)
plt.axis('off')
plt.imshow(upscale(np.squeeze(noised_data[i])), cmap='gray')
plt.show()
print('GENERATED')
for i in range(9):
pred = generator.predict(noised_data[i:i+1], verbose=0)
plt.subplot(331 + i)
plt.axis('off')
plt.imshow(upscale(np.squeeze(pred[0])), cmap='gray')
plt.show()
print('ORIGINAL')
for i in range(9):
plt.subplot(331 + i)
plt.axis('off')
plt.imshow(upscale(np.squeeze(original[i])), cmap='gray')
plt.show()
def plot_generated_images_combined(original, noised_data, generator):
"""Another function to plot images during training."""
rows, cols = 4, 12
num = rows * cols
image_size = 28
generated_images = generator.predict(noised_data[0:num])
imgs = np.concatenate([original[0:num], noised_data[0:num],
generated_images])
imgs = imgs.reshape((rows * 3, cols, image_size, image_size))
imgs = np.vstack(np.split(imgs, rows, axis=1))
imgs = imgs.reshape((rows * 3, -1, image_size, image_size))
imgs = np.vstack([np.hstack(i) for i in imgs])
imgs = upscale(imgs)
plt.figure(figsize=(8, 16))
plt.axis('off')
plt.title('Original Images: top rows, '
'Corrupted Input: middle rows, '
'Generated Images: bottom rows')
plt.imshow(imgs, cmap='gray')
plt.show()
def plot_training_loss(discriminator_losses, generator_losses):
"""Plot the losses."""
plt.figure()
plt.plot(range(len(discriminator_losses)), discriminator_losses,
color='red', label='Discriminator loss')
plt.plot(range(len(generator_losses)), generator_losses,
color='blue', label='Adversarial loss')
plt.title('Discriminator and Adversarial loss')
plt.xlabel('Iterations')
plt.ylabel('Loss (Adversarial/Discriminator)')
plt.legend()
plt.show()