Skip to content

Commit

Permalink
Fix bug when prune depthwise convolution layer. (PaddlePaddle#399)
Browse files Browse the repository at this point in the history
* fix bug when prune the depthwise convolution layer

* fix bug when prune the depthwise convolution layer

* fix bug when prune depthwise convolution layer

* fix pruner when prune depthwise convolution layer

* remove print from unit test
  • Loading branch information
yukavio authored Aug 5, 2020
1 parent d00373a commit 36b38fc
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 17 deletions.
18 changes: 3 additions & 15 deletions paddleslim/core/graph_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,7 @@ def outputs(self, name):
"""
Get all the varibales by the output name.
"""
return [
self._graph.var(var_name) for var_name in self._op.output(name)
]
return [self._graph.var(var_name) for var_name in self._op.output(name)]

def set_attr(self, key, value):
"""
Expand Down Expand Up @@ -354,16 +352,6 @@ def numel_params(self):
ret += np.product(param.shape())
return ret

def update_param_shape(self, scope):
"""
Update the shape of parameters in the graph according to tensors in scope.
It is used after loading pruned parameters from file.
"""
for param in self.all_parameters():
tensor_shape = np.array(
scope.find_var(param.name()).get_tensor()).shape
param.set_shape(tensor_shape)

def infer_shape(self):
"""
Update the groups of convolution layer according to current filters.
Expand All @@ -375,6 +363,6 @@ def infer_shape(self):

def update_groups_of_conv(self):
for op in self.ops():
if op.type() == 'depthwise_conv2d' or op.type(
) == 'depthwise_conv2d_grad':
if 'conv2d' in op.type() and op.attr('groups') >= op.inputs(
'Filter')[0].shape()[0]:
op.set_attr('groups', op.inputs('Filter')[0].shape()[0])
3 changes: 1 addition & 2 deletions paddleslim/prune/group_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ def collect_convs(params, graph, visited={}):
walker = conv2d_walker(
conv_op, pruned_params=pruned_params, visited=visited)
walker.prune(param, pruned_axis=0, pruned_idx=[0])
if len(pruned_params) > 0:
groups.append(pruned_params)
groups.append(pruned_params)
visited = set()
uniq_groups = []
for group in groups:
Expand Down
2 changes: 2 additions & 0 deletions tests/test_group_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def test_prune(self):
conv6 = conv_bn_layer(conv5, 8, 3, "conv6")
groups = collect_convs(
["conv1_weights", "conv2_weights", "conv3_weights"], main_program)
while [] in groups:
groups.remove([])
self.assertTrue(len(groups) == 2)
self.assertTrue(len(groups[0]) == 18)
self.assertTrue(len(groups[1]) == 6)
Expand Down

0 comments on commit 36b38fc

Please sign in to comment.