Skip to content

Commit

Permalink
feat: ✨ Implement async functionality in BedrockConverse (run-llama…
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreCNF authored Jul 19, 2024
1 parent cb9e4a8 commit 7343f28
Show file tree
Hide file tree
Showing 4 changed files with 241 additions and 26 deletions.
37 changes: 37 additions & 0 deletions docs/docs/examples/llm/bedrock_converse.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,43 @@
"for s in response.sources:\n",
" print(f\"Name: {s.tool_name}, Input: {s.raw_input}, Output: {str(s)}\")"
]
},
{
"cell_type": "markdown",
"id": "a7bee3bc",
"metadata": {},
"source": [
"## Async"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8f4c144f",
"metadata": {},
"outputs": [],
"source": [
"from llama_index.llms.bedrock_converse import BedrockConverse\n",
"\n",
"llm = BedrockConverse(\n",
" model=\"anthropic.claude-3-haiku-20240307-v1:0\",\n",
" aws_access_key_id=\"AWS Access Key ID to use\",\n",
" aws_secret_access_key=\"AWS Secret Access Key to use\",\n",
" aws_session_token=\"AWS Session Token to use\",\n",
" region_name=\"AWS Region to use, eg. us-east-1\",\n",
")\n",
"resp = await llm.acomplete(\"Paul Graham is \")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cd72e3a0",
"metadata": {},
"outputs": [],
"source": [
"print(resp)"
]
}
],
"metadata": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
ChatMessage,
ChatResponse,
ChatResponseGen,
ChatResponseAsyncGen,
CompletionResponseAsyncGen,
CompletionResponse,
CompletionResponseGen,
LLMMetadata,
Expand All @@ -28,6 +30,8 @@
llm_completion_callback,
)
from llama_index.core.base.llms.generic_utils import (
achat_to_completion_decorator,
astream_chat_to_completion_decorator,
chat_to_completion_decorator,
stream_chat_to_completion_decorator,
)
Expand All @@ -36,6 +40,7 @@
from llama_index.llms.bedrock_converse.utils import (
bedrock_modelname_to_context_size,
converse_with_retry,
converse_with_retry_async,
force_single_tool_call,
is_bedrock_function_calling_model,
join_two_dicts,
Expand Down Expand Up @@ -115,8 +120,9 @@ class BedrockConverse(FunctionCallingLLM):
description="Additional kwargs for the bedrock invokeModel request.",
)

_config: Any = PrivateAttr()
_client: Any = PrivateAttr()
_aclient: Any = PrivateAttr()
_asession: Any = PrivateAttr()

