forked from ggerganov/llama.cpp
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
server : add more test cases (ggerganov#10569)
* server : add split model test * add test speculative * add invalid cases
- Loading branch information
Showing
6 changed files
with
186 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import pytest | ||
from utils import * | ||
|
||
# We use a F16 MOE gguf as main model, and q4_0 as draft model | ||
|
||
server = ServerPreset.stories15m_moe() | ||
|
||
MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.gguf" | ||
|
||
def create_server(): | ||
global server | ||
server = ServerPreset.stories15m_moe() | ||
# download draft model file if needed | ||
file_name = MODEL_DRAFT_FILE_URL.split('/').pop() | ||
model_draft_file = f'../../../{file_name}' | ||
if not os.path.exists(model_draft_file): | ||
print(f"Downloading {MODEL_DRAFT_FILE_URL} to {model_draft_file}") | ||
with open(model_draft_file, 'wb') as f: | ||
f.write(requests.get(MODEL_DRAFT_FILE_URL).content) | ||
print(f"Done downloading draft model file") | ||
# set default values | ||
server.model_draft = model_draft_file | ||
server.draft_min = 4 | ||
server.draft_max = 8 | ||
|
||
|
||
@pytest.fixture(scope="module", autouse=True) | ||
def fixture_create_server(): | ||
return create_server() | ||
|
||
|
||
def test_with_and_without_draft(): | ||
global server | ||
server.model_draft = None # disable draft model | ||
server.start() | ||
res = server.make_request("POST", "/completion", data={ | ||
"prompt": "I believe the meaning of life is", | ||
"temperature": 0.0, | ||
"top_k": 1, | ||
}) | ||
assert res.status_code == 200 | ||
content_no_draft = res.body["content"] | ||
server.stop() | ||
|
||
# create new server with draft model | ||
create_server() | ||
server.start() | ||
res = server.make_request("POST", "/completion", data={ | ||
"prompt": "I believe the meaning of life is", | ||
"temperature": 0.0, | ||
"top_k": 1, | ||
}) | ||
assert res.status_code == 200 | ||
content_draft = res.body["content"] | ||
|
||
assert content_no_draft == content_draft | ||
|
||
|
||
def test_different_draft_min_draft_max(): | ||
global server | ||
test_values = [ | ||
(1, 2), | ||
(1, 4), | ||
(4, 8), | ||
(4, 12), | ||
(8, 16), | ||
] | ||
last_content = None | ||
for draft_min, draft_max in test_values: | ||
server.stop() | ||
server.draft_min = draft_min | ||
server.draft_max = draft_max | ||
server.start() | ||
res = server.make_request("POST", "/completion", data={ | ||
"prompt": "I believe the meaning of life is", | ||
"temperature": 0.0, | ||
"top_k": 1, | ||
}) | ||
assert res.status_code == 200 | ||
if last_content is not None: | ||
assert last_content == res.body["content"] | ||
last_content = res.body["content"] | ||
|
||
|
||
@pytest.mark.parametrize("n_slots,n_requests", [ | ||
(1, 2), | ||
(2, 2), | ||
]) | ||
def test_multi_requests_parallel(n_slots: int, n_requests: int): | ||
global server | ||
server.n_slots = n_slots | ||
server.start() | ||
tasks = [] | ||
for _ in range(n_requests): | ||
tasks.append((server.make_request, ("POST", "/completion", { | ||
"prompt": "I believe the meaning of life is", | ||
"temperature": 0.0, | ||
"top_k": 1, | ||
}))) | ||
results = parallel_function_calls(tasks) | ||
for res in results: | ||
assert res.status_code == 200 | ||
assert match_regex("(wise|kind|owl|answer)+", res.body["content"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters