forked from Lightning-AI/litgpt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_chat.py
50 lines (40 loc) · 1.37 KB
/
test_chat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import subprocess
import sys
from itertools import repeat
from pathlib import Path
from unittest.mock import MagicMock
import pytest
import torch
@pytest.mark.parametrize(
("generated", "stop_tokens", "expected"),
[
(repeat(1), (), [1] * 8),
([1, 2, 3, 0], ([0],), [1, 2, 3]),
([1, 2, 3, 0], ([9], [2, 4], [1, 2, 3, 0]), []),
([1, 2, 3, 0, 0], ([0, 0, 0], [0, 0]), [1, 2, [3]]),
],
)
def test_generate(generated, stop_tokens, expected):
import chat.base as chat
input_idx = torch.tensor([5, 3])
max_returned_tokens = len(input_idx) + 8
model = MagicMock()
model.config.block_size = 100
model.max_seq_length = 100
original_multinomial = torch.multinomial
it = iter(generated)
def multinomial(*_, **__):
out = next(it)
return torch.tensor([out])
chat.torch.multinomial = multinomial
actual = chat.generate(model, input_idx, max_returned_tokens, stop_tokens=stop_tokens)
actual = list(actual)
chat.torch.multinomial = original_multinomial
for t in actual:
assert t.dtype == torch.long
assert [t.tolist() for t in actual] == expected
def test_cli():
cli_path = Path(__file__).parent.parent / "chat" / "base.py"
output = subprocess.check_output([sys.executable, cli_path, "-h"])
output = str(output.decode())
assert "Starts a conversation" in output