Skip to content

Commit

Permalink
Adding cached KVs (Lightning-AI#51)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
gkroiz and carmocca authored May 12, 2023
1 parent 788211b commit 0b5620d
Show file tree
Hide file tree
Showing 8 changed files with 210 additions and 91 deletions.
64 changes: 45 additions & 19 deletions chat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import re
import sys
import time
import warnings
from pathlib import Path
from typing import Optional, Tuple, List
Expand All @@ -14,9 +15,11 @@

@torch.no_grad()
def generate(
model: torch.nn.Module,
model: Parrot,
idx: torch.Tensor,
max_seq_length: int,
max_new_tokens: int,
*,
max_seq_length: Optional[int] = None,
temperature: float = 1.0,
top_k: Optional[int] = None,
stop_tokens: Tuple[List[int], ...] = tuple(),
Expand All @@ -26,51 +29,69 @@ def generate(
Args:
model: The model to use.
idx: Tensor of shape (T) with indices of the prompt sequence.
max_new_tokens: The number of new tokens to generate.
max_seq_length: The maximum sequence length allowed.
temperature: Scales the predicted logits by 1 / temperature
top_k: If specified, only sample among the tokens with the k highest probabilities
stop_tokens: If specified, stop generating any more token once one of this list is generated.
"""
stop_tokens = [torch.tensor(tokens, device=idx.device) for tokens in stop_tokens]
T = yield_i = idx.size(0)
assert max_seq_length > T
buffer = max((len(tokens) for tokens in stop_tokens), default=0)
T = idx.size(0)
T_new = T + max_new_tokens
if max_seq_length is None:
max_seq_length = min(T_new, model.config.block_size)
# otherwise this would use more memory than necessary
assert max_seq_length <= T_new

device = idx.device
stop_tokens = [torch.tensor(tokens, device=device) for tokens in stop_tokens]
input_pos = torch.arange(0, T, device=device)

# buffer holds the tokens that haven't been yield yet
buffer_length = max((len(tokens) for tokens in stop_tokens), default=1)
buffer = torch.full((buffer_length,), -999, device=device) # fill with non-existing token

if idx.device.type == "xla":
import torch_xla.core.xla_model as xm

xm.mark_step()

for t in range(T, max_seq_length):
yield_i = -1
for t in range(max_new_tokens):
# forward
logits = model(idx.view(1, -1))
logits = model(idx.view(1, -1), max_seq_length, input_pos)
logits = logits[0, -1] / temperature

# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[[-1]]] = -float("Inf")
logits = torch.where(logits < v[[-1]], -float("Inf"), logits)

probs = torch.nn.functional.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.multinomial(probs, num_samples=1)

# concatenate the new generation
idx = torch.cat((idx, idx_next), dim=-1)
# advance
input_pos = input_pos[-1:] + 1

if idx.device.type == "xla":
xm.mark_step()

# concatenate the new generation
buffer[min(t, buffer_length - 1)] = idx

# check the stop condition
for tokens in stop_tokens:
l = len(tokens)
if torch.equal(idx[-l:], tokens):
if torch.equal(buffer[-l:], tokens):
# stop token hit, yield any leftovers that aren't part of it
last = t - l + 1
if last > yield_i: # avoid an empty yield
yield idx[yield_i:last]
if buffer_length > l: # avoid an empty yield
yield buffer[:-l]
return
if t - yield_i >= buffer:
# if the buffer is full
if t - yield_i >= buffer_length:
# we know this idx is not part of stop tokens, safe to yield
yield idx[yield_i]
yield buffer[0]
# roll once to the left, as next generation will be put at the end
buffer = torch.roll(buffer, -1, 0)
yield_i += 1


Expand Down Expand Up @@ -125,15 +146,20 @@ def main(
y = generate(
model,
encoded_prompt,
model.config.block_size, # type: ignore[union-attr,arg-type]
max_new_tokens=model.config.block_size, # type: ignore[union-attr,arg-type]
temperature=temperature,
top_k=top_k,
stop_tokens=stop_tokens,
)
print(f">> Reply: ", end="")
try:
tokens_generated = 0
t0 = time.perf_counter()
for token in y:
print(tokenizer.decode(token), end="", flush=True)
tokens_generated += 1
t = time.perf_counter() - t0
print(f"Time for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", file=sys.stderr)
except KeyboardInterrupt:
# support stopping generation
pass
Expand Down
52 changes: 27 additions & 25 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@

@torch.no_grad()
def generate(
model: torch.nn.Module,
model: Parrot,
idx: torch.Tensor,
max_new_tokens: int,
max_seq_length: int,
*,
max_seq_length: Optional[int] = None,
temperature: float = 1.0,
top_k: Optional[int] = None,
eos_id: Optional[int] = None,
Expand All @@ -35,45 +36,53 @@ def generate(
top_k: If specified, only sample among the tokens with the k highest probabilities.
eos_id: If specified, stop generating any more token once the <eos> token is triggered.
"""
# create an empty tensor of the expected final shape and fill in the current tokens
T = idx.size(0)
T_new = T + max_new_tokens
empty = torch.empty(T_new, dtype=idx.dtype, device=idx.device)
if max_seq_length is None:
max_seq_length = min(T_new, model.config.block_size)
# otherwise this would use more memory than necessary
assert max_seq_length <= T_new

device, dtype = idx.device, idx.dtype
# create an empty tensor of the expected final shape and fill in the current tokens
empty = torch.empty(T_new, dtype=dtype, device=device)
empty[:T] = idx
idx = empty
input_pos = torch.arange(0, T, device=device)

if idx.device.type == "xla":
import torch_xla.core.xla_model as xm

xm.mark_step()

# generate max_new_tokens tokens
for t in range(T, T_new):
# ignore the not-filled-yet tokens
idx_cond = idx[:t]
# if the sequence context is growing too long we must crop it at max_seq_length
idx_cond = idx_cond if t <= max_seq_length else idx_cond[-max_seq_length:]
for _ in range(max_new_tokens):
x = idx.index_select(0, input_pos).view(1, -1)

# forward
logits = model(idx_cond.view(1, -1))
logits = model(x, max_seq_length, input_pos)
logits = logits[0, -1] / temperature

# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[[-1]]] = -float("Inf")
logits = torch.where(logits < v[[-1]], -float("Inf"), logits)

probs = torch.nn.functional.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx_next = torch.multinomial(probs, num_samples=1).to(dtype=dtype)

# advance
input_pos = input_pos[-1:] + 1

if idx.device.type == "xla":
xm.mark_step()

# concatenate the new generation
idx[t] = idx_next
idx = idx.index_copy(0, input_pos, idx_next)

# if <eos> token is triggered, return the output (stop generation)
if idx_next == eos_id:
return idx[: t + 1] # include the EOS token

if idx.device.type == "xla":
xm.mark_step()
return idx[:input_pos] # include the EOS token

return idx

Expand Down Expand Up @@ -129,14 +138,7 @@ def main(
L.seed_everything(1234)
for i in range(num_samples):
t0 = time.perf_counter()
y = generate(
model,
encoded,
max_new_tokens,
model.config.block_size, # type: ignore[union-attr,arg-type]
temperature=temperature,
top_k=top_k,
)
y = generate(model, encoded, max_new_tokens, temperature=temperature, top_k=top_k)
t = time.perf_counter() - t0

print(tokenizer.decode(y))
Expand Down
8 changes: 1 addition & 7 deletions generate_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,7 @@ def main(

t0 = time.perf_counter()
y = generate(
model,
idx=encoded,
max_seq_length=max_new_tokens,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
eos_id=tokenizer.eos_id,
model, idx=encoded, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, eos_id=tokenizer.eos_id
)
t = time.perf_counter() - t0

Expand Down
6 changes: 3 additions & 3 deletions howto/tpus.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ Since you created a new machine, you'll probably need to download the weights. Y
Generation works out-of-the-box with TPUs:

```shell
python3 generate.py --prompt "Hello, my name is" --num_samples 2
python3 generate.py --prompt "Hello, my name is" --num_samples 3
```

This command will take a long time as XLA needs to compile the graph (~13 min) before running the model.
In fact, you'll notice that the second sample takes considerable less time (~12 sec).
This command will take a long time as XLA needs to compile the graph: ~20s for the first generation.
In fact, you'll notice that the second sample takes considerable less time: ~9s, and ~2s after.

## Finetuning

Expand Down
Loading

0 comments on commit 0b5620d

Please sign in to comment.