forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_parallel_tool_calls.py
207 lines (163 loc) · 7.45 KB
/
test_parallel_tool_calls.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
# SPDX-License-Identifier: Apache-2.0
import json
from typing import Optional
import openai
import pytest
from .utils import (MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, SEARCH_TOOL,
WEATHER_TOOL, ServerConfig)
# test: getting the model to generate parallel tool calls (streaming/not)
# when requested. NOTE that not all models may support this, so some exclusions
# may be added in the future. e.g. llama 3.1 models are not designed to support
# parallel tool calls.
@pytest.mark.asyncio
async def test_parallel_tool_calls(client: openai.AsyncOpenAI,
server_config: ServerConfig):
if not server_config.get("supports_parallel", True):
pytest.skip("The {} model doesn't support parallel tool calls".format(
server_config["model"]))
models = await client.models.list()
model_name: str = models.data[0].id
chat_completion = await client.chat.completions.create(
messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
temperature=0,
max_completion_tokens=200,
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False)
choice = chat_completion.choices[0]
stop_reason = chat_completion.choices[0].finish_reason
non_streamed_tool_calls = chat_completion.choices[0].message.tool_calls
# make sure 2 tool calls are present
assert choice.message.role == "assistant"
assert non_streamed_tool_calls is not None
assert len(non_streamed_tool_calls) == 2
for tool_call in non_streamed_tool_calls:
# make sure the tool includes a function and ID
assert tool_call.type == "function"
assert tool_call.function is not None
assert isinstance(tool_call.id, str)
assert len(tool_call.id) >= 9
# make sure the weather tool was called correctly
assert tool_call.function.name == WEATHER_TOOL["function"]["name"]
assert isinstance(tool_call.function.arguments, str)
parsed_arguments = json.loads(tool_call.function.arguments)
assert isinstance(parsed_arguments, dict)
assert isinstance(parsed_arguments.get("city"), str)
assert isinstance(parsed_arguments.get("state"), str)
assert stop_reason == "tool_calls"
# make the same request, streaming
stream = await client.chat.completions.create(
model=model_name,
messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS,
temperature=0,
max_completion_tokens=200,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False,
stream=True)
role_name: Optional[str] = None
finish_reason_count: int = 0
tool_call_names: list[str] = []
tool_call_args: list[str] = []
tool_call_idx: int = -1
tool_call_id_count: int = 0
async for chunk in stream:
# if there's a finish reason make sure it's tools
if chunk.choices[0].finish_reason:
finish_reason_count += 1
assert chunk.choices[0].finish_reason == 'tool_calls'
# if a role is being streamed make sure it wasn't already set to
# something else
if chunk.choices[0].delta.role:
assert not role_name or role_name == 'assistant'
role_name = 'assistant'
# if a tool call is streamed make sure there's exactly one
# (based on the request parameters
streamed_tool_calls = chunk.choices[0].delta.tool_calls
if streamed_tool_calls and len(streamed_tool_calls) > 0:
# make sure only one diff is present - correct even for parallel
assert len(streamed_tool_calls) == 1
tool_call = streamed_tool_calls[0]
# if a new tool is being called, set up empty arguments
if tool_call.index != tool_call_idx:
tool_call_idx = tool_call.index
tool_call_args.append("")
# if a tool call ID is streamed, make sure one hasn't been already
if tool_call.id:
tool_call_id_count += 1
assert (isinstance(tool_call.id, str)
and (len(tool_call.id) >= 9))
# if parts of the function start being streamed
if tool_call.function:
# if the function name is defined, set it. it should be streamed
# IN ENTIRETY, exactly one time.
if tool_call.function.name:
assert isinstance(tool_call.function.name, str)
tool_call_names.append(tool_call.function.name)
if tool_call.function.arguments:
# make sure they're a string and then add them to the list
assert isinstance(tool_call.function.arguments, str)
tool_call_args[
tool_call.index] += tool_call.function.arguments
assert finish_reason_count == 1
assert role_name == 'assistant'
assert (len(non_streamed_tool_calls) == len(tool_call_names) ==
len(tool_call_args))
for i in range(2):
assert non_streamed_tool_calls[i].function.name == tool_call_names[i]
streamed_args = json.loads(tool_call_args[i])
non_streamed_args = json.loads(
non_streamed_tool_calls[i].function.arguments)
assert streamed_args == non_streamed_args
# test: providing parallel tool calls back to the model to get a response
# (streaming/not)
@pytest.mark.asyncio
async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI,
server_config: ServerConfig):
if not server_config.get("supports_parallel", True):
pytest.skip("The {} model doesn't support parallel tool calls".format(
server_config["model"]))
models = await client.models.list()
model_name: str = models.data[0].id
chat_completion = await client.chat.completions.create(
messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
temperature=0,
max_completion_tokens=200,
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False)
choice = chat_completion.choices[0]
assert choice.finish_reason != "tool_calls" # "stop" or "length"
assert choice.message.role == "assistant"
assert choice.message.tool_calls is None \
or len(choice.message.tool_calls) == 0
assert choice.message.content is not None
assert "98" in choice.message.content # Dallas temp in tool response
assert "78" in choice.message.content # Orlando temp in tool response
stream = await client.chat.completions.create(
messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE,
temperature=0,
max_completion_tokens=200,
model=model_name,
tools=[WEATHER_TOOL, SEARCH_TOOL],
logprobs=False,
stream=True)
chunks: list[str] = []
finish_reason_count = 0
role_sent: bool = False
async for chunk in stream:
delta = chunk.choices[0].delta
if delta.role:
assert not role_sent
assert delta.role == "assistant"
role_sent = True
if delta.content:
chunks.append(delta.content)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
assert chunk.choices[0].finish_reason == choice.finish_reason
assert not delta.tool_calls or len(delta.tool_calls) == 0
assert role_sent
assert finish_reason_count == 1
assert len(chunks)
assert "".join(chunks) == choice.message.content