Skip to content

Commit

Permalink
Tests for chat/base.py: decode and main functions (Lightning-AI…
Browse files Browse the repository at this point in the history
…#787)

Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
Andrei-Aksionov and carmocca authored Dec 5, 2023
1 parent dbd614c commit 2907c08
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 1 deletion.
1 change: 1 addition & 0 deletions chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def main(
- gptq.int4: 4-bit quantization from GPTQ
for more details, see https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/quantize.md
precision: Indicates the Fabric precision setting to use.
compile: Whether to use compilation to speed up token generation. Will increase startup time.
"""
precision = precision or get_default_supported_precision(training=False)

Expand Down
70 changes: 69 additions & 1 deletion tests/test_chat.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import json
import subprocess
import sys
from contextlib import redirect_stderr, redirect_stdout
from io import StringIO
from itertools import repeat
from pathlib import Path
from unittest.mock import MagicMock
from unittest.mock import ANY, MagicMock, Mock, call, patch

import pytest
import torch
Expand Down Expand Up @@ -47,6 +50,71 @@ def multinomial(*_, **__):
assert torch.cat(actual).tolist() == expected


@pytest.mark.parametrize("tokenizer_backend", ["huggingface", "sentencepiece"])
def test_decode(tokenizer_backend):
from lightning.fabric import Fabric

import chat.base as chat

class Tokenizer:
backend = tokenizer_backend
id2token = {1: "foo ", 2: "bar ", 3: "baz "}

def decode(self, tensor: torch.Tensor) -> str:
tensor = [tensor] if tensor.ndim == 0 else tensor
return "".join(self.id2token[int(value)] for value in tensor)

tokenizer_mock = Tokenizer()

fabric = Fabric(devices=1, accelerator="cpu")

token_stream = torch.tensor([3, 2, 1])
out, err = StringIO(), StringIO()
with redirect_stdout(out), redirect_stderr(err):
chat.decode(fabric, tokenizer_mock, token_stream)

assert out.getvalue() == "baz bar foo "


@patch("chat.base.input")
@pytest.mark.parametrize("stop_iteration", [KeyboardInterrupt, ""])
def test_main(mocked_input, stop_iteration, fake_checkpoint_dir, monkeypatch, tensor_like):
import chat.base as chat

# these values will be iteratively provided for each `input()` call
mocked_input.side_effect = ["Hello", stop_iteration]

config_path = fake_checkpoint_dir / "lit_config.json"
config = {"block_size": 128, "vocab_size": 50, "n_layer": 2, "n_head": 4, "n_embd": 8, "rotary_percentage": 1}
config_path.write_text(json.dumps(config))

load_mock = Mock()
load_mock.return_value = load_mock
monkeypatch.setattr(chat, "load_checkpoint", load_mock)
tokenizer_mock = Mock()
tokenizer_mock.return_value.backend = "sentencepiece"
tokenizer_mock.return_value.eos_id = 0
tokenizer_mock.return_value.encode.return_value = torch.tensor([1, 2, 3])
tokenizer_mock.return_value.decode.return_value = "foo bar baz"
monkeypatch.setattr(chat, "Tokenizer", tokenizer_mock)
generate_mock = Mock()
generate_mock.return_value = torch.tensor([3, 2, 1])
monkeypatch.setattr(chat, "generate", generate_mock)

out, err = StringIO(), StringIO()
with redirect_stdout(out), redirect_stderr(err):
chat.main(temperature=2.0, top_k=2, checkpoint_dir=fake_checkpoint_dir)

# decoding is done per each generated item
assert len(tokenizer_mock.return_value.decode.mock_calls) == generate_mock.return_value.numel()
assert torch.allclose(tokenizer_mock.return_value.decode.call_args[0][0], generate_mock.return_value)
assert generate_mock.mock_calls == [call(ANY, tensor_like, 128, temperature=2.0, top_k=2, stop_tokens=([0],))]
# # only the generated result is printed to stdout
assert out.getvalue() == ">> Reply: foo bar baz\n"

assert "'padded_vocab_size': 512, 'n_layer': 2, 'n_head': 4" in err.getvalue()


def test_cli():
cli_path = Path(__file__).parent.parent / "chat" / "base.py"
output = subprocess.check_output([sys.executable, cli_path, "-h"])
Expand Down

0 comments on commit 2907c08

Please sign in to comment.