Skip to content

Commit

Permalink
surgery: transplant weights with an optional suffix
Browse files Browse the repository at this point in the history
  • Loading branch information
shelhamer committed May 20, 2016
1 parent c909969 commit 93f5e93
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions surgery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,21 @@
import caffe
import numpy as np

def transplant(new_net, net):
def transplant(new_net, net, suffix=''):
for p in net.params:
if p not in new_net.params:
p_new = p + suffix
if p_new not in new_net.params:
print 'dropping', p
continue
for i in range(len(net.params[p])):
if i > (len(new_net.params[p]) - 1):
if i > (len(new_net.params[p_new]) - 1):
print 'dropping', p, i
break
if net.params[p][i].data.shape != new_net.params[p][i].data.shape:
print 'coercing', p, i, 'from', net.params[p][i].data.shape, 'to', new_net.params[p][i].data.shape
if net.params[p][i].data.shape != new_net.params[p_new][i].data.shape:
print 'coercing', p, i, 'from', net.params[p][i].data.shape, 'to', new_net.params[p_new][i].data.shape
else:
print 'copying', p, i
new_net.params[p][i].data.flat = net.params[p][i].data.flat
print 'copying', p, ' -> ', p_new, i
new_net.params[p_new][i].data.flat = net.params[p][i].data.flat

def expand_score(new_net, new_layer, net, layer):
old_cl = net.params[layer][0].num
Expand Down

0 comments on commit 93f5e93

Please sign in to comment.