Skip to content

Commit

Permalink
limit prune ratio in transformer pruner (PaddlePaddle#1087)
Browse files Browse the repository at this point in the history
  • Loading branch information
ceci3 authored May 9, 2022
1 parent dede4a8 commit e19de50
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
3 changes: 3 additions & 0 deletions paddleslim/auto_compression/config_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def load_config(config_path):
else:
train_config = None

if len(compress_config) == 0:
compress_config = None

return compress_config, train_config


Expand Down
21 changes: 19 additions & 2 deletions paddleslim/auto_compression/transformer_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ..common import get_logger
from ..common.recover_program import recover_inference_program
from ..common.transformer_pattern import preprocess_transformer_patterns
from ..common.patterns_common import is_dynamic_weight_op

_logger = get_logger(__name__, level=logging.INFO)

Expand Down Expand Up @@ -228,13 +229,21 @@ def __init__(self, exe, places, inference_program, patterns, label_info,
self.graph = GraphWrapper(inference_program)
self.patterns = patterns
self.label_info = label_info
self.width_mult = width_mult
self.fetch_targets = fetch_targets
self.dataloader = dataloader

self.scope = paddle.static.global_scope()
input_mask_op, layer_num, head_num, mha_weight, ffn_weight = self._preprocess_patterns(
patterns, self.graph)

### the prune ratio * head_num need to be an integer.
pruned_head = round(width_mult * head_num)
self.width_mult = float(pruned_head) / head_num
if self.width_mult != width_mult:
_logger.info(
"the prune ratio * head_num need to be an integer. so change prune ratio from {} to {}".
format(str(1.0 - width_mult), str(1.0 - self.width_mult)))

self.input_mask_op = input_mask_op
self.mha_weight = mha_weight
self.ffn_weight = ffn_weight
Expand All @@ -247,7 +256,15 @@ def _preprocess_patterns(self, patterns, graph):
""" Preprocess pattern of the program, get some info need by reorder"""
input_mask_op = patterns['input_mask']
layer_num = int((len(patterns) - 1) / 2)
head_num = len(input_mask_op.input_arg_names)

### get real head number
head_num = -1
tmp_mha_ops = patterns['MHA$0']
for op in tmp_mha_ops:
if op.type() in ['matmul', 'matmul_v2'] and (
not is_dynamic_weight_op(op)) and head_num == -1:
inp_var = op.inputs("X")
head_num = inp_var[0].shape()[1]

mha_weight, ffn_weight = preprocess_transformer_patterns(patterns,
graph)
Expand Down

0 comments on commit e19de50

Please sign in to comment.