Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
siahuat0727 committed May 21, 2019
1 parent b0f46f4 commit 88a976f
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,21 +182,22 @@ def reinitialize(mask, drop_filters, conv_weights, fc_weights, zero_init):
if W.dim() == 4 and drop_filters[name] is not None: # conv weights
# find null space
size = W.size()
stdv = 1. / math.sqrt(size[1]*size[2]*size[3]) # https://github.com/pytorch/pytorch/blob/08891b0a4e08e2c642deac2042a02238a4d34c67/torch/nn/modules/conv.py#L40-L47
W2d = W.view(size[0], -1).cpu().numpy()
null_space = qr_null(np.vstack((drop_filters[name], W2d)))
null_space = torch.from_numpy(null_space).cuda()
null_space = null_space.transpose(0, 1).view(-1, size[1], size[2], size[3])

# https://github.com/pytorch/pytorch/blob/08891b0a4e08e2c642deac2042a02238a4d34c67/torch/nn/modules/conv.py#L40-L47
stdv = 1. / math.sqrt(size[1]*size[2]*size[3])

null_count = 0
for mask_idx in mask[name]:
if null_count < null_space.size(0):
W.data[mask_idx] = null_space.data[null_count].clamp_(-stdv, stdv)
null_count += 1
else:
W.data[mask_idx].uniform_(-stdv, stdv)
if null_space.size == 0:
W.data[mask[name]].uniform_(-stdv, stdv)
else:
null_space = null_space.transpose(0, 1).view(-1, size[1], size[2], size[3])
null_count = 0
for mask_idx in mask[name]:
if null_count < null_space.size(0):
W.data[mask_idx] = null_space.data[null_count].clamp_(-stdv, stdv)
null_count += 1
else:
W.data[mask_idx].uniform_(-stdv, stdv)

# mask channels of prev-layer-pruned-filters' outputs
if prev_layer_name is not None:
Expand Down

0 comments on commit 88a976f

Please sign in to comment.