diff --git a/ci_scripts/check_api_parameters.py b/ci_scripts/check_api_parameters.py index a4602bf1279..283a55537d3 100644 --- a/ci_scripts/check_api_parameters.py +++ b/ci_scripts/check_api_parameters.py @@ -19,6 +19,7 @@ import re import sys import inspect +import ast import paddle @@ -61,15 +62,10 @@ def _check_params_in_description(rstfilename, paramstr): flag = True params_intitle = [] if paramstr: - _params_intitle = paramstr.split( - ', ' - ) # is there any parameter with default value of type list/tuple? may break this. - for s in _params_intitle: - if ':' in s: # annotations - pname = s.split(':') - params_intitle.append(pname[0].strip()) - else: - params_intitle.append(s.strip()) + fake_func = ast.parse(f'def fake_func({paramstr}): pass') + for arg in fake_func.body[0].args.args: + params_intitle.append(arg.arg) + funcdescnode = extract_params_desc_from_rst_file(rstfilename) if funcdescnode: items = funcdescnode.children[1].children[0].children