Skip to content

Commit

Permalink
Vwp/alpaca streaming (langchain-ai#3468)
Browse files Browse the repository at this point in the history
Co-authored-by: Luke Stanley <[email protected]>
  • Loading branch information
vowelparrot and lukestanley authored Apr 24, 2023
1 parent 26035df commit 416f3bd
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 30 deletions.
22 changes: 19 additions & 3 deletions docs/modules/models/llms/integrations/llamacpp.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@
"outputs": [],
"source": [
"from langchain.llms import LlamaCpp\n",
"from langchain import PromptTemplate, LLMChain"
"from langchain import PromptTemplate, LLMChain\n",
"from langchain.callbacks.base import CallbackManager\n",
"from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler"
]
},
{
Expand All @@ -67,7 +69,14 @@
},
"outputs": [],
"source": [
"llm = LlamaCpp(model_path=\"./ggml-model-q4_0.bin\")"
"# Callbacks support token-wise streaming\n",
"callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])\n",
"# Verbose is required to pass to the callback manager\n",
"\n",
"# Make sure the model path is correct for your system!\n",
"llm = LlamaCpp(\n",
" model_path=\"./ggml-model-q4_0.bin\", callback_manager=callback_manager, verbose=True\n",
")"
]
},
{
Expand All @@ -84,10 +93,17 @@
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" First we need to identify what year Justin Beiber was born in. A quick google search reveals that he was born on March 1st, 1994. Now we know when the Super Bowl was played in, so we can look up which NFL team won it. The NFL Superbowl of the year 1994 was won by the San Francisco 49ers against the San Diego Chargers."
]
},
{
"data": {
"text/plain": [
"'\\n\\nWe know that Justin Bieber is currently 25 years old and that he was born on March 1st, 1994 and that he is a singer and he has an album called Purpose, so we know that he was born when Super Bowl XXXVIII was played between Dallas and Seattle and that it took place February 1st, 2004 and that the Seattle Seahawks won 24-21, so Seattle is our answer!'"
"' First we need to identify what year Justin Beiber was born in. A quick google search reveals that he was born on March 1st, 1994. Now we know when the Super Bowl was played in, so we can look up which NFL team won it. The NFL Superbowl of the year 1994 was won by the San Francisco 49ers against the San Diego Chargers.'"
]
},
"execution_count": 6,
Expand Down
116 changes: 89 additions & 27 deletions langchain/llms/llamacpp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Wrapper around llama.cpp."""
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Generator, List, Optional

from pydantic import Field, root_validator

Expand Down Expand Up @@ -87,6 +87,9 @@ class LlamaCpp(LLM):
last_n_tokens_size: Optional[int] = 64
"""The number of tokens to look back when applying the repeat_penalty."""

streaming: bool = True
"""Whether to stream the results, token by token."""

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that llama-cpp-python library is installed."""
Expand Down Expand Up @@ -139,7 +142,7 @@ def _default_params(self) -> Dict[str, Any]:
"top_p": self.top_p,
"logprobs": self.logprobs,
"echo": self.echo,
"stop_sequences": self.stop,
"stop_sequences": self.stop, # key here is convention among LLM classes
"repeat_penalty": self.repeat_penalty,
"top_k": self.top_k,
}
Expand All @@ -154,6 +157,31 @@ def _llm_type(self) -> str:
"""Return type of llm."""
return "llama.cpp"

