Skip to content

Commit

Permalink
Fix pruning for yolov4 (PaddlePaddle#313)
Browse files Browse the repository at this point in the history
  • Loading branch information
wanghaoshuang authored Jun 5, 2020
1 parent 44e359c commit b81f27a
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 35 deletions.
18 changes: 9 additions & 9 deletions paddleslim/prune/criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,19 @@ def l1_norm(group, graph):
list: A list of tuple storing l1-norm on given axis.
"""
scores = []
for name, value, axis in group:
for name, value, axis, pruned_idx in group:

reduce_dims = [i for i in range(len(value.shape)) if i != axis]
score = np.sum(np.abs(value), axis=tuple(reduce_dims))
scores.append((name, axis, score))
scores.append((name, axis, score, pruned_idx))

return scores


@CRITERION.register
def geometry_median(group, graph):
scores = []
name, value, axis = group[0]
name, value, axis, _ = group[0]
assert (len(value.shape) == 4)

def get_distance_sum(value, out_idx):
Expand All @@ -73,8 +73,8 @@ def get_distance_sum(value, out_idx):

tmp = np.array(dist_sum_list)

for name, value, axis in group:
scores.append((name, axis, tmp))
for name, value, axis, idx in group:
scores.append((name, axis, tmp, idx))
return scores


Expand All @@ -97,7 +97,7 @@ def bn_scale(group, graph):
assert (isinstance(graph, GraphWrapper))

# step1: Get first convolution
conv_weight, value, axis = group[0]
conv_weight, value, axis, _ = group[0]
param_var = graph.var(conv_weight)
conv_op = param_var.outputs()[0]

Expand All @@ -111,12 +111,12 @@ def bn_scale(group, graph):

# steps3: Find scale of bn
score = None
for name, value, aixs in group:
for name, value, aixs, _ in group:
if bn_scale_param == name:
score = np.abs(value.reshape([-1]))

scores = []
for name, value, axis in group:
scores.append((name, axis, score))
for name, value, axis, idx in group:
scores.append((name, axis, score, idx))

return scores
6 changes: 3 additions & 3 deletions paddleslim/prune/group_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,21 +57,21 @@ def collect_convs(params, graph, visited={}):
conv_op = param.outputs()[0]
walker = conv2d_walker(
conv_op, pruned_params=pruned_params, visited=visited)
walker.prune(param, pruned_axis=0, pruned_idx=[])
walker.prune(param, pruned_axis=0, pruned_idx=[0])
groups.append(pruned_params)
visited = set()
uniq_groups = []
for group in groups:
repeat_group = False
simple_group = []
for param, axis, _ in group:
for param, axis, pruned_idx in group:
param = param.name()
if axis == 0:
if param in visited:
repeat_group = True
else:
visited.add(param)
simple_group.append((param, axis))
simple_group.append((param, axis, pruned_idx))
if not repeat_group:
uniq_groups.append(simple_group)

Expand Down
7 changes: 4 additions & 3 deletions paddleslim/prune/idx_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,17 @@ def default_idx_selector(group, ratio):
list: pruned indexes
"""
name, axis, score = group[
name, axis, score, _ = group[
0] # sort channels by the first convolution's score
sorted_idx = score.argsort()

pruned_num = int(round(len(sorted_idx) * ratio))
pruned_idx = sorted_idx[:pruned_num]

idxs = []
for name, axis, score in group:
idxs.append((name, axis, pruned_idx))
for name, axis, score, offsets in group:
r_idx = [i + offsets[0] for i in pruned_idx]
idxs.append((name, axis, r_idx))
return idxs


Expand Down
33 changes: 19 additions & 14 deletions paddleslim/prune/prune_walker.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,10 @@ def _prune_op(self, op, var, pruned_axis, pruned_idx, visited=None):
if op.type() in SKIP_OPS:
_logger.warn("Skip operator [{}]".format(op.type()))
return
_logger.warn(
"{} op will be pruned by default walker to keep the shapes of input and output being same because its walker is not registered.".
format(op.type()))

# _logger.warn(
# "{} op will be pruned by default walker to keep the shapes of input and output being same because its walker is not registered.".
# format(op.type()))
cls = PRUNE_WORKER.get("default_walker")
_logger.debug("\nfrom: {}\nto: {}\npruned_axis: {}; var: {}".format(
self.op, op, pruned_axis, var.name()))
Expand Down Expand Up @@ -263,26 +264,30 @@ def _prune(self, var, pruned_axis, pruned_idx):
if name == "Y":
actual_axis = pruned_axis - axis
in_var = self.op.inputs(name)[0]
if len(in_var.shape()) == 1 and in_var.shape()[0] == 1:
continue
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, actual_axis, pruned_idx)

else:
if var in self.op.inputs("X"):
in_var = self.op.inputs("Y")[0]

if in_var.is_parameter():
self.pruned_params.append(
(in_var, pruned_axis - axis, pruned_idx))
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis - axis, pruned_idx)
if not (len(in_var.shape()) == 1 and in_var.shape()[0] == 1):
if in_var.is_parameter():
self.pruned_params.append(
(in_var, pruned_axis - axis, pruned_idx))
pre_ops = in_var.inputs()
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis - axis,
pruned_idx)
elif var in self.op.inputs("Y"):
in_var = self.op.inputs("X")[0]
pre_ops = in_var.inputs()
pruned_axis = pruned_axis + axis
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)
if not (len(in_var.shape()) == 1 and in_var.shape()[0] == 1):
pre_ops = in_var.inputs()
pruned_axis = pruned_axis + axis
for op in pre_ops:
self._prune_op(op, in_var, pruned_axis, pruned_idx)

out_var = self.op.outputs("Out")[0]
self._visit(out_var, pruned_axis)
Expand Down
17 changes: 11 additions & 6 deletions paddleslim/prune/pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,43 +90,48 @@ def prune(self,
visited = {}
pruned_params = []
for param, ratio in zip(params, ratios):
_logger.info("pruning: {}".format(param))
if graph.var(param) is None:
_logger.warn(
"Variable[{}] to be pruned is not in current graph.".
format(param))
continue
group = collect_convs([param], graph, visited)[0] # [(name, axis)]
group = collect_convs([param], graph,
visited)[0] # [(name, axis, pruned_idx)]
if group is None or len(group) == 0:
continue
if only_graph and self.idx_selector.__name__ == "default_idx_selector":

param_v = graph.var(param)
pruned_num = int(round(param_v.shape()[0] * ratio))
pruned_idx = [0] * pruned_num
for name, axis in group:
for name, axis, _ in group:
pruned_params.append((name, axis, pruned_idx))

else:
assert ((not self.pruned_weights),
"The weights have been pruned once.")
group_values = []
for name, axis in group:
for name, axis, pruned_idx in group:
values = np.array(scope.find_var(name).get_tensor())
group_values.append((name, values, axis))
group_values.append((name, values, axis, pruned_idx))

scores = self.criterion(group_values,
graph) # [(name, axis, score)]
scores = self.criterion(
group_values, graph) # [(name, axis, score, pruned_idx)]

pruned_params.extend(self.idx_selector(scores, ratio))

merge_pruned_params = {}
for param, pruned_axis, pruned_idx in pruned_params:
print("{}\t{}\t{}".format(param, pruned_axis, len(pruned_idx)))
if param not in merge_pruned_params:
merge_pruned_params[param] = {}
if pruned_axis not in merge_pruned_params[param]:
merge_pruned_params[param][pruned_axis] = []
merge_pruned_params[param][pruned_axis].append(pruned_idx)

print("param name: stage.0.conv_layer.conv.weights; idx: {}".format(
merge_pruned_params["stage.0.conv_layer.conv.weights"][1]))
for param_name in merge_pruned_params:
for pruned_axis in merge_pruned_params[param_name]:
pruned_idx = np.concatenate(merge_pruned_params[param_name][
Expand Down

0 comments on commit b81f27a

Please sign in to comment.