Skip to content

Commit

Permalink
[add] dynamic tool matching
Browse files Browse the repository at this point in the history
  • Loading branch information
zhouxh19 committed Dec 7, 2023
1 parent 183f604 commit af59276
Show file tree
Hide file tree
Showing 33 changed files with 150 additions and 245 deletions.
24 changes: 3 additions & 21 deletions doc2knowledge/knowledge_clustering.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from multiagents.llms.sentence_embedding import sentence_embedding

from openai import OpenAI
import numpy as np
import json
Expand Down Expand Up @@ -57,27 +59,7 @@

embeddings = []
for i,text in enumerate(texts):

payload = {
"input": [text["name"]],
"model": "text-embedding-ada-002"
}

timeout=10
ok = 0
while timeout>0:
try:
response = requests.post(url, json=payload, headers=headers)
ok = 1
break
except Exception as e:
time.sleep(.01)
timeout -= 1

if ok == 0:
raise Exception("Failed to get response from API!")

embedding = json.loads(response.text)['data'][0]['embedding']
embedding = sentence_embedding(text["name"])
embeddings.append(embedding)
print(f"embedded {i} text")

Expand Down
17 changes: 12 additions & 5 deletions multiagents/agents/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from multiagents.utils.utils import AgentAction, AgentFinish
from multiagents.reasoning_algorithms import base_env

from multiagents.tools.retriever import api_matcher

class ToolNotExistError(BaseException):

"""Exception raised when parsing output from a command fails."""
Expand Down Expand Up @@ -157,6 +159,7 @@ class Config:
diag_id: str = ""
tools: APICaller = Field(default_factory=APICaller)
tool_memory: BaseMemory = Field(default_factory=ChatHistoryMemory)
tool_matcher: api_matcher = Field(default_factory=api_matcher)
verbose: bool = Field(default=False)
name: str = Field(default="CpuExpert")
max_history: int = 3
Expand All @@ -166,7 +169,6 @@ class Config:
alert_dict: List[dict] = []
messages: List[dict] = []


async def step(
self, former_solution: str, advice: str, task_description: str = "", **kwargs
) -> SolverMessage:
Expand Down Expand Up @@ -290,13 +292,18 @@ def _fill_prompt_template(
- ${tool_names}: the list of tool names
- ${tool_observations}: the observation of the tool in this turn
"""
#retriever = api_retriever()
#relevant_tools = retriever.query(Template(self.prompt_template).safe_substitute({"chat_history": self.memory.to_string(add_sender_prefix=True)}), self.tools)

tools = "\n".join([f"> {tool}: {self.tools.functions[tool]['desc']}" for tool in self.tools.functions])
self.tool_matcher.add_tool(self.tools)

import pdb; pdb.set_trace()
relevant_tools = self.tool_matcher.query(Template(self.prompt_template).safe_substitute({"chat_history": self.memory.to_string(add_sender_prefix=True)}))

tools = "\n".join([f"> {tool}: {relevant_tools[tool]}" for tool in relevant_tools])

tools = tools.replace("{{", "{").replace("}}", "}")
tool_names = ", ".join([tool for tool in self.tools.functions])

tool_names = ", ".join([tool for tool in relevant_tools])

input_arguments = {
"alert_info": self.alert_str,
"agent_name": self.name,
Expand Down
204 changes: 0 additions & 204 deletions multiagents/agents/tool_agent.py

This file was deleted.

5 changes: 1 addition & 4 deletions multiagents/environments/dba.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,6 @@ async def step(
self.reporter.initialize_report()
pbar.update(1)

# import pdb; pdb.set_trace()

# ================== vanilla model ==================
# self.reporter.report["anomaly description"]
# solver = self.agents[AGENT_TYPES.SOLVER][0]
Expand Down Expand Up @@ -298,7 +296,6 @@ def role_assign(self, advice: str = "", alert_info: str = "") -> List[BaseAgent]
# solver_idx = random.randint(0, len(self.agents[AGENT_TYPES.SOLVER]) - 1)
# agents= [self.agents[AGENT_TYPES.SOLVER][0]]


return agents

async def decision_making(
Expand All @@ -316,7 +313,7 @@ async def decision_making(
task_description=self.task_description,
previous_plan=previous_plan,
advice=advice)

print("\n============= Finish the initial diagnosis =============")

for i,diag in enumerate(initial_diags):
Expand Down
5 changes: 2 additions & 3 deletions multiagents/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,8 @@ def load_tools(tool_config: List[Dict], max_api_num, agent_name):
for tool in tool_config:

api_module = importlib.import_module(f"""multiagents.tools.{tool["tool_name"]}.api""")
register_functions_from_module(api_module, caller, max_api_num, agent_name) # functions


register_functions_from_module(api_module, caller, max_api_num, agent_name)

return caller


Expand Down
68 changes: 68 additions & 0 deletions multiagents/llms/sentence_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import logging
import os
from typing import Dict, List, Optional, Union
from pydantic import BaseModel, Field
from multiagents.llms.base import LLMResult
from . import llm_registry
import requests
import json
import aiohttp
import asyncio
import time
import random
import re
from termcolor import colored
from tqdm import tqdm
from openai import OpenAI


def sentence_embedding(sentence: str, model: str = "text-embedding-ada-002"):

api_key = os.environ.get("OPENAI_API_KEY")

# client = OpenAI( api_key=api_key)
# timeout=10
# ok = 0
# while timeout>0:
# try:
# response = client.embeddings.create(input=[sentence], model=model)
# ok = 1
# break
# except Exception as e:
# time.sleep(.01)
# timeout -= 1

# if ok == 0:
# raise Exception("Failed to get response from API!")

# embedding = response.data[0].embedding
# return embedding


payload = {
"input": [sentence],
"model": model
}
url = "https://api.aiaiapi.com/v1/embeddings"
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + api_key
}

timeout=10
ok = 0
while timeout>0:
try:
response = requests.post(url, json=payload, headers=headers)
ok = 1
break
except Exception as e:
time.sleep(.01)
timeout -= 1

if ok == 0:
raise Exception("Failed to get response from openai API!")

embedding = json.loads(response.text)['data'][0]['embedding']

return embedding
Loading

0 comments on commit af59276

Please sign in to comment.