def _get_parameters(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
"""
Performs sanity check, preparing paramaters in format needed by llama_cpp.
Args:
stop (Optional[List[str]]): List of stop sequences for llama_cpp.
Returns:
Dictionary containing the combined parameters.
"""

# Raise error if stop sequences are in both input and default params
if self.stop and stop is not None:
raise ValueError("`stop` found in both the input and default params.")

params = self._default_params

# llama_cpp expects the "stop" key not this, so we remove it:
params.pop("stop_sequences")

# then sets it as configured, or default to an empty list:
params["stop"] = self.stop or stop or []

return params

def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call the Llama model and return the output.
Expand All @@ -167,31 +195,65 @@ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
Example:
.. code-block:: python
from langchain.llms import LlamaCppEmbeddings
llm = LlamaCppEmbeddings(model_path="/path/to/local/llama/model.bin")
from langchain.llms import LlamaCpp
llm = LlamaCpp(model_path="/path/to/local/llama/model.bin")
llm("This is a prompt.")
"""

params = self._default_params
if self.stop and stop is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop:
params["stop_sequences"] = self.stop
elif stop:
params["stop_sequences"] = stop
if self.streaming:
# If streaming is enabled, we use the stream
# method that yields as they are generated
# and return the combined strings from the first choices's text:
combined_text_output = ""
for token in self.stream(prompt=prompt, stop=stop):
combined_text_output += token["choices"][0]["text"]
return combined_text_output
else:
params["stop_sequences"] = []

"""Call the Llama model and return the output."""
text = self.client(
prompt=prompt,
max_tokens=params["max_tokens"],
temperature=params["temperature"],
top_p=params["top_p"],
logprobs=params["logprobs"],
echo=params["echo"],
stop=params["stop_sequences"],
repeat_penalty=params["repeat_penalty"],
top_k=params["top_k"],
)
return text["choices"][0]["text"]
params = self._get_parameters(stop)
result = self.client(prompt=prompt, **params)
return result["choices"][0]["text"]

def stream(
self, prompt: str, stop: Optional[List[str]] = None
) -> Generator[Dict, None, None]:
"""Yields results objects as they are generated in real time.
BETA: this is a beta feature while we figure out the right abstraction:
Once that happens, this interface could change.
It also calls the callback manager's on_llm_new_token event with
similar parameters to the OpenAI LLM class method of the same name.
Args:
prompt: The prompts to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
A generator representing the stream of tokens being generated.
Yields:
A dictionary like objects containing a string token and metadata.
See llama-cpp-python docs and below for more.
Example:
.. code-block:: python
from langchain.llms import LlamaCpp
llm = LlamaCpp(
model_path="/path/to/local/model.bin",
temperature = 0.5
)
for chunk in llm.stream("Ask 'Hi, how are you?' like a pirate:'",
stop=["'","\n"]):
result = chunk["choices"][0]
print(result["text"], end='', flush=True)
"""
params = self._get_parameters(stop)
result = self.client(prompt=prompt, stream=True, **params)
for chunk in result:
token = chunk["choices"][0]["text"]
log_probs = chunk["choices"][0].get("logprobs", None)
self.callback_manager.on_llm_new_token(
token=token, verbose=self.verbose, log_probs=log_probs
)
yield chunk
38 changes: 38 additions & 0 deletions tests/integration_tests/llms/test_llamacpp.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# flake8: noqa
"""Test Llama.cpp wrapper."""
import os
from typing import Generator
from urllib.request import urlretrieve

from langchain.llms import LlamaCpp
from langchain.callbacks.base import CallbackManager

from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler


def get_model() -> str:
Expand Down Expand Up @@ -32,3 +36,37 @@ def test_llamacpp_inference() -> None:
llm = LlamaCpp(model_path=model_path)
output = llm("Say foo:")
assert isinstance(output, str)
assert len(output) > 1


def test_llamacpp_streaming() -> None:
"""Test streaming tokens from LlamaCpp."""
model_path = get_model()
llm = LlamaCpp(model_path=model_path, max_tokens=10)
generator = llm.stream("Q: How do you say 'hello' in German? A:'", stop=["'"])
stream_results_string = ""
assert isinstance(generator, Generator)

for chunk in generator:
assert not isinstance(chunk, str)
# Note that this matches the OpenAI format:
assert isinstance(chunk["choices"][0]["text"], str)
stream_results_string += chunk["choices"][0]["text"]
assert len(stream_results_string.strip()) > 1


def test_llamacpp_streaming_callback() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
MAX_TOKENS = 5
OFF_BY_ONE = 1 # There may be an off by one error in the upstream code!

callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
llm = LlamaCpp(
model_path=get_model(),
callback_manager=callback_manager,
verbose=True,
max_tokens=MAX_TOKENS,
)
llm("Q: Can you count to 10? A:'1, ")
assert callback_handler.llm_streams <= MAX_TOKENS + OFF_BY_ONE

0 comments on commit 416f3bd

Please sign in to comment.