Skip to content

Commit

Permalink
Fix generation (Lightning-AI#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Mar 24, 2023
1 parent d055a43 commit bf2f243
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 505 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
__pycache__
data
.idea
/scripts/llama_model.py
.DS_Store

# downloaded by scripts/compare.py
llama_model.py
153 changes: 0 additions & 153 deletions compare.py

This file was deleted.

59 changes: 47 additions & 12 deletions generate.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,46 @@
# adapted from karpathy/minGPT
import os
import torch
import models.llama as llama
from model import LLaMA, LLaMAConfig
from tokenizer import Tokenizer
import lightning as L


def generate(
@torch.inference_mode()
def generate(model, idx, max_new_tokens, max_seq_length, temperature=1.0, top_k=None):
"""
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
Args:
idx: Tensor of shape (B, T) with indices of the prompt sequence.
max_new_tokens: The number of new tokens to generate.
max_seq_length: The maximum sequence length allowed.
temperature: Scales the predicted logits by 1 / temperature
top_k: If specified, only sample among the tokens with the k highest probabilities.
The implementation of this function is modified from A. Karpathy's nanoGPT.
"""
for _ in range(max_new_tokens):
# if the sequence context is growing too long we must crop it at max_seq_length
idx_cond = idx if idx.size(1) <= max_seq_length else idx[:, -max_seq_length:]
logits = model(idx_cond)
logits = logits[:, -1, :] / temperature

# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float("Inf")

probs = torch.nn.functional.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)

return idx


def main(
prompt: str = "Hello, my name is",
*,
num_samples: int = 1,
steps: int = 20,
max_new_tokens: int = 20,
top_k: int = 200,
temperature: float = 0.8,
compile: bool = False,
Expand All @@ -21,7 +53,7 @@ def generate(
Args:
prompt: The prompt string to use for generating the samples.
num_samples: The number of text samples to generate.
steps: The number of generation steps to take.
max_new_tokens: The number of generation steps to take.
top_k: The number of top most probable tokens to consider in the sampling process.
temperature: A value controlling the randomness of the sampling process. Higher values result in more random
samples.
Expand All @@ -35,27 +67,30 @@ def generate(

fabric = L.Fabric(accelerator=accelerator, precision=precision, devices=1)

checkpoint = torch.load("/srv/data/checkpoints/llama/converted_meta/7B/state_dict.pt")
llama_config = llama.LLAMA_CONFIG_DICT["7B"]

checkpoint_path = "/srv/data/checkpoints/llama/converted_meta/7B/state_dict.pt"
assert os.path.isfile(checkpoint_path)
llama_config = LLaMAConfig()
# initialize the model directly on the device
with fabric.device:
model = llama.LLaMA(llama_config)
model = LLaMA(llama_config)
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint)
model.eval()
if compile:
model = torch.compile(model)
model = fabric.setup_module(model, move_to_device=False)

tokenizer = llama.Tokenizer("/srv/data/checkpoints/llama/converted_meta/tokenizer.model")
tokenizer = Tokenizer("/srv/data/checkpoints/llama/converted_meta/tokenizer.model")
encoded_prompt = tokenizer.encode(prompt, bos=True, eos=False).to(fabric.device)
encoded_prompt = encoded_prompt[None, :]
for k in range(num_samples):
y = model.generate(encoded_prompt, steps, temperature=temperature, top_k=top_k)
for _ in range(num_samples):
y = generate(
model, encoded_prompt, max_new_tokens, model.params.max_seq_length, temperature=temperature, top_k=top_k
)
print(tokenizer.decode(y[0]))


if __name__ == "__main__":
from jsonargparse import CLI

CLI()
CLI(main)
24 changes: 7 additions & 17 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,6 @@ def __init__(self, config, rope_cache):
# regularization
self.n_head = config.n_head
self.n_embd = config.n_embd
# flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
self.flash = False
if not self.flash:
# print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))

self.rope_cache = rope_cache

def forward(self, x):
Expand All @@ -113,15 +105,8 @@ def forward(self, x):
k = apply_rope(k, self.rope_cache)

# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
if self.flash:
# efficient attention using Flash Attention CUDA kernels
y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True)
else:
# manual implementation of attention
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
# efficient attention using Flash Attention CUDA kernels
y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

# output projection
Expand Down Expand Up @@ -220,3 +205,8 @@ def forward(self, idx):
logits = self.lm_head(x)

return logits

def step(self, idx, targets):
logits = self(idx)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
return loss
Loading

0 comments on commit bf2f243

Please sign in to comment.