Skip to content

Commit

Permalink
Add generate.py and prepare_shakespeare.py tests (Lightning-AI#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Mar 28, 2023
1 parent 98e09b5 commit db1bb28
Show file tree
Hide file tree
Showing 10 changed files with 195 additions and 45 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/cpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ jobs:

- name: Install dependencies
run: |
pip install pytest .
pip install pytest . -r requirements.txt
pip list
- name: Run tests
run: |
pytest -v --durations=10
pytest -v --durations=10 --disable-pytest-warnings --strict-markers
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ __pycache__
.idea
.DS_Store
*.egg-info
build

# data
data
Expand Down
26 changes: 13 additions & 13 deletions generate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import sys
import time
from pathlib import Path
from typing import Optional

import lightning as L
Expand Down Expand Up @@ -47,7 +48,7 @@ def generate(

# forward
logits = model(idx_cond)
logits = logits[:, -1, :] / temperature
logits = logits[:, -1] / temperature

# optionally crop the logits to only the top k options
if top_k is not None:
Expand All @@ -58,7 +59,7 @@ def generate(
idx_next = torch.multinomial(probs, num_samples=1)

# concatenate the new column
idx[:, t] = idx_next
idx[:, t:] = idx_next

return idx

Expand All @@ -73,8 +74,8 @@ def main(
# compilation fails as it does not support torch.complex64 for RoPE
# compile: bool = False,
accelerator: str = "auto",
checkpoint_path: Optional[str] = None,
tokenizer_path: Optional[str] = None,
checkpoint_path: Optional[Path] = None,
tokenizer_path: Optional[Path] = None,
model_size: str = "7B",
quantize: bool = False,
) -> None:
Expand All @@ -95,12 +96,11 @@ def main(
quantize: Whether to quantize the model using the `LLM.int8()` method
"""
if not checkpoint_path:
checkpoint_path = f"./checkpoints/lit-llama/{model_size}/state_dict.pth"
checkpoint_path = Path(f"./checkpoints/lit-llama/{model_size}/state_dict.pth")
if not tokenizer_path:
tokenizer_path = "./checkpoints/lit-llama/tokenizer.model"

assert os.path.isfile(checkpoint_path)
assert os.path.isfile(tokenizer_path)
tokenizer_path = Path("./checkpoints/lit-llama/tokenizer.model")
assert checkpoint_path.is_file()
assert tokenizer_path.is_file()

fabric = L.Fabric(accelerator=accelerator, devices=1)

Expand Down Expand Up @@ -128,8 +128,8 @@ def main(
model = fabric.setup_module(model)

tokenizer = Tokenizer(tokenizer_path)
encoded_prompt = tokenizer.encode(prompt, bos=True, eos=False).to(fabric.device)
encoded_prompt = encoded_prompt[None, :]
encoded_prompt = tokenizer.encode(prompt, bos=True, eos=False, device=fabric.device)
encoded_prompt = encoded_prompt[None, :] # add batch dimension

L.seed_everything(1234)
t0 = time.time()
Expand All @@ -141,8 +141,8 @@ def main(
model.config.block_size, # type: ignore[union-attr,arg-type]
temperature=temperature,
top_k=top_k,
)
print(tokenizer.decode(y[0]))
)[0] # unpack batch dimension
print(tokenizer.decode(y))

print(f"Time for inference: {time.time() - t0:.02f} seconds", file=sys.stderr)
print(f"Memory used (GB): {torch.cuda.max_memory_reserved() / 1e9:.02f}", file=sys.stderr)
Expand Down
20 changes: 10 additions & 10 deletions lit_llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.scale * x_normed


llama_configs = {
"7B": dict(n_layer=32, n_head=32, n_embd=4096),
"13B": dict(n_layer=40, n_head=40, n_embd=5120),
"30B": dict(n_layer=60, n_head=52, n_embd=6656),
"65B": dict(n_layer=80, n_head=64, n_embd=8192),
}


@dataclass
class LLaMAConfig:
block_size: int = 4096
Expand All @@ -88,7 +80,15 @@ class LLaMAConfig:

@classmethod
def from_name(cls, name: str) -> Self:
return cls(**llama_configs[name])
return llama_configs[name]


llama_configs = {
"7B": LLaMAConfig(n_layer=32, n_head=32, n_embd=4096),
"13B": LLaMAConfig(n_layer=40, n_head=40, n_embd=5120),
"30B": LLaMAConfig(n_layer=60, n_head=52, n_embd=6656),
"65B": LLaMAConfig(n_layer=80, n_head=64, n_embd=8192),
}


class CausalSelfAttention(nn.Module):
Expand Down Expand Up @@ -206,7 +206,7 @@ def forward(self, idx: torch.Tensor) -> torch.Tensor:
x = block(x)
x = self.transformer.ln_f(x)

logits = self.lm_head(x)
logits = self.lm_head(x) # (b, t, vocab_size)

return logits

Expand Down
13 changes: 9 additions & 4 deletions lit_llama/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import os
from pathlib import Path
from typing import Optional

import torch
from sentencepiece import SentencePieceProcessor, SentencePieceTrainer


class Tokenizer:
"""Tokenizer for LLaMA."""

def __init__(self, model_path: str) -> None:
self.processor = SentencePieceProcessor(model_file=model_path)
def __init__(self, model_path: Path) -> None:
self.processor = SentencePieceProcessor(model_file=str(model_path))
self.bos_id = self.processor.bos_id()
self.eos_id = self.processor.eos_id()
self.pad_id = self.processor.pad_id()
Expand All @@ -16,13 +19,15 @@ def __init__(self, model_path: str) -> None:
def vocab_size(self) -> int:
return self.processor.vocab_size()

def encode(self, string: str, bos: bool = True, eos: bool = False) -> torch.Tensor:
def encode(
self, string: str, bos: bool = True, eos: bool = False, device: Optional[torch.device] = None
) -> torch.Tensor:
tokens = self.processor.encode(string)
if bos:
tokens = [self.bos_id] + tokens
if eos:
tokens = tokens + [self.eos_id]
return torch.tensor(tokens, dtype=torch.int)
return torch.tensor(tokens, dtype=torch.int, device=device)

def decode(self, tokens: torch.Tensor) -> str:
return self.processor.decode(tokens.tolist())
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@ lightning>=2.0.0
sentencepiece
tqdm # convert_checkpoint.py
numpy # train.py dataset memmap
jsonargparse # generate.py, convert_checkpoint.py CLI
jsonargparse[signatures] # generate.py, convert_checkpoint.py CLI
bitsandbytes # quantization.py
30 changes: 17 additions & 13 deletions scripts/prepare_shakespeare.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,20 @@
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import os
import sys
import requests
from pathlib import Path

import numpy as np
import requests


def prepare(destination_path: str = "data/shakespeare") -> None:
os.makedirs(destination_path, exist_ok=True)
def prepare(destination_path: Path = Path("data/shakespeare")) -> None:
"""Prepare the "Tiny Shakespeare" dataset."""
destination_path.mkdir(parents=True, exist_ok=True)

# download the tiny shakespeare dataset
input_file_path = os.path.join(destination_path, "input.txt")
if not os.path.exists(input_file_path):
input_file_path = destination_path / "input.txt"
if not input_file_path.exists():
data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
with open(input_file_path, "w") as f:
f.write(requests.get(data_url).text)
Expand All @@ -40,10 +43,10 @@ def prepare(destination_path: str = "data/shakespeare") -> None:
train_data = data[: int(n * 0.9)]
val_data = data[int(n * 0.9) :]

from tokenizer import Tokenizer
from lit_llama.tokenizer import Tokenizer

Tokenizer.train(input=input_file_path, destination=destination_path)
tokenizer = Tokenizer(os.path.join(destination_path, "tokenizer.model"))
tokenizer = Tokenizer(destination_path / "tokenizer.model")
train_ids = tokenizer.encode(train_data)
val_ids = tokenizer.encode(val_data)
print(f"train has {len(train_ids):,} tokens")
Expand All @@ -52,13 +55,14 @@ def prepare(destination_path: str = "data/shakespeare") -> None:
# export to bin files
train_ids = np.array(train_ids, dtype=np.uint16)
val_ids = np.array(val_ids, dtype=np.uint16)
train_ids.tofile(os.path.join(destination_path, "train.bin"))
val_ids.tofile(os.path.join(destination_path, "val.bin"))
train_ids.tofile(destination_path / "train.bin")
val_ids.tofile(destination_path / "val.bin")


if __name__ == "__main__":
wd = os.path.realpath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(wd)
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))

from jsonargparse import CLI

Expand Down
2 changes: 0 additions & 2 deletions tests/test_basic_functionality.py

This file was deleted.

119 changes: 119 additions & 0 deletions tests/test_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import functools
import subprocess
import sys
from contextlib import redirect_stdout
from io import StringIO
from pathlib import Path
from unittest import mock
from unittest.mock import Mock, PropertyMock, call, ANY

import pytest
import torch

wd = Path(__file__).parent.parent.absolute()


@functools.lru_cache(maxsize=1)
def load_generate_script():
sys.path.append(str(wd))

import generate

return generate


@pytest.mark.parametrize("B", (1, 2))
def test_generate(B):
generate = load_generate_script()

T, C = 5, 3
logits = torch.randn(B, T, C)
input_idx = torch.randint(10, size=(B, T))

model = Mock(return_value=logits)
max_new_tokens = 20

multinomial_results = []
original_multinomial = torch.multinomial

def multinomial(*args, **kwargs):
out = original_multinomial(*args, **kwargs)
multinomial_results.append(out)
return out

with mock.patch("torch.multinomial", multinomial):
out = generate.generate(model, input_idx, max_new_tokens, max_seq_length=10)

assert out.shape == (B, T + max_new_tokens)
multinomial_results = torch.hstack(multinomial_results)
expected = torch.cat((input_idx, multinomial_results), dim=1)
assert out.shape == expected.shape
torch.testing.assert_close(out, expected)


def test_main(tmp_path, monkeypatch):
generate = load_generate_script()

checkpoint_path = tmp_path / "ckpt"
checkpoint_path.touch()
tokenizer_path = tmp_path / "tokenizer"
tokenizer_path.touch()

class FabricMock(PropertyMock):
@property
def device(self):
return torch.device("cpu")

fabric_mock = FabricMock()
monkeypatch.setattr(generate.L, "Fabric", fabric_mock)
model_mock = Mock()
monkeypatch.setattr(generate.LLaMA, "from_name", model_mock)
load_mock = Mock()
monkeypatch.setattr(generate.torch, "load", load_mock)
tokenizer_mock = Mock()
tokenizer_mock.return_value.encode.return_value = torch.tensor([[1, 2, 3]])
tokenizer_mock.return_value.decode.return_value = "foo bar baz"
monkeypatch.setattr(generate, "Tokenizer", tokenizer_mock)
generate_mock = Mock()
generate_mock.return_value = torch.tensor([[3, 2, 1]])
monkeypatch.setattr(generate, "generate", generate_mock)

num_samples = 2
out = StringIO()
with redirect_stdout(out):
generate.main(
checkpoint_path=checkpoint_path,
tokenizer_path=tokenizer_path,
model_size="1T",
accelerator="litpu",
temperature=2.0,
top_k=2,
num_samples=num_samples,
)

model_mock.assert_called_once_with("1T")
load_mock.assert_called_once_with(checkpoint_path)
tokenizer_mock.assert_called_once_with(tokenizer_path)
assert len(tokenizer_mock.return_value.decode.mock_calls) == num_samples
assert torch.allclose(tokenizer_mock.return_value.decode.call_args[0][0], generate_mock.return_value)
model = model_mock.return_value
assert fabric_mock.mock_calls == [
call(accelerator="litpu", devices=1),
call().device.__enter__(),
call().device.__exit__(None, None, None),
call().setup_module(model),
]
model = fabric_mock.return_value.setup_module.return_value
assert (
generate_mock.mock_calls
== [call(model, ANY, 50, model.config.block_size, temperature=2.0, top_k=2)] * num_samples
)
# only the generated result is printed to stdout
assert out.getvalue() == "foo bar baz\n" * num_samples


def test_cli():
cli_path = wd / "generate.py"
output = subprocess.check_output([sys.executable, cli_path, "-h"])
output = str(output.decode())
assert "Generates text samples" in output
23 changes: 23 additions & 0 deletions tests/test_prepare_shakespeare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import os
import subprocess
import sys
from pathlib import Path

wd = (Path(__file__).parent.parent / "scripts").absolute()


def test_prepare(tmp_path):
sys.path.append(str(wd))

import prepare_shakespeare

prepare_shakespeare.prepare(tmp_path)

assert set(os.listdir(tmp_path)) == {"train.bin", "tokenizer.model", "tokenizer.vocab", "input.txt", "val.bin"}


def test_cli():
cli_path = wd / "prepare_shakespeare.py"
output = subprocess.check_output([sys.executable, cli_path, "-h"])
output = str(output.decode())
assert 'Prepare the "Tiny Shakespeare"' in output

0 comments on commit db1bb28

Please sign in to comment.