Skip to content

Commit

Permalink
Tests for generation with LoRA (Lightning-AI#294)
Browse files Browse the repository at this point in the history
Co-authored-by: Adrian Wälchli <[email protected]>
  • Loading branch information
Andrei-Aksionov and awaelchli authored Jul 31, 2023
1 parent 2818b81 commit ae56f4e
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 2 deletions.
5 changes: 3 additions & 2 deletions generate/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def main(
check_valid_checkpoint_dir(checkpoint_dir)

with open(checkpoint_dir / "lit_config.json") as fp:
config = Config(
config_params = dict(
r=lora_r,
alpha=lora_alpha,
dropout=lora_dropout,
Expand All @@ -85,8 +85,9 @@ def main(
to_projection=lora_projection,
to_mlp=lora_mlp,
to_head=lora_head,
**json.load(fp),
)
config_params.update(**json.load(fp))
config = Config(**config_params)

if quantize is not None and devices > 1:
raise NotImplementedError
Expand Down
79 changes: 79 additions & 0 deletions tests/test_generate_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import json
import subprocess
import sys
from contextlib import redirect_stdout, redirect_stderr
from io import StringIO
from pathlib import Path
from unittest import mock
from unittest.mock import Mock, call, ANY

import pytest
import torch


def test_main(fake_checkpoint_dir, monkeypatch):
import generate.lora as generate

config_path = fake_checkpoint_dir / "lit_config.json"
config = {
"block_size": 16,
"vocab_size": 50,
"n_layer": 2,
"n_head": 4,
"n_embd": 8,
"rotary_percentage": 1,
"to_query": False,
"to_value": False,
"to_projection": True,
}
config_path.write_text(json.dumps(config))

load_mock = Mock()
load_mock.return_value = load_mock
load_mock.__enter__ = Mock()
load_mock.__exit__ = Mock()
monkeypatch.setattr(generate, "lazy_load", load_mock)
tokenizer_mock = Mock()
tokenizer_mock.return_value.encode.return_value = torch.tensor([[1, 2, 3]])
tokenizer_mock.return_value.decode.return_value = "### Response:foo bar baz"
monkeypatch.setattr(generate, "Tokenizer", tokenizer_mock)
generate_mock = Mock()
generate_mock.return_value = torch.tensor([[3, 2, 1]])
monkeypatch.setattr(generate, "generate", generate_mock)

num_samples = 1
out, err = StringIO(), StringIO()
with redirect_stdout(out), redirect_stderr(err):
generate.main(temperature=2.0, top_k=2, checkpoint_dir=fake_checkpoint_dir)

assert len(tokenizer_mock.return_value.decode.mock_calls) == num_samples
assert torch.allclose(tokenizer_mock.return_value.decode.call_args[0][0], generate_mock.return_value)
assert (
generate_mock.mock_calls
== [call(ANY, ANY, ANY, max_seq_length=101, temperature=2.0, top_k=2, eos_id=ANY)] * num_samples
)
# only the generated result is printed to stdout
assert out.getvalue() == "foo bar baz\n" * num_samples

assert "'padded_vocab_size': 512, 'n_layer': 2, 'n_head': 4, 'n_embd': 8" in err.getvalue()


def test_lora_variables_exist():
import generate.lora as generate

for lora_argument in ("r", "alpha", "dropout", "query", "key", "value", "projection", "mlp", "head"):
assert getattr(generate, f"lora_{lora_argument}", None) is not None


def test_lora_is_enabled():
import generate.lora as generate

lora_arguments = ("query", "key", "value", "projection", "mlp", "head")
assert any(getattr(generate, f"lora_{lora_argument}") for lora_argument in lora_arguments)


def test_cli():
cli_path = Path(__file__).parent.parent / "generate" / "lora.py"
output = subprocess.check_output([sys.executable, cli_path, "-h"])
output = str(output.decode())
assert "Generates a response" in output

0 comments on commit ae56f4e

Please sign in to comment.