Skip to content

Commit

Permalink
Parametrize CLI tests (Lightning-AI#1093)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored and awaelchli committed Mar 15, 2024
1 parent 7e901ba commit a7f41ef
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 18 deletions.
11 changes: 8 additions & 3 deletions tests/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,13 @@ def test_main(mocked_input, stop_iteration, fake_checkpoint_dir, monkeypatch, te
assert "'padded_vocab_size': 512, 'n_layer': 2, 'n_head': 4" in err.getvalue()


def test_cli():
cli_path = Path(__file__).parent.parent / "litgpt/chat/base.py"
output = subprocess.check_output([sys.executable, cli_path, "-h"])
@pytest.mark.parametrize("mode", ["file", "entrypoint"])
def test_cli(mode):
if mode == "file":
cli_path = Path(__file__).parent.parent / "litgpt/chat/base.py"
args = [sys.executable, cli_path, "-h"]
else:
args = ["litgpt", "chat", "-h"]
output = subprocess.check_output(args)
output = str(output.decode())
assert "Starts a conversation" in output
11 changes: 8 additions & 3 deletions tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,14 @@ def test_main(fake_checkpoint_dir, monkeypatch, tensor_like):
assert "'padded_vocab_size': 512, 'n_layer': 2, 'n_head': 4" in err.getvalue()


def test_cli():
cli_path = Path(__file__).parent.parent / "litgpt/generate/base.py"
output = subprocess.check_output([sys.executable, cli_path, "-h"])
@pytest.mark.parametrize("mode", ["file", "entrypoint"])
def test_cli(mode):
if mode == "file":
cli_path = Path(__file__).parent.parent / "litgpt/generate/base.py"
args = [sys.executable, cli_path, "-h"]
else:
args = ["litgpt", "generate", "base", "-h"]
output = subprocess.check_output(args)
output = str(output.decode())
assert "Generates text samples" in output

Expand Down
11 changes: 8 additions & 3 deletions tests/test_generate_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,13 @@ def test_main(fake_checkpoint_dir, monkeypatch, version, tensor_like):


@pytest.mark.parametrize("version", ("", "_v2"))
def test_cli(version):
cli_path = Path(__file__).parent.parent / f"litgpt/generate/adapter{version}.py"
output = subprocess.check_output([sys.executable, cli_path, "-h"])
@pytest.mark.parametrize("mode", ["file", "entrypoint"])
def test_cli(version, mode):
if mode == "file":
cli_path = Path(__file__).parent.parent / f"litgpt/generate/adapter{version}.py"
args = [sys.executable, cli_path, "-h"]
else:
args = ["litgpt", "generate", f"adapter{version}", "-h"]
output = subprocess.check_output(args)
output = str(output.decode())
assert "Generates a response" in output
12 changes: 9 additions & 3 deletions tests/test_generate_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from unittest import mock
from unittest.mock import ANY, Mock, call

import pytest
import torch


Expand Down Expand Up @@ -55,8 +56,13 @@ def test_main(fake_checkpoint_dir, monkeypatch, tensor_like):
assert "'padded_vocab_size': 512, 'n_layer': 2, 'n_head': 4, 'head_size': 2, 'n_embd': 8" in err.getvalue()


def test_cli():
cli_path = Path(__file__).parent.parent / "litgpt/generate/lora.py"
output = subprocess.check_output([sys.executable, cli_path, "-h"])
@pytest.mark.parametrize("mode", ["file", "entrypoint"])
def test_cli(mode):
if mode == "file":
cli_path = Path(__file__).parent.parent / "litgpt/generate/lora.py"
args = [sys.executable, cli_path, "-h"]
else:
args = ["litgpt", "generate", "lora", "-h"]
output = subprocess.check_output(args)
output = str(output.decode())
assert "Generates a response" in output
11 changes: 8 additions & 3 deletions tests/test_generate_sequentially.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,13 @@ def test_base_with_sequentially(tmp_path):
assert base_stdout == sequential_stdout


def test_cli():
cli_path = root / "litgpt/generate/sequentially.py"
output = subprocess.check_output([sys.executable, cli_path, "-h"])
@pytest.mark.parametrize("mode", ["file", "entrypoint"])
def test_cli(mode):
if mode == "file":
cli_path = Path(__file__).parent.parent / "litgpt/generate/sequentially.py"
args = [sys.executable, cli_path, "-h"]
else:
args = ["litgpt", "generate", "sequentially", "-h"]
output = subprocess.check_output(args)
output = str(output.decode())
assert "Generates text samples" in output
11 changes: 8 additions & 3 deletions tests/test_generate_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,13 @@ def test_tp(tmp_path):
assert tp_stdout.startswith("What food do llamas eat?")


def test_cli():
cli_path = root / "litgpt/generate/tp.py"
output = subprocess.check_output([sys.executable, cli_path, "-h"])
@pytest.mark.parametrize("mode", ["file", "entrypoint"])
def test_cli(mode):
if mode == "file":
cli_path = Path(__file__).parent.parent / "litgpt/generate/tp.py"
args = [sys.executable, cli_path, "-h"]
else:
args = ["litgpt", "generate", "tp", "-h"]
output = subprocess.check_output(args)
output = str(output.decode())
assert "Generates text samples" in output

0 comments on commit a7f41ef

Please sign in to comment.