Skip to content

Commit

Permalink
Support openai assistant v2 API (microsoft#2466)
Browse files Browse the repository at this point in the history
* adapted to openai assistant v2 api

* fix comments

* format code

* fix ci

* Update autogen/agentchat/contrib/gpt_assistant_agent.py

Co-authored-by: Eric Zhu <[email protected]>

---------

Co-authored-by: Eric Zhu <[email protected]>
  • Loading branch information
IANTHEREAL and ekzhu authored Apr 23, 2024
1 parent 2daae42 commit a41182a
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 53 deletions.
54 changes: 31 additions & 23 deletions autogen/agentchat/contrib/gpt_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from autogen import OpenAIWrapper
from autogen.agentchat.agent import Agent
from autogen.agentchat.assistant_agent import AssistantAgent, ConversableAgent
from autogen.oai.openai_utils import retrieve_assistants_by_name
from autogen.oai.openai_utils import create_gpt_assistant, retrieve_assistants_by_name, update_gpt_assistant

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -50,7 +50,8 @@ def __init__(
- check_every_ms: check thread run status interval
- tools: Give Assistants access to OpenAI-hosted tools like Code Interpreter and Knowledge Retrieval,
or build your own tools using Function calling. ref https://platform.openai.com/docs/assistants/tools
- file_ids: files used by retrieval in run
- file_ids: (Deprecated) files used by retrieval in run. It is Deprecated, use tool_resources instead. https://platform.openai.com/docs/assistants/migration/what-has-changed.
- tool_resources: A set of resources that are used by the assistant's tools. The resources are specific to the type of tool.
overwrite_instructions (bool): whether to overwrite the instructions of an existing assistant. This parameter is in effect only when assistant_id is specified in llm_config.
overwrite_tools (bool): whether to overwrite the tools of an existing assistant. This parameter is in effect only when assistant_id is specified in llm_config.
kwargs (dict): Additional configuration options for the agent.
Expand Down Expand Up @@ -90,7 +91,6 @@ def __init__(
candidate_assistants,
instructions,
openai_assistant_cfg.get("tools", []),
openai_assistant_cfg.get("file_ids", []),
)

if len(candidate_assistants) == 0:
Expand All @@ -101,12 +101,12 @@ def __init__(
"No instructions were provided for new assistant. Using default instructions from AssistantAgent.DEFAULT_SYSTEM_MESSAGE."
)
instructions = AssistantAgent.DEFAULT_SYSTEM_MESSAGE
self._openai_assistant = self._openai_client.beta.assistants.create(
self._openai_assistant = create_gpt_assistant(
self._openai_client,
name=name,
instructions=instructions,
tools=openai_assistant_cfg.get("tools", []),
model=model_name,
file_ids=openai_assistant_cfg.get("file_ids", []),
assistant_config=openai_assistant_cfg,
)
else:
logger.warning(
Expand All @@ -127,9 +127,12 @@ def __init__(
logger.warning(
"overwrite_instructions is True. Provided instructions will be used and will modify the assistant in the API"
)
self._openai_assistant = self._openai_client.beta.assistants.update(
self._openai_assistant = update_gpt_assistant(
self._openai_client,
assistant_id=openai_assistant_id,
instructions=instructions,
assistant_config={
"instructions": instructions,
},
)
else:
logger.warning(
Expand All @@ -154,9 +157,13 @@ def __init__(
logger.warning(
"overwrite_tools is True. Provided tools will be used and will modify the assistant in the API"
)
self._openai_assistant = self._openai_client.beta.assistants.update(
self._openai_assistant = update_gpt_assistant(
self._openai_client,
assistant_id=openai_assistant_id,
tools=openai_assistant_cfg.get("tools", []),
assistant_config={
"tools": specified_tools,
"tool_resources": openai_assistant_cfg.get("tool_resources", None),
},
)
else:
# Tools are specified but overwrite_tools is False; do not update the assistant's tools
Expand Down Expand Up @@ -198,6 +205,8 @@ def _invoke_assistant(
assistant_thread = self._openai_threads[sender]
# Process each unread message
for message in pending_messages:
if message["content"].strip() == "":
continue
self._openai_client.beta.threads.messages.create(
thread_id=assistant_thread.id,
content=message["content"],
Expand Down Expand Up @@ -426,22 +435,23 @@ def delete_assistant(self):
logger.warning("Permanently deleting assistant...")
self._openai_client.beta.assistants.delete(self.assistant_id)

def find_matching_assistant(self, candidate_assistants, instructions, tools, file_ids):
def find_matching_assistant(self, candidate_assistants, instructions, tools):
"""
Find the matching assistant from a list of candidate assistants.
Filter out candidates with the same name but different instructions, file IDs, and function names.
TODO: implement accurate match based on assistant metadata fields.
Filter out candidates with the same name but different instructions, and function names.
"""
matching_assistants = []

# Preprocess the required tools for faster comparison
required_tool_types = set(tool.get("type") for tool in tools)
required_tool_types = set(
"file_search" if tool.get("type") in ["retrieval", "file_search"] else tool.get("type") for tool in tools
)

required_function_names = set(
tool.get("function", {}).get("name")
for tool in tools
if tool.get("type") not in ["code_interpreter", "retrieval"]
if tool.get("type") not in ["code_interpreter", "retrieval", "file_search"]
)
required_file_ids = set(file_ids) # Convert file_ids to a set for unordered comparison

for assistant in candidate_assistants:
# Check if instructions are similar
Expand All @@ -454,11 +464,12 @@ def find_matching_assistant(self, candidate_assistants, instructions, tools, fil
continue

# Preprocess the assistant's tools
assistant_tool_types = set(tool.type for tool in assistant.tools)
assistant_tool_types = set(
"file_search" if tool.type in ["retrieval", "file_search"] else tool.type for tool in assistant.tools
)
assistant_function_names = set(tool.function.name for tool in assistant.tools if hasattr(tool, "function"))
assistant_file_ids = set(getattr(assistant, "file_ids", [])) # Convert to set for comparison

# Check if the tool types, function names, and file IDs match
# Check if the tool types, function names match
if required_tool_types != assistant_tool_types or required_function_names != assistant_function_names:
logger.warning(
"tools not match, skip assistant(%s): tools %s, functions %s",
Expand All @@ -467,9 +478,6 @@ def find_matching_assistant(self, candidate_assistants, instructions, tools, fil
assistant_function_names,
)
continue
if required_file_ids != assistant_file_ids:
logger.warning("file_ids not match, skip assistant(%s): %s", assistant.id, assistant_file_ids)
continue

# Append assistant to matching list if all conditions are met
matching_assistants.append(assistant)
Expand All @@ -496,7 +504,7 @@ def _process_assistant_config(self, llm_config, assistant_config):

# Move the assistant related configurations to assistant_config
# It's important to keep forward compatibility
assistant_config_items = ["assistant_id", "tools", "file_ids", "check_every_ms"]
assistant_config_items = ["assistant_id", "tools", "file_ids", "tool_resources", "check_every_ms"]
for item in assistant_config_items:
if openai_client_cfg.get(item) is not None and openai_assistant_cfg.get(item) is None:
openai_assistant_cfg[item] = openai_client_cfg[item]
Expand Down
103 changes: 103 additions & 0 deletions autogen/oai/openai_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import importlib.metadata
import json
import logging
import os
import re
import tempfile
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Union

from dotenv import find_dotenv, load_dotenv
from openai import OpenAI
from openai.types.beta.assistant import Assistant
from packaging.version import parse

NON_CACHE_KEY = ["api_key", "base_url", "api_type", "api_version"]
DEFAULT_AZURE_API_VERSION = "2024-02-15-preview"
Expand Down Expand Up @@ -675,3 +678,103 @@ def retrieve_assistants_by_name(client: OpenAI, name: str) -> List[Assistant]:
if assistant.name == name:
candidate_assistants.append(assistant)
return candidate_assistants


def detect_gpt_assistant_api_version() -> str:
"""Detect the openai assistant API version"""
oai_version = importlib.metadata.version("openai")
if parse(oai_version) < parse("1.21"):
return "v1"
else:
return "v2"


def create_gpt_vector_store(client: OpenAI, name: str, fild_ids: List[str]) -> Any:
"""Create a openai vector store for gpt assistant"""

vector_store = client.beta.vector_stores.create(name=name)
# poll the status of the file batch for completion.
batch = client.beta.vector_stores.file_batches.create_and_poll(vector_store_id=vector_store.id, file_ids=fild_ids)

if batch.status == "in_progress":
time.sleep(1)
logging.debug(f"file batch status: {batch.file_counts}")
batch = client.beta.vector_stores.file_batches.poll(vector_store_id=vector_store.id, batch_id=batch.id)

if batch.status == "completed":
return vector_store

raise ValueError(f"Failed to upload files to vector store {vector_store.id}:{batch.status}")


def create_gpt_assistant(
client: OpenAI, name: str, instructions: str, model: str, assistant_config: Dict[str, Any]
) -> Assistant:
"""Create a openai gpt assistant"""

assistant_create_kwargs = {}
gpt_assistant_api_version = detect_gpt_assistant_api_version()
tools = assistant_config.get("tools", [])

if gpt_assistant_api_version == "v2":
tool_resources = assistant_config.get("tool_resources", {})
file_ids = assistant_config.get("file_ids")
if tool_resources.get("file_search") is not None and file_ids is not None:
raise ValueError(
"Cannot specify both `tool_resources['file_search']` tool and `file_ids` in the assistant config."
)

# Designed for backwards compatibility for the V1 API
# Instead of V1 AssistantFile, files are attached to Assistants using the tool_resources object.
for tool in tools:
if tool["type"] == "retrieval":
tool["type"] = "file_search"
if file_ids is not None:
# create a vector store for the file search tool
vs = create_gpt_vector_store(client, f"{name}-vectorestore", file_ids)
tool_resources["file_search"] = {
"vector_store_ids": [vs.id],
}
elif tool["type"] == "code_interpreter" and file_ids is not None:
tool_resources["code_interpreter"] = {
"file_ids": file_ids,
}

assistant_create_kwargs["tools"] = tools
if len(tool_resources) > 0:
assistant_create_kwargs["tool_resources"] = tool_resources
else:
# not support forwards compatibility
if "tool_resources" in assistant_config:
raise ValueError("`tool_resources` argument are not supported in the openai assistant V1 API.")
if any(tool["type"] == "file_search" for tool in tools):
raise ValueError(
"`file_search` tool are not supported in the openai assistant V1 API, please use `retrieval`."
)
assistant_create_kwargs["tools"] = tools
assistant_create_kwargs["file_ids"] = assistant_config.get("file_ids", [])

logging.info(f"Creating assistant with config: {assistant_create_kwargs}")
return client.beta.assistants.create(name=name, instructions=instructions, model=model, **assistant_create_kwargs)


def update_gpt_assistant(client: OpenAI, assistant_id: str, assistant_config: Dict[str, Any]) -> Assistant:
"""Update openai gpt assistant"""

gpt_assistant_api_version = detect_gpt_assistant_api_version()
assistant_update_kwargs = {}

if assistant_config.get("tools") is not None:
assistant_update_kwargs["tools"] = assistant_config["tools"]

if assistant_config.get("instructions") is not None:
assistant_update_kwargs["instructions"] = assistant_config["instructions"]

if gpt_assistant_api_version == "v2":
if assistant_config.get("tool_resources") is not None:
assistant_update_kwargs["tool_resources"] = assistant_config["tool_resources"]
else:
if assistant_config.get("file_ids") is not None:
assistant_update_kwargs["file_ids"] = assistant_config["file_ids"]

return client.beta.assistants.update(assistant_id=assistant_id, **assistant_update_kwargs)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
__version__ = version["__version__"]

install_requires = [
"openai>=1.3,<1.21",
"openai>=1.3",
"diskcache",
"termcolor",
"flaml",
Expand Down
44 changes: 15 additions & 29 deletions test/agentchat/contrib/test_gpt_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import autogen
from autogen import OpenAIWrapper, UserProxyAgent
from autogen.agentchat.contrib.gpt_assistant_agent import GPTAssistantAgent
from autogen.oai.openai_utils import retrieve_assistants_by_name
from autogen.oai.openai_utils import detect_gpt_assistant_api_version, retrieve_assistants_by_name

sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
from conftest import reason, skip_openai # noqa: E402
Expand Down Expand Up @@ -264,6 +264,7 @@ def test_get_assistant_files() -> None:
openai_client = OpenAIWrapper(config_list=openai_config_list)._clients[0]._oai_client
file = openai_client.files.create(file=open(current_file_path, "rb"), purpose="assistants")
name = f"For test_get_assistant_files {uuid.uuid4()}"
gpt_assistant_api_version = detect_gpt_assistant_api_version()

# keep it to test older version of assistant config
assistant = GPTAssistantAgent(
Expand All @@ -277,10 +278,17 @@ def test_get_assistant_files() -> None:
)

try:
files = assistant.openai_client.beta.assistants.files.list(assistant_id=assistant.assistant_id)
retrieved_file_ids = [fild.id for fild in files]
if gpt_assistant_api_version == "v1":
files = assistant.openai_client.beta.assistants.files.list(assistant_id=assistant.assistant_id)
retrieved_file_ids = [fild.id for fild in files]
elif gpt_assistant_api_version == "v2":
oas_assistant = assistant.openai_client.beta.assistants.retrieve(assistant_id=assistant.assistant_id)
vectorstore_ids = oas_assistant.tool_resources.file_search.vector_store_ids
retrieved_file_ids = []
for vectorstore_id in vectorstore_ids:
files = assistant.openai_client.beta.vector_stores.files.list(vector_store_id=vectorstore_id)
retrieved_file_ids.extend([fild.id for fild in files])
expected_file_id = file.id

finally:
assistant.delete_assistant()
openai_client.files.delete(file.id)
Expand Down Expand Up @@ -401,7 +409,7 @@ def test_assistant_mismatch_retrieval() -> None:
"tools": [
{"type": "function", "function": function_1_schema},
{"type": "function", "function": function_2_schema},
{"type": "retrieval"},
{"type": "file_search"},
{"type": "code_interpreter"},
],
"file_ids": [file_1.id, file_2.id],
Expand All @@ -411,7 +419,6 @@ def test_assistant_mismatch_retrieval() -> None:
name = f"For test_assistant_retrieval {uuid.uuid4()}"

assistant_first, assistant_instructions_mistaching = None, None
assistant_file_ids_mismatch, assistant_tools_mistaching = None, None
try:
assistant_first = GPTAssistantAgent(
name,
Expand All @@ -432,30 +439,11 @@ def test_assistant_mismatch_retrieval() -> None:
)
assert len(candidate_instructions_mistaching) == 2

# test mismatch fild ids
file_ids_mismatch_llm_config = {
"tools": [
{"type": "code_interpreter"},
{"type": "retrieval"},
{"type": "function", "function": function_2_schema},
{"type": "function", "function": function_1_schema},
],
"file_ids": [file_2.id],
"config_list": openai_config_list,
}
assistant_file_ids_mismatch = GPTAssistantAgent(
name,
instructions="This is a test",
llm_config=file_ids_mismatch_llm_config,
)
candidate_file_ids_mismatch = retrieve_assistants_by_name(assistant_file_ids_mismatch.openai_client, name)
assert len(candidate_file_ids_mismatch) == 3

# test tools mismatch
tools_mismatch_llm_config = {
"tools": [
{"type": "code_interpreter"},
{"type": "retrieval"},
{"type": "file_search"},
{"type": "function", "function": function_3_schema},
],
"file_ids": [file_2.id, file_1.id],
Expand All @@ -467,15 +455,13 @@ def test_assistant_mismatch_retrieval() -> None:
llm_config=tools_mismatch_llm_config,
)
candidate_tools_mismatch = retrieve_assistants_by_name(assistant_tools_mistaching.openai_client, name)
assert len(candidate_tools_mismatch) == 4
assert len(candidate_tools_mismatch) == 3

finally:
if assistant_first:
assistant_first.delete_assistant()
if assistant_instructions_mistaching:
assistant_instructions_mistaching.delete_assistant()
if assistant_file_ids_mismatch:
assistant_file_ids_mismatch.delete_assistant()
if assistant_tools_mistaching:
assistant_tools_mistaching.delete_assistant()

Expand Down

0 comments on commit a41182a

Please sign in to comment.