forked from RockChinQ/LangBot
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfuncschema.py
116 lines (97 loc) · 3.32 KB
/
funcschema.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import sys
import re
import inspect
def get_func_schema(function: callable) -> dict:
"""
Return the data schema of a function.
{
"function": function,
"description": "function description",
"parameters": {
"type": "object",
"properties": {
"parameter_a": {
"type": "str",
"description": "parameter_a description"
},
"parameter_b": {
"type": "int",
"description": "parameter_b description"
},
"parameter_c": {
"type": "str",
"description": "parameter_c description",
"enum": ["a", "b", "c"]
},
},
"required": ["parameter_a", "parameter_b"]
}
}
"""
func_doc = function.__doc__
# Google Style Docstring
if func_doc is None:
raise Exception("Function {} has no docstring.".format(function.__name__))
func_doc = func_doc.strip().replace(' ','').replace('\t', '')
# extract doc of args from docstring
doc_spt = func_doc.split('\n\n')
desc = doc_spt[0]
args = doc_spt[1] if len(doc_spt) > 1 else ""
returns = doc_spt[2] if len(doc_spt) > 2 else ""
# extract args
# delete the first line of args
arg_lines = args.split('\n')[1:]
arg_doc_list = re.findall(r'(\w+)(\((\w+)\))?:\s*(.*)', args)
args_doc = {}
for arg_line in arg_lines:
doc_tuple = re.findall(r'(\w+)(\(([\w\[\]]+)\))?:\s*(.*)', arg_line)
if len(doc_tuple) == 0:
continue
args_doc[doc_tuple[0][0]] = doc_tuple[0][3]
# extract returns
return_doc_list = re.findall(r'(\w+):\s*(.*)', returns)
params = enumerate(inspect.signature(function).parameters.values())
parameters = {
"type": "object",
"required": [],
"properties": {},
}
for i, param in params:
# 排除 self, query
if param.name in ['self', 'query']:
continue
param_type = param.annotation.__name__
type_name_mapping = {
"str": "string",
"int": "integer",
"float": "number",
"bool": "boolean",
"list": "array",
"dict": "object",
}
if param_type in type_name_mapping:
param_type = type_name_mapping[param_type]
parameters['properties'][param.name] = {
"type": param_type,
"description": args_doc[param.name],
}
# add schema for array
if param_type == "array":
# extract type of array, the int of list[int]
# use re
array_type_tuple = re.findall(r'list\[(\w+)\]', str(param.annotation))
array_type = 'string'
if len(array_type_tuple) > 0:
array_type = array_type_tuple[0]
if array_type in type_name_mapping:
array_type = type_name_mapping[array_type]
parameters['properties'][param.name]["items"] = {
"type": array_type,
}
if param.default is inspect.Parameter.empty:
parameters["required"].append(param.name)
return {
"function": function,
"description": desc,
"parameters": parameters,
}