def __init__(
self,
Expand Down Expand Up @@ -152,12 +158,13 @@ def __init__(
"aws_session_token": aws_session_token,
"botocore_session": botocore_session,
}
config = None
self._config = None
try:
import boto3
import aioboto3
from botocore.config import Config

config = (
self._config = (
Config(
retries={"max_attempts": max_retries, "mode": "standard"},
connect_timeout=timeout,
Expand All @@ -167,9 +174,11 @@ def __init__(
else botocore_config
)
session = boto3.Session(**session_kwargs)
self._asession = aioboto3.Session(**session_kwargs)
except ImportError:
raise ImportError(
"boto3 package not found, install with" "'pip install boto3'"
"boto3 and/or aioboto3 package not found, install with"
"'pip install boto3 aioboto3"
)

# Prior to general availability, custom boto3 wheel files were
Expand All @@ -179,9 +188,9 @@ def __init__(
if client is not None:
self._client = client
elif "bedrock-runtime" in session.get_available_services():
self._client = session.client("bedrock-runtime", config=config)
self._client = session.client("bedrock-runtime", config=self._config)
else:
self._client = session.client("bedrock", config=config)
self._client = session.client("bedrock", config=self._config)

super().__init__(
temperature=temperature,
Expand Down Expand Up @@ -386,29 +395,126 @@ def stream_complete(
async def achat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponse:
# TODO convert to async; do synchronous chat for now
return self.chat(messages, **kwargs)
# convert Llama Index messages to AWS Bedrock Converse messages
converse_messages, system_prompt = messages_to_converse_messages(messages)
if len(system_prompt) > 0 or self.system_prompt is None:
self.system_prompt = system_prompt
all_kwargs = self._get_all_kwargs(**kwargs)

# invoke LLM in AWS Bedrock Converse with retry
response = await converse_with_retry_async(
session=self._asession,
config=self._config,
messages=converse_messages,
system_prompt=self.system_prompt,
max_retries=self.max_retries,
stream=False,
**all_kwargs,
)

content, tool_calls, tool_call_ids, status = self._get_content_and_tool_calls(
response
)

return ChatResponse(
message=ChatMessage(
role=MessageRole.ASSISTANT,
content=content,
additional_kwargs={
"tool_calls": tool_calls,
"tool_call_id": tool_call_ids,
"status": status,
},
),
raw=dict(response),
)

@llm_completion_callback()
async def acomplete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponse:
# TODO convert to async; do synchronous completion for now
return self.complete(prompt, formatted=formatted, **kwargs)
complete_fn = achat_to_completion_decorator(self.achat)
return await complete_fn(prompt, **kwargs)

@llm_chat_callback()
async def astream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseGen:
# TODO convert to async; do synchronous chat for now
return self.stream_chat(messages, **kwargs)
) -> ChatResponseAsyncGen:
# convert Llama Index messages to AWS Bedrock Converse messages
converse_messages, system_prompt = messages_to_converse_messages(messages)
if len(system_prompt) > 0 or self.system_prompt is None:
self.system_prompt = system_prompt
all_kwargs = self._get_all_kwargs(**kwargs)

# invoke LLM in AWS Bedrock Converse with retry
response = await converse_with_retry_async(
session=self._asession,
config=self._config,
messages=converse_messages,
system_prompt=self.system_prompt,
max_retries=self.max_retries,
stream=True,
**all_kwargs,
)

async def gen() -> ChatResponseAsyncGen:
content = {}
role = MessageRole.ASSISTANT
for chunk in response["stream"]:
if content_block_delta := chunk.get("contentBlockDelta"):
content_delta = content_block_delta["delta"]
content = join_two_dicts(content, content_delta)
(
_,
tool_calls,
tool_call_ids,
status,
) = self._get_content_and_tool_calls(content=content)

yield ChatResponse(
message=ChatMessage(
role=role,
content=content.get("text", ""),
additional_kwargs={
"tool_calls": tool_calls,
"tool_call_id": tool_call_ids,
"status": status,
},
),
delta=content_delta.get("text", ""),
raw=response,
)
elif content_block_start := chunk.get("contentBlockStart"):
tool_use = content_block_start["toolUse"]
content = join_two_dicts(content, tool_use)
(
_,
tool_calls,
tool_call_ids,
status,
) = self._get_content_and_tool_calls(content=content)

yield ChatResponse(
message=ChatMessage(
role=role,
content=content.get("text", ""),
additional_kwargs={
"tool_calls": tool_calls,
"tool_call_id": tool_call_ids,
"status": status,
},
),
raw=response,
)

return gen()

@llm_completion_callback()
async def astream_complete(
self, prompt: str, formatted: bool = False, **kwargs: Any
) -> CompletionResponseGen:
# TODO convert to async; do synchronous completion for now
return self.stream_complete(prompt, formatted=formatted, **kwargs)
) -> CompletionResponseAsyncGen:
astream_complete_fn = astream_chat_to_completion_decorator(self.astream_chat)
return await astream_complete_fn(prompt, **kwargs)

def chat_with_tools(
self,
Expand Down Expand Up @@ -445,15 +551,21 @@ async def achat_with_tools(
allow_parallel_tool_calls: bool = False,
**kwargs: Any,
) -> ChatResponse:
# TODO convert to async; do synchronous chat for now
return self.chat_with_tools(
tools=tools,
user_msg=user_msg,
chat_history=chat_history,
verbose=verbose,
allow_parallel_tool_calls=allow_parallel_tool_calls,
**kwargs,
)
chat_history = chat_history or []

if isinstance(user_msg, str):
user_msg = ChatMessage(role=MessageRole.USER, content=user_msg)
chat_history.append(user_msg)

# convert Llama Index tools to AWS Bedrock Converse tools
tool_dicts = tools_to_converse_tools(tools)

response = await self.achat(chat_history, tools=tool_dicts, **kwargs)

if not allow_parallel_tool_calls:
force_single_tool_call(response)

return response

def get_tool_calls_from_response(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,30 @@ def _create_retry_decorator(client: Any, max_retries: int) -> Callable[[Any], An
)


def _create_retry_decorator_async(max_retries: int) -> Callable[[Any], Any]:
min_seconds = 4
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
try:
import aioboto3 # noqa
except ImportError as e:
raise ImportError(
"You must install the `aioboto3` package to use Bedrock."
"Please `pip install aioboto3`"
) from e

return retry(
reraise=True,
stop=stop_after_attempt(max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type()
), # TODO: Add throttling exception in async version
before_sleep=before_sleep_log(logger, logging.WARNING),
)


def converse_with_retry(
client: Any,
model: str,
Expand Down Expand Up @@ -260,6 +284,47 @@ def _conversion_with_retry(**kwargs: Any) -> Any:
return _conversion_with_retry(**converse_kwargs)


async def converse_with_retry_async(
session: Any,
config: Any,
model: str,
messages: Sequence[Dict[str, Any]],
max_retries: int = 3,
system_prompt: Optional[str] = None,
max_tokens: int = 1000,
temperature: float = 0.1,
stream: bool = False,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator_async(max_retries=max_retries)
converse_kwargs = {
"modelId": model,
"messages": messages,
"inferenceConfig": {
"maxTokens": max_tokens,
"temperature": temperature,
},
}
if system_prompt:
converse_kwargs["system"] = [{"text": system_prompt}]
if tool_config := kwargs.get("tools"):
converse_kwargs["toolConfig"] = tool_config
converse_kwargs = join_two_dicts(
converse_kwargs, {k: v for k, v in kwargs.items() if k != "tools"}
)

@retry_decorator
async def _conversion_with_retry(**kwargs: Any) -> Any:
# the async boto3 client needs to be defined inside this async with, otherwise it will raise an error
async with session.client("bedrock-runtime", config=config) as client:
if stream:
return await client.converse_stream(**kwargs)
return await client.converse(**kwargs)

return await _conversion_with_retry(**converse_kwargs)


def join_two_dicts(dict1: Dict[str, Any], dict2: Dict[str, Any]) -> Dict[str, Any]:
"""
Joins two dictionaries, summing shared keys and adding new keys.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-llms-bedrock-converse"
readme = "README.md"
version = "0.1.4"
version = "0.1.5"

[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
llama-index-core = "^0.10.1"
llama-index-llms-anthropic = "^0.1.7"
boto3 = "^1.34.122"
aioboto3 = "^13.1.1"

[tool.poetry.group.dev.dependencies]
ipython = "8.10.0"
Expand Down

0 comments on commit 7343f28

Please sign in to comment.