Skip to content

Commit

Permalink
Feat: add init option for the related weights
Browse files Browse the repository at this point in the history
When re-init pruned filters, for the weights of the following layer
that related to the filters, add option to initialize randomly or
initialize with all zeros.
  • Loading branch information
siahuat0727 committed Apr 27, 2019
1 parent 8d94fc8 commit 4e1564e
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
parser.add_argument('--save_model', type=str, default='best.pt', help="path to save model")
parser.add_argument('--prune_ratio', type=float, default=0.3, help="prune ratio")
parser.add_argument('--comment', type=str, default='', help="tag for tensorboardX event name")
parser.add_argument('--zero_init', action='store_true', help="whether to initialize with zero")

def train(train_loader, criterion, optimizer, epoch, model, writer, mask, args, conv_weights):
batch_time = AverageMeter()
Expand Down Expand Up @@ -172,7 +173,7 @@ def pruning(conv_weights, prune_ratio):
test_filter_sparsity(conv_weights)
return prune, mask, drop_filters

def reinitialize(mask, drop_filters, conv_weights, fc_weights):
def reinitialize(mask, drop_filters, conv_weights, fc_weights, zero_init):
print('Reinitializing...')
with torch.no_grad():
prev_layer_name = None
Expand Down Expand Up @@ -201,9 +202,16 @@ def reinitialize(mask, drop_filters, conv_weights, fc_weights):
# mask channels of prev-layer-pruned-filters' outputs
if prev_layer_name is not None:
if W.dim() == 4: # conv
W.data[:, mask[prev_layer_name]] = 0
if zero_init:
W.data[:, mask[prev_layer_name]] = 0
else:
W.data[:, mask[prev_layer_name]].uniform_(-stdv, stdv)
elif W.dim() == 2: # fc
W.view(W.size(0), prev_num_filters, -1).data[:, mask[prev_layer_name]] = 0
if zero_init:
W.view(W.size(0), prev_num_filters, -1).data[:, mask[prev_layer_name]] = 0
else:
stdv = 1. / math.sqrt(W.size(1))
W.view(W.size(0), prev_num_filters, -1).data[:, mask[prev_layer_name]].uniform_(-stdv, stdv)
prev_layer_name, prev_num_filters = name, W.size(0)
test_filter_sparsity(conv_weights)

Expand Down Expand Up @@ -277,7 +285,7 @@ def main():
prune_map.append(np.concatenate(list(prune.values())))
# check if the end of S2 stage
if any(epoch == s for s in range(args.S1+args.S2, args.epochs, args.S1+args.S2)):
reinitialize(mask, drop_filters, conv_weights, fc_weights)
reinitialize(mask, drop_filters, conv_weights, fc_weights, args.zero_init)
train(trainloader, criterion, optimizer, epoch, model, writer, mask, args, conv_weights)
acc = validate(testloader, criterion, model, writer, args, epoch, best_acc)
best_acc = max(best_acc, acc)
Expand All @@ -288,6 +296,7 @@ def main():

# Shows which filters turn off as training progresses
prune_map = np.array(prune_map).transpose()
print(prune_map)
plt.matshow(prune_map.astype(np.int), cmap=ListedColormap(['k', 'w']))
plt.xticks(np.arange(prune_map.shape[1]))
plt.yticks(np.arange(prune_map.shape[0]))
Expand Down

0 comments on commit 4e1564e

Please sign in to comment.