forked from shelhamer/fcn.berkeleyvision.org
-
Notifications
You must be signed in to change notification settings - Fork 0
/
surgery.py
46 lines (42 loc) · 1.65 KB
/
surgery.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from __future__ import division
import caffe
import numpy as np
def transplant(new_net, net, suffix=''):
for p in net.params:
p_new = p + suffix
if p_new not in new_net.params:
print 'dropping', p
continue
for i in range(len(net.params[p])):
if i > (len(new_net.params[p_new]) - 1):
print 'dropping', p, i
break
if net.params[p][i].data.shape != new_net.params[p_new][i].data.shape:
print 'coercing', p, i, 'from', net.params[p][i].data.shape, 'to', new_net.params[p_new][i].data.shape
else:
print 'copying', p, ' -> ', p_new, i
new_net.params[p_new][i].data.flat = net.params[p][i].data.flat
def expand_score(new_net, new_layer, net, layer):
old_cl = net.params[layer][0].num
new_net.params[new_layer][0].data[:old_cl][...] = net.params[layer][0].data
new_net.params[new_layer][1].data[0,0,0,:old_cl][...] = net.params[layer][1].data
def upsample_filt(size):
factor = (size + 1) // 2
if size % 2 == 1:
center = factor - 1
else:
center = factor - 0.5
og = np.ogrid[:size, :size]
return (1 - abs(og[0] - center) / factor) * \
(1 - abs(og[1] - center) / factor)
def interp(net, layers):
for l in layers:
m, k, h, w = net.params[l][0].data.shape
if m != k and k != 1:
print 'input + output channels need to be the same or |output| == 1'
raise
if h != w:
print 'filters need to be square'
raise
filt = upsample_filt(h)
net.params[l][0].data[range(m), range(k), :, :] = filt