-
Notifications
You must be signed in to change notification settings - Fork 1
/
display_network.py
99 lines (76 loc) · 3.09 KB
/
display_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import PIL
# This function visualizes filters in matrix A. Each column of A is a
# filter. We will reshape each column into a square image and visualizes
# on each cell of the visualization panel.
# All other parameters are optional, usually you do not need to worry
# about it.
# opt_normalize: whether we need to normalize the filter so that all of
# them can have similar contrast. Default value is true.
# opt_graycolor: whether we use gray as the heat map. Default is true.
# opt_colmajor: you can switch convention to row major for A. In that
# case, each row of A is a filter. Default value is false.
def display_network(A, filename='output/weights.png'):
opt_normalize = True
opt_graycolor = True
# Rescale
A = A - np.average(A)
# Compute rows & cols
(row, col) = A.shape
sz = int(np.ceil(np.sqrt(row)))
buf = 1
n = np.ceil(np.sqrt(col))
m = np.ceil(col / n)
image = np.ones(shape=(buf + m * (sz + buf), buf + n * (sz + buf)))
if not opt_graycolor:
image *= 0.1
k = 0
for i in range(int(m)):
for j in range(int(n)):
if k >= col:
continue
clim = np.max(np.abs(A[:, k]))
if opt_normalize:
image[buf + i * (sz + buf):buf + i * (sz + buf) + sz, buf + j * (sz + buf):buf + j * (sz + buf) + sz] = \
A[:, k].reshape(sz, sz) / clim
else:
image[buf + i * (sz + buf):buf + i * (sz + buf) + sz, buf + j * (sz + buf):buf + j * (sz + buf) + sz] = \
A[:, k].reshape(sz, sz) / np.max(np.abs(A))
k += 1
plt.imsave(filename, image, cmap=matplotlib.cm.gray)
def display_color_network(A, filename='output/weights.png'):
"""
# display receptive field(s) or basis vector(s) for image patches
#
# A the basis, with patches as column vectors
# In case the midpoint is not set at 0, we shift it dynamically
:param A:
:param file:
:return:
"""
if np.min(A) >= 0:
A = A - np.mean(A)
cols = np.round(np.sqrt(A.shape[1]))
channel_size = A.shape[0] / 3
dim = np.sqrt(channel_size)
dimp = dim + 1
rows = np.ceil(A.shape[1] / cols)
B = A[0:channel_size, :]
C = A[channel_size:2 * channel_size, :]
D = A[2 * channel_size:3 * channel_size, :]
B = B / np.max(np.abs(B))
C = C / np.max(np.abs(C))
D = D / np.max(np.abs(D))
# Initialization of the image
image = np.ones(shape=(dim * rows + rows - 1, dim * cols + cols - 1, 3))
for i in range(int(rows)):
for j in range(int(cols)):
# This sets the patch
image[i * dimp:i * dimp + dim, j * dimp:j * dimp + dim, 0] = B[:, i * cols + j].reshape(dim, dim)
image[i * dimp:i * dimp + dim, j * dimp:j * dimp + dim, 1] = C[:, i * cols + j].reshape(dim, dim)
image[i * dimp:i * dimp + dim, j * dimp:j * dimp + dim, 2] = D[:, i * cols + j].reshape(dim, dim)
image = (image + 1) / 2
PIL.Image.fromarray(np.uint8(image * 255), 'RGB').save(filename)
return 0