forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
/
check_api_compatible.py
196 lines (177 loc) · 6.42 KB
/
check_api_compatible.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import inspect
import logging
import re
import sys
logger = logging.getLogger()
if logger.handlers:
# we assume the first handler is the one we want to configure
console = logger.handlers[0]
else:
console = logging.StreamHandler(sys.stderr)
logger.addHandler(console)
console.setFormatter(
logging.Formatter(
"%(asctime)s - %(funcName)s:%(lineno)d - %(levelname)s - %(message)s"
)
)
def _check_compatible(args_o, args_n, defaults_o, defaults_n):
# 如果参数减少了,需要提醒关注
if len(args_o) > len(args_n):
logger.debug("args num less then previous: %s vs %s", args_o, args_n)
return False
# 参数改名了,也要提醒关注
for idx in range(min(len(args_o), len(args_n))):
if args_o[idx] != args_n[idx]:
logger.debug(
"args's %d parameter diff with previous: %s vs %s",
idx,
args_o,
args_n,
)
return False
# 新增加了参数,必须提供默认值。以及不能减少默认值数量
if (len(args_n) - len(defaults_n)) > (len(args_o) - len(defaults_o)):
logger.debug(
"defaults num less then previous: %s vs %s", defaults_o, defaults_n
)
return False
# 默认值必须相等
for idx in range(min(len(defaults_o), len(defaults_n))):
nidx_o = -1 - idx
nidx_n = -1 - idx - (len(args_n) - len(args_o))
if defaults_o[nidx_o] != defaults_n[nidx_n]:
logger.debug(
"defaults's %d value diff with previous: %s vs %s",
nidx_n,
defaults_o,
defaults_n,
)
return False
return True
def check_compatible(old_api_spec, new_api_spec):
"""
check compatible, FullArgSpec
"""
if not (
isinstance(old_api_spec, inspect.FullArgSpec)
and isinstance(new_api_spec, inspect.FullArgSpec)
):
logger.warning(
"new_api_spec or old_api_spec is not instance of inspect.FullArgSpec"
)
return False
return _check_compatible(
old_api_spec.args,
new_api_spec.args,
[] if old_api_spec.defaults is None else old_api_spec.defaults,
[] if new_api_spec.defaults is None else new_api_spec.defaults,
)
def check_compatible_str(old_api_spec_str, new_api_spec_str):
patArgSpec = re.compile(
r'args=(.*), varargs=.*defaults=(None|\((.*)\)), kwonlyargs=.*'
)
mo_o = patArgSpec.search(old_api_spec_str)
mo_n = patArgSpec.search(new_api_spec_str)
if not (mo_o and mo_n):
# error
logger.warning("old_api_spec_str: %s", old_api_spec_str)
logger.warning("new_api_spec_str: %s", new_api_spec_str)
return False
args_o = eval(mo_o.group(1))
args_n = eval(mo_n.group(1))
defaults_o = mo_o.group(2) if mo_o.group(3) is None else mo_o.group(3)
defaults_n = mo_n.group(2) if mo_n.group(3) is None else mo_n.group(3)
defaults_o = defaults_o.split(', ') if defaults_o else []
defaults_n = defaults_n.split(', ') if defaults_n else []
return _check_compatible(args_o, args_n, defaults_o, defaults_n)
def read_argspec_from_file(specfile):
"""
read FullArgSpec from spec file
"""
res_dict = {}
patArgSpec = re.compile(
r'^(paddle[^,]+)\s+\((ArgSpec.*),\s\(\'document\W*([0-9a-z]{32})'
)
fullargspec_prefix = 'inspect.Full'
for line in specfile.readlines():
mo = patArgSpec.search(line)
if mo and mo.group(2) != 'ArgSpec()':
logger.debug("%s argspec: %s", mo.group(1), mo.group(2))
try:
res_dict[mo.group(1)] = eval(fullargspec_prefix + mo.group(2))
except: # SyntaxError, NameError:
res_dict[mo.group(1)] = fullargspec_prefix + mo.group(2)
return res_dict
arguments = [
# flags, dest, type, default, help
]
def parse_args():
"""
Parse input arguments
"""
global arguments
parser = argparse.ArgumentParser(
description='check api compatible across versions'
)
parser.add_argument('--debug', dest='debug', action="store_true")
parser.add_argument(
'prev',
type=argparse.FileType('r'),
help='the previous version (the version from develop branch)',
)
parser.add_argument(
'post',
type=argparse.FileType('r'),
help='the post version (the version from PullRequest)',
)
for item in arguments:
parser.add_argument(
item[0], dest=item[1], help=item[4], type=item[2], default=item[3]
)
if len(sys.argv) < 2:
parser.print_help()
sys.exit(1)
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
if args.debug:
logger.setLevel(logging.DEBUG)
else:
logger.setLevel(logging.INFO)
if args.prev and args.post:
prev_spec = read_argspec_from_file(args.prev)
post_spec = read_argspec_from_file(args.post)
diff_api_names = []
for as_post_name, as_post in post_spec.items():
as_prev = prev_spec.get(as_post_name)
if as_prev is None: # the api is deleted
continue
if isinstance(as_prev, str) or isinstance(as_post, str):
as_prev_str = (
as_prev if isinstance(as_prev, str) else repr(as_prev)
)
as_post_str = (
as_post if isinstance(as_post, str) else repr(as_post)
)
if not check_compatible_str(as_prev_str, as_post_str):
diff_api_names.append(as_post_name)
else:
if not check_compatible(as_prev, as_post):
diff_api_names.append(as_post_name)
if diff_api_names:
print('\n'.join(diff_api_names))