Skip to content

Commit

Permalink
Exclude prompt from generated response (Lightning-AI#1485)
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt authored Jun 12, 2024
1 parent 8f65463 commit c0f7686
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 5 deletions.
3 changes: 2 additions & 1 deletion litgpt/deploy/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def predict(self, inputs: torch.Tensor) -> Any:
temperature=self.temperature,
top_k=self.top_k,
top_p=self.top_p,
eos_id=self.tokenizer.eos_id
eos_id=self.tokenizer.eos_id,
include_prompt=False
)

for block in self.model.transformer.h:
Expand Down
7 changes: 6 additions & 1 deletion litgpt/generate/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def generate(
top_k: Optional[int] = None,
top_p: float = 1.0,
eos_id: Optional[int] = None,
include_prompt: bool = True,
) -> torch.Tensor:
"""
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
Expand All @@ -112,6 +113,7 @@ def generate(
For more details, see https://arxiv.org/abs/1904.09751
or https://huyenchip.com/2024/01/16/sampling.html#top_p
eos_id: If specified, stop generating any more token once the <eos> token is triggered.
include_prompt: If true (default) prepends the prompt (after applying the prompt style) to the output.
"""
T = prompt.size(0)
assert max_returned_tokens > T
Expand All @@ -122,7 +124,10 @@ def generate(
raise NotImplementedError(f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}")

device = prompt.device
tokens = [prompt]
if include_prompt:
tokens = [prompt]
else:
tokens = []
input_pos = torch.tensor([T], device=device)
token = next_token(
model, torch.arange(0, T, device=device), prompt.view(1, -1), temperature=temperature, top_k=top_k, top_p=top_p
Expand Down
2 changes: 1 addition & 1 deletion tests/test_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ def test_simple(tmp_path):
response = client.post("/predict", json={"prompt": "Hello world"})
# Model is a small random model, not trained, hence the gibberish.
# We are just testing that the server works.
assert response.json()["output"][:19] == "Hello world statues"
assert response.json()["output"][:19] == " statues CAD pierci"
3 changes: 1 addition & 2 deletions tutorials/deploy.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,5 @@ print(response.json()["output"])
Executing the code above prints the following output:

```
Instruct: Fix typos in the following sentence: Exampel input
Output: Example input.
Example input.
```

0 comments on commit c0f7686

Please sign in to comment.