Skip to content

Commit

Permalink
Allow multiline prompts (Lightning-AI#1279)
Browse files Browse the repository at this point in the history
Co-authored-by: Luca Antiga <[email protected]>
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
3 people authored May 10, 2024
1 parent 20383ed commit d317902
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 25 deletions.
90 changes: 65 additions & 25 deletions litgpt/chat/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,51 @@ def decode(fabric: L.Fabric, tokenizer: Tokenizer, token_stream: Iterator[torch.
return tokens_generated


def process_prompt(prompt, model, tokenizer, prompt_style, fabric, temperature, top_k, top_p, stop_tokens):
prompt = prompt_style.apply(prompt=prompt)
encoded_prompt = tokenizer.encode(prompt, device=fabric.device)
y = generate(
model, encoded_prompt, model.max_seq_length, temperature=temperature, top_k=top_k, top_p=top_p, stop_tokens=stop_tokens
)
fabric.print(">> Reply: ", end="")
t0 = time.perf_counter()
tokens_generated = decode(fabric, tokenizer, y)
t = time.perf_counter() - t0
for block in model.transformer.h:
block.attn.kv_cache.reset_parameters()
fabric.print(
f"\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec,"
f" {tokens_generated} tokens",
file=sys.stderr,
)
fabric.print()


def interact(multiline, model, tokenizer, prompt_style, fabric, temperature, top_k, top_p, stop_tokens):
while True:
try:
if not multiline:
prompt = input(">> Prompt: ")
else:
print(">> Prompt: (Type '!submit' on a new line to end input).")
prompt_lines = []
while True:
line = input()
if line.strip().lower() in ("!submit", "!quit", "!exit"):
break
prompt_lines.append(line)
prompt = "\n".join(prompt_lines)

except KeyboardInterrupt:
break

prompt = prompt.lower().strip()
if not prompt or prompt in ("!quit", "!exit"):
break

process_prompt(prompt, model, tokenizer, prompt_style, fabric, temperature, top_k, top_p, stop_tokens)


@torch.inference_mode()
def main(
*,
Expand All @@ -120,6 +165,7 @@ def main(
quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"]] = None,
precision: Optional[str] = None,
compile: bool = False,
multiline: bool = False,
) -> None:
"""Starts a conversation with a tuned GPT model.
Expand Down Expand Up @@ -148,6 +194,7 @@ def main(
for more details, see https://github.com/Lightning-AI/litgpt/blob/main/tutorials/quantize.md
precision: Indicates the Fabric precision setting to use.
compile: Whether to use compilation to speed up token generation. Will increase startup time.
multiline: Whether to support multiline input prompts.
"""
precision = precision or get_default_supported_precision(training=False)

Expand Down Expand Up @@ -193,29 +240,22 @@ def main(
)
stop_tokens = prompt_style.stop_tokens(tokenizer)

print(f"Now chatting with {config.name}.\nTo exit, press 'Enter' on an empty prompt.\n")
if multiline:
exit_instruction = "To exit, enter '!quit' or '!exit' on an empty prompt and press 'Enter'."
else:
exit_instruction = "To exit, press 'Enter' on an empty prompt."

print(f"Now chatting with {config.name}.\n{exit_instruction}\n")
L.seed_everything(1234)
while True:
try:
prompt = input(">> Prompt: ")
except KeyboardInterrupt:
break
if prompt.lower().strip() in ("", "quit", "exit"):
break
prompt = prompt_style.apply(prompt=prompt)
encoded_prompt = tokenizer.encode(prompt, device=fabric.device)
y = generate(
model, encoded_prompt, model.max_seq_length, temperature=temperature, top_k=top_k, top_p=top_p, stop_tokens=stop_tokens
)
fabric.print(">> Reply: ", end="")
t0 = time.perf_counter()
tokens_generated = decode(fabric, tokenizer, y)
t = time.perf_counter() - t0
for block in model.transformer.h:
block.attn.kv_cache.reset_parameters()
fabric.print(
f"\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec,"
f" {tokens_generated} tokens",
file=sys.stderr,
)
fabric.print()

interact(
multiline=multiline,
model=model,
tokenizer=tokenizer,
prompt_style=prompt_style,
fabric=fabric,
temperature=temperature,
top_k=top_k,
top_p=top_p,
stop_tokens=stop_tokens
)
5 changes: 5 additions & 0 deletions tutorials/0_to_litgpt.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,12 @@ Time for inference: 1.26 sec total, 27.81 tokens/sec, 35 tokens
>> Prompt:
```
&nbsp;

> [!TIP]
> Use `--multiline true` to support prompts that require multiple input lines.
<br>

&nbsp;
**More information and additional resources**
Expand Down
4 changes: 4 additions & 0 deletions tutorials/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ litgpt chat --checkpoint_dir checkpoints/stabilityai/stablelm-tuned-alpha-3b
This script can work with any checkpoint. For the best chat-like experience, we recommend using it with a checkpoints
fine-tuned for chatting such as `stabilityai/stablelm-tuned-alpha-3b` or `togethercomputer/RedPajama-INCITE-Chat-3B-v1`.

> [!TIP]
> Use `--multiline true` to work with inputs that span multiple lines.

## Run a large model on one smaller device

Check out our [quantization tutorial](quantize.md).
Expand Down

0 comments on commit d317902

Please sign in to comment.