Skip to content

Commit

Permalink
Minimal Python example (Lightning-AI#1410)
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt authored May 13, 2024
1 parent 62a491c commit 5dabf5f
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 0 deletions.
30 changes: 30 additions & 0 deletions tutorials/examples/minimal-generate-scripts/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
## Minimal LitGPT Generate Examples in Python



The scripts in this folder provide minimal examples showing how to use LitGPT from within Python without the CLI.

- `generate.py` is a minimal script that uses the `main` function from LitGPT's `generate` utilities
- `generate-step-by-step.py` is a lower-level script using LitGPT utility functions directly instead of relying on the `main` function menntioned above.

Assuming you downloaded the checkpoint files via

```bash
litgpt download --repo_id EleutherAI/pythia-1b
```

you can run the scripts as follows:


```bash
python generate-step-by-step.py
```

or

```bash
python generate.py
```



Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

from pathlib import Path

import lightning as L
import torch

from litgpt.prompts import PromptStyle
from litgpt.tokenizer import Tokenizer
from litgpt.utils import load_checkpoint, get_default_supported_precision
from litgpt.generate.base import generate
from litgpt.model import GPT
from litgpt.config import Config


def use_model():

###################
# Load model
###################

# run `litgpt download --repo_id EleutherAI/pythia-1b` to download the checkpoint first
checkpoint_dir = Path("checkpoints") / "EleutherAI" / "pythia-1b"
config = Config.from_file(checkpoint_dir / "model_config.yaml")

precision = get_default_supported_precision(training=False)
device = torch.device("cuda")

fabric = L.Fabric(
accelerator=device.type,
devices=1,
precision=precision,
)

checkpoint_path = checkpoint_dir / "lit_model.pth"
tokenizer = Tokenizer(checkpoint_dir)

prompt_style = PromptStyle.from_config(config)

with fabric.init_module(empty_init=True):
model = GPT(config)
with fabric.init_tensor():
model.set_kv_cache(batch_size=1)

model.eval()
model = fabric.setup_module(model)
load_checkpoint(fabric, model, checkpoint_path)

device = fabric.device

###################
# Predict
###################

prompt = "What do Llamas eat?"
max_new_tokens = 50

prompt = prompt_style.apply(prompt)
encoded = tokenizer.encode(prompt, device=device)

prompt_length = encoded.size(0)
max_returned_tokens = prompt_length + max_new_tokens

torch.manual_seed(123)

y = generate(
model,
encoded,
max_returned_tokens,
temperature=0.5,
top_k=200,
top_p=1.0,
eos_id=tokenizer.eos_id
)

for block in model.transformer.h:
block.attn.kv_cache.reset_parameters()

decoded_output = tokenizer.decode(y)
print(decoded_output)


if __name__ == "__main__":
use_model()
29 changes: 29 additions & 0 deletions tutorials/examples/minimal-generate-scripts/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

from pathlib import Path
import torch
from litgpt.generate.base import main
from litgpt.utils import get_default_supported_precision


def use_model():

# run `litgpt download --repo_id EleutherAI/pythia-1b` to download the checkpoint first
checkpoint_dir = Path("checkpoints") / "EleutherAI" / "pythia-1b"

torch.manual_seed(123)

main(
prompt="What food do llamas eat?",
max_new_tokens=50,
temperature=0.5,
top_k=200,
top_p=1.0,
checkpoint_dir=checkpoint_dir,
precision=get_default_supported_precision(training=False),
compile=False
)


if __name__ == "__main__":
use_model()

0 comments on commit 5dabf5f

Please sign in to comment.