Skip to content

Commit

Permalink
fix act recover params and pattern recognition (PaddlePaddle#1695)
Browse files Browse the repository at this point in the history
* fix

* update
  • Loading branch information
ceci3 authored Mar 27, 2023
1 parent b735a39 commit 0333133
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
2 changes: 2 additions & 0 deletions paddleslim/common/recover_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def _recover_param_attr(program, startup_program):
if param.persistable is True and param.name != 'feed' and param.name != 'fetch']
with paddle.static.program_guard(program, startup_program):
for w in all_weights:
if w.dtype not in [paddle.float32]:
continue
new_w = paddle.create_parameter(
shape=w.shape, dtype=w.dtype, name=w.name)
new_w.set_value(w.get_value())
Expand Down
20 changes: 11 additions & 9 deletions paddleslim/common/transformer_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@ def _find_gemm_op(op, graph):
return op


def _append_transformer_prune_params(op, graph, block_num, params_dict):
for next_op in graph.next_ops(op):
def _append_transformer_prune_params(op_lists, graph, block_num, params_dict):
first_op = op_lists[0]
for next_op in graph.next_ops(first_op):
if next_op.type() == 'elementwise_add':
continue
next_op = _find_gemm_op(next_op, graph)
if next_op.type() in ['mul', 'matmul', 'matmul_v2'
] and has_trainable_var(next_op):
if next_op.type() in [
'mul', 'matmul', 'matmul_v2'
] and has_trainable_var(next_op) and next_op in op_lists:
if block_num not in params_dict:
params_dict[block_num] = {}
params_dict[block_num]['P1'] = [get_weight(next_op)]
Expand All @@ -41,7 +43,7 @@ def _append_transformer_prune_params(op, graph, block_num, params_dict):
get_weight(has_bias(next_op, graph)))
op = next_op
next_op = _find_gemm_op(find_weight_op(op, graph), graph)
if next_op:
if next_op and next_op in op_lists:
params_dict[block_num]['P2'] = [get_weight(next_op)]
params_dict[block_num]['P2'].append(
get_weight(has_bias(next_op, graph)))
Expand All @@ -57,14 +59,14 @@ def preprocess_transformer_patterns(patterns, graph):
continue
block_num = int(pattern_name.split('$')[-1])
if 'MHA' in pattern_name:
mha_weight = _append_transformer_prune_params(pattern_ops[0], graph,
block_num, mha_weight)
mha_weight = _append_transformer_prune_params(
pattern_ops, graph, block_num, mha_weight)
mha_weight[block_num]['reshape_op'] = []
for op in pattern_ops:
if op.type() in ['reshape', 'reshape2']:
mha_weight[block_num]['reshape_op'].append(op)
elif 'FFN' in pattern_name:
ffn_weight = _append_transformer_prune_params(pattern_ops[0], graph,
block_num, ffn_weight)
ffn_weight = _append_transformer_prune_params(
pattern_ops, graph, block_num, ffn_weight)

return mha_weight, ffn_weight

0 comments on commit 0333133

Please sign in to comment.