forked from PaddlePaddle/PaddleSlim
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpatterns.py
159 lines (132 loc) · 5.65 KB
/
patterns.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import sys
import numpy as np
import warnings
import paddle
from ..core import GraphWrapper
from .patterns_common import *
__all__ = ['find_final_nodes', 'get_patterns']
def find_final_nodes(program):
""" Find the output of the final op with weights in the program """
final_nodes = []
graph = GraphWrapper(program)
for op in sorted(graph.ops()):
if op.type() in ALL_WEIGHT_OP and is_output_weight_ops(op, graph):
n_op = has_bias(op, graph)
if n_op is not None:
final_nodes.extend(n_op.all_outputs())
else:
if op.type() == 'batch_norm':
out_var = op.outputs('Y')
else:
out_var = op.all_outputs()
final_nodes.extend(out_var)
return final_nodes
def _is_mha(pattern_ops, pattern_ops_type, skip_quant_tensor_list=[]):
""" judge whether this pattern is multihead attention """
if pattern_ops_type.count('softmax') != 1 or pattern_ops_type.count(
'fetch') > 0:
return False
matmul_num = 0
for op in pattern_ops:
if op.type() in ['matmul', 'matmul_v2']:
if not is_dynamic_weight_op(op):
skip_quant_tensor_list.extend(op._op.input('X'))
matmul_num += 1
if matmul_num == 2:
return True
return False
def _is_ffn(pattern_ops, pattern_ops_type):
""" judge whether this pattern is feed forward network """
if pattern_ops_type.count('layer_norm') != 1:
return False
linear_num = 0
act_num = 0
for op in pattern_ops:
if op.type() in ['mul', 'matmul', 'matmul_v2']:
if is_dynamic_weight_op(op):
linear_num += 1
if op.type() in ['relu', 'gelu']:
act_num += 1
if linear_num == 2 and act_num == 1:
return True
return False
def get_patterns(program, only_final_node=True):
""" distinguish the pattern in the program and get distillation node """
distill_node = []
skip_quant_tensor_list = []
patterns = {}
graph = GraphWrapper(program)
block_num = 0
model_type = None
for op in graph.ops():
belonged_teacher = False
for inp in op.all_inputs():
if 'teacher' in inp._var.name:
belonged_teacher = True
break
if belonged_teacher:
continue
if op.type() == 'elementwise_add':
inp1, inp2 = op.all_inputs()[0], op.all_inputs()[1]
if (not inp1._var.persistable) and (not inp2._var.persistable):
sc_path = []
shortcut_start_op = []
is_sc = is_shortcut(op, graph, sc_path, shortcut_start_op)
if is_sc:
out_var_name = op.all_outputs()[0]._var.name
shortcut_start_op = shortcut_start_op[0]
pattern_ops, pattern_ops_type = traversal_ops(
shortcut_start_op, graph, op.idx())
pattern_name = shortcut_start_op.type() + '$' + str(op.idx(
))
if _is_mha(pattern_ops, pattern_ops_type,
skip_quant_tensor_list):
model_type = 'transformer'
pattern_name = 'MHA$' + str(block_num)
if model_type == 'transformer' and _is_ffn(
pattern_ops, pattern_ops_type):
pattern_name = 'FFN$' + str(block_num)
block_num += 1
if not only_final_node:
distill_node.append('teacher_' + out_var_name)
distill_node.append(out_var_name)
if model_type == 'transformer' and (
'fetch' in pattern_ops_type or
pattern_ops_type[-1] == 'scale'):
if 'input_mask' not in patterns:
patterns['input_mask'] = pattern_ops[0]._op
if 'fetch' in pattern_ops_type or pattern_ops_type[
-1] == 'scale':
continue
patterns[pattern_name] = pattern_ops
if model_type != 'transformer' and (not only_final_node):
distill_node.append('teacher_' + out_var_name)
distill_node.append(out_var_name)
### add the output of final weight node to distill node
final_weight_node = find_final_nodes(program)
for out_var in final_weight_node:
distill_node.append('teacher_' + out_var.name())
distill_node.append(out_var.name())
#### skip quant matmul in attention
if model_type == 'transformer':
for block_id in range(len(program.blocks)):
for op in program.blocks[block_id].ops:
for inp_name in op.input_arg_names:
if inp_name in skip_quant_tensor_list:
op._set_attr("op_namescope", "skip_quant")
return patterns, distill_node, model_type