Skip to content

Commit

Permalink
Add tests for generate/adapter and generate/adapter_v2 (Lightning-AI#103
Browse files Browse the repository at this point in the history
)

Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
agmo1993 and carmocca authored Jun 6, 2023
1 parent 25fb51e commit a90a105
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 12 deletions.
4 changes: 4 additions & 0 deletions generate/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
import lightning as L
import torch

# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from generate.base import generate
from lit_parrot import Tokenizer
from lit_parrot.adapter import Parrot, Config
Expand Down
13 changes: 13 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import sys
from pathlib import Path

Expand All @@ -10,3 +11,15 @@
def add_wd_to_path():
# this adds support for running tests without the package installed
sys.path.append(str(wd))


@pytest.fixture()
def fake_checkpoint_dir(tmp_path):
os.chdir(tmp_path)
checkpoint_dir = tmp_path / "checkpoints" / "tmp"
checkpoint_dir.mkdir(parents=True)
(checkpoint_dir / "lit_model.pth").touch()
(checkpoint_dir / "lit_config.json").touch()
(checkpoint_dir / "tokenizer.json").touch()
(checkpoint_dir / "tokenizer_config.json").touch()
return checkpoint_dir
12 changes: 0 additions & 12 deletions tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,6 @@
import torch


@pytest.fixture()
def fake_checkpoint_dir(tmp_path):
os.chdir(tmp_path)
checkpoint_dir = tmp_path / "checkpoints" / "tmp"
checkpoint_dir.mkdir(parents=True)
(checkpoint_dir / "lit_model.pth").touch()
(checkpoint_dir / "lit_config.json").touch()
(checkpoint_dir / "tokenizer.json").touch()
(checkpoint_dir / "tokenizer_config.json").touch()
return checkpoint_dir


@pytest.mark.parametrize("max_seq_length", (10, 20 + 5))
def test_generate(max_seq_length):
import generate.base as generate
Expand Down
64 changes: 64 additions & 0 deletions tests/test_generate_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import json
import subprocess
import sys
from contextlib import redirect_stdout, redirect_stderr
from io import StringIO
from pathlib import Path
from unittest import mock
from unittest.mock import Mock, call, ANY

import pytest
import torch


@mock.patch("torch.cuda.is_bf16_supported", return_value=False)
@pytest.mark.parametrize("version", ("v1", "v2"))
def test_main(_, fake_checkpoint_dir, monkeypatch, version):
if version == "v1":
import generate.adapter as generate
else:
import generate.adapter_v2 as generate

config_path = fake_checkpoint_dir / "lit_config.json"
config = {"block_size": 16, "vocab_size": 50, "n_layer": 2, "n_head": 4, "n_embd": 8, "rotary_percentage": 1}
config_path.write_text(json.dumps(config))

class FabricMock(Mock):
@property
def device(self):
return torch.device("cpu")

monkeypatch.setattr(generate.L, "Fabric", FabricMock)
load_mock = Mock()
load_mock.return_value = load_mock
load_mock.__enter__ = Mock()
load_mock.__exit__ = Mock()
monkeypatch.setattr(generate, "lazy_load", load_mock)
tokenizer_mock = Mock()
tokenizer_mock.return_value.encode.return_value = torch.tensor([[1, 2, 3]])
tokenizer_mock.return_value.decode.return_value = "### Response:foo bar baz"
monkeypatch.setattr(generate, "Tokenizer", tokenizer_mock)
generate_mock = Mock()
generate_mock.return_value = torch.tensor([[3, 2, 1]])
monkeypatch.setattr(generate, "generate", generate_mock)

num_samples = 1
out, err = StringIO(), StringIO()
with redirect_stdout(out), redirect_stderr(err):
generate.main(temperature=2.0, top_k=2, checkpoint_dir=fake_checkpoint_dir)

assert len(tokenizer_mock.return_value.decode.mock_calls) == num_samples
assert torch.allclose(tokenizer_mock.return_value.decode.call_args[0][0], generate_mock.return_value)
assert generate_mock.mock_calls == [call(ANY, ANY, ANY, max_seq_length=101, temperature=2.0, top_k=2, eos_id=ANY)] * num_samples
# only the generated result is printed to stdout
assert out.getvalue() == "foo bar baz\n" * num_samples

assert "'padded_vocab_size': 512, 'n_layer': 2, 'n_head': 4, 'n_embd': 8" in err.getvalue()


@pytest.mark.parametrize("version", ("", "_v2"))
def test_cli(version):
cli_path = Path(__file__).parent.parent / "generate" / f"adapter{version}.py"
output = subprocess.check_output([sys.executable, cli_path, "-h"])
output = str(output.decode())
assert "Generates a response" in output

0 comments on commit a90a105

Please sign in to comment.