forked from agiresearch/Cerebrum
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a6806ad
commit 6e65921
Showing
16 changed files
with
973 additions
and
1 deletion.
There are no files selected for viewing
174 changes: 174 additions & 0 deletions
174
cerebrum/example/agents/festival_card_designer/agent.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
from cerebrum.agents.base import BaseAgent | ||
from cerebrum.llm.communication import LLMQuery | ||
import json | ||
|
||
class FestivalCardDesigner(BaseAgent): | ||
def __init__(self, agent_name, task_input, config_): | ||
super().__init__(agent_name, task_input, config_) | ||
|
||
self.plan_max_fail_times = 3 | ||
self.tool_call_max_fail_times = 3 | ||
|
||
self.start_time = None | ||
self.end_time = None | ||
self.request_waiting_times: list = [] | ||
self.request_turnaround_times: list = [] | ||
self.task_input = task_input | ||
self.messages = [] | ||
self.workflow_mode = "manual" # (manual, automatic) | ||
self.rounds = 0 | ||
|
||
|
||
def build_system_instruction(self): | ||
prefix = "".join(["".join(self.config["description"])]) | ||
|
||
plan_instruction = "".join( | ||
[ | ||
f"You are given the available tools from the tool list: {json.dumps(self.tool_info)} to help you solve problems. ", | ||
"Generate a plan with comprehensive yet minimal steps to fulfill the task. ", | ||
"The plan must follow the json format as below: ", | ||
"[", | ||
'{"action_type": "action_type_value", "action": "action_value","tool_use": [tool_name1, tool_name2,...]}', | ||
'{"action_type": "action_type_value", "action": "action_value", "tool_use": [tool_name1, tool_name2,...]}', | ||
"...", | ||
"]", | ||
"In each step of the planned plan, identify tools to use and recognize no tool is necessary. ", | ||
"Followings are some plan examples. ", | ||
"[" "[", | ||
'{"action_type": "tool_use", "action": "gather information from arxiv. ", "tool_use": ["arxiv"]},', | ||
'{"action_type": "chat", "action": "write a summarization based on the gathered information. ", "tool_use": []}', | ||
"];", | ||
"[", | ||
'{"action_type": "tool_use", "action": "gather information from arxiv. ", "tool_use": ["arxiv"]},', | ||
'{"action_type": "chat", "action": "understand the current methods and propose ideas that can improve ", "tool_use": []}', | ||
"]", | ||
"]", | ||
] | ||
) | ||
|
||
if self.workflow_mode == "manual": | ||
self.messages.append({"role": "system", "content": prefix}) | ||
|
||
else: | ||
assert self.workflow_mode == "automatic" | ||
self.messages.append({"role": "system", "content": prefix}) | ||
self.messages.append({"role": "user", "content": plan_instruction}) | ||
|
||
def automatic_workflow(self): | ||
for i in range(self.plan_max_fail_times): | ||
response = self.send_request( | ||
agent_name=self.agent_name, | ||
query=LLMQuery( | ||
messages=self.messages, tools=None, message_return_type="json" | ||
), | ||
)["response"] | ||
|
||
workflow = self.check_workflow(response.response_message) | ||
|
||
self.rounds += 1 | ||
|
||
if workflow: | ||
return workflow | ||
|
||
else: | ||
self.messages.append( | ||
{ | ||
"role": "assistant", | ||
"content": f"Fail {i+1} times to generate a valid plan. I need to regenerate a plan", | ||
} | ||
) | ||
return None | ||
|
||
def manual_workflow(self): | ||
workflow = [ | ||
{ | ||
"action_type": "chat", | ||
"action": "Gather user information (festival theme, target audience, card size) and identify card design elements (colors, fonts, imagery)", | ||
"tool_use": [] | ||
}, | ||
{ | ||
"message": "Generate card layout options", | ||
"tool_use": ["text_to_image"] | ||
}, | ||
{ | ||
"message": "Summarize the card Add textual elements to the festival card ", | ||
"tool_use": [] | ||
} | ||
] | ||
return workflow | ||
|
||
def run(self): | ||
self.build_system_instruction() | ||
|
||
task_input = self.task_input | ||
|
||
self.messages.append({"role": "user", "content": task_input}) | ||
|
||
workflow = None | ||
|
||
if self.workflow_mode == "automatic": | ||
workflow = self.automatic_workflow() | ||
self.messages = self.messages[:1] # clear long context | ||
|
||
else: | ||
assert self.workflow_mode == "manual" | ||
workflow = self.manual_workflow() | ||
|
||
self.messages.append( | ||
{ | ||
"role": "user", | ||
"content": f"[Thinking]: The workflow generated for the problem is {json.dumps(workflow)}. Follow the workflow to solve the problem step by step. ", | ||
} | ||
) | ||
|
||
try: | ||
if workflow: | ||
final_result = "" | ||
|
||
for i, step in enumerate(workflow): | ||
action_type = step["action_type"] | ||
action = step["action"] | ||
tool_use = step["tool_use"] | ||
|
||
prompt = f"At step {i + 1}, you need to: {action}. " | ||
self.messages.append({"role": "user", "content": prompt}) | ||
|
||
if tool_use: | ||
selected_tools = self.pre_select_tools(tool_use) | ||
|
||
else: | ||
selected_tools = None | ||
|
||
response = self.send_request( | ||
agent_name=self.agent_name, | ||
query=LLMQuery( | ||
messages=self.messages, | ||
tools=selected_tools, | ||
action_type=action_type, | ||
), | ||
)["response"] | ||
|
||
self.messages.append({"role": "assistant", "content": response.response_message}) | ||
|
||
self.rounds += 1 | ||
|
||
|
||
final_result = self.messages[-1]["content"] | ||
|
||
return { | ||
"agent_name": self.agent_name, | ||
"result": final_result, | ||
"rounds": self.rounds, | ||
} | ||
|
||
else: | ||
return { | ||
"agent_name": self.agent_name, | ||
"result": "Failed to generate a valid workflow in the given times.", | ||
"rounds": self.rounds, | ||
|
||
} | ||
|
||
except Exception as e: | ||
|
||
return {} |
15 changes: 15 additions & 0 deletions
15
cerebrum/example/agents/festival_card_designer/config.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
{ | ||
"name": "festival_card_designer", | ||
"description": [ | ||
"You are a festival card designer. ", | ||
"Create unique and eye-catching festival cards based on user preferences and festival themes." | ||
], | ||
"tools": [ | ||
"stability-ai/text_to_image" | ||
], | ||
"meta": { | ||
"author": "example", | ||
"version": "0.0.1", | ||
"license": "CC0" | ||
} | ||
} |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
from cerebrum.agents.base import BaseAgent | ||
from cerebrum.llm.communication import LLMQuery | ||
import json | ||
|
||
class LanguageTutor(BaseAgent): | ||
def __init__(self, agent_name, task_input, config_): | ||
super().__init__(agent_name, task_input, config_) | ||
|
||
self.plan_max_fail_times = 3 | ||
self.tool_call_max_fail_times = 3 | ||
|
||
self.start_time = None | ||
self.end_time = None | ||
self.request_waiting_times: list = [] | ||
self.request_turnaround_times: list = [] | ||
self.task_input = task_input | ||
self.messages = [] | ||
self.workflow_mode = "manual" # (manual, automatic) | ||
# self.workflow_mode = "automatic" | ||
self.rounds = 0 | ||
|
||
def build_system_instruction(self): | ||
prefix = "".join(["".join(self.config["description"])]) | ||
|
||
plan_instruction = "".join( | ||
[ | ||
f"You are given the available tools from the tool list: {json.dumps(self.tool_info)} to help you solve problems. ", | ||
"Generate a plan with comprehensive yet minimal steps to fulfill the task. ", | ||
"The plan must follow the json format as below: ", | ||
"[", | ||
'{"action_type": "action_type_value", "action": "action_value","tool_use": [tool_name1, tool_name2,...]}', | ||
'{"action_type": "action_type_value", "action": "action_value", "tool_use": [tool_name1, tool_name2,...]}', | ||
"...", | ||
"]", | ||
"In each step of the planned plan, identify tools to use and recognize no tool is necessary. ", | ||
"Followings are some plan examples. ", | ||
"[" "[", | ||
'{"action_type": "tool_use", "action": "gather information from arxiv. ", "tool_use": ["arxiv"]},', | ||
'{"action_type": "chat", "action": "write a summarization based on the gathered information. ", "tool_use": []}', | ||
"];", | ||
"[", | ||
'{"action_type": "tool_use", "action": "gather information from arxiv. ", "tool_use": ["arxiv"]},', | ||
'{"action_type": "chat", "action": "understand the current methods and propose ideas that can improve ", "tool_use": []}', | ||
"]", | ||
"]", | ||
] | ||
) | ||
|
||
if self.workflow_mode == "manual": | ||
self.messages.append({"role": "system", "content": prefix}) | ||
|
||
else: | ||
assert self.workflow_mode == "automatic" | ||
self.messages.append({"role": "system", "content": prefix}) | ||
self.messages.append({"role": "user", "content": plan_instruction}) | ||
|
||
def automatic_workflow(self): | ||
for i in range(self.plan_max_fail_times): | ||
response = self.send_request( | ||
agent_name=self.agent_name, | ||
query=LLMQuery( | ||
messages=self.messages, tools=None, message_return_type="json" | ||
), | ||
)["response"] | ||
|
||
workflow = self.check_workflow(response.response_message) | ||
|
||
self.rounds += 1 | ||
|
||
if workflow: | ||
return workflow | ||
|
||
else: | ||
self.messages.append( | ||
{ | ||
"role": "assistant", | ||
"content": f"Fail {i+1} times to generate a valid plan. I need to regenerate a plan", | ||
} | ||
) | ||
return None | ||
|
||
def manual_workflow(self): | ||
workflow = [ | ||
{ | ||
"action_type": "chat", | ||
"action": "Identify user's target language and learning goals and create grammar explanations and practice sentences.", | ||
"tool_use": [] | ||
}, | ||
{ | ||
"action_type": "chat", | ||
"action": "Provide audio examples of pronunciation.", | ||
"tool_use": [] | ||
}, | ||
{ | ||
"action_type": "chat", | ||
"action": "Engage in conversation practice with the user.", | ||
"tool_use": [] | ||
} | ||
] | ||
return workflow | ||
|
||
def run(self): | ||
self.build_system_instruction() | ||
|
||
task_input = self.task_input | ||
|
||
self.messages.append({"role": "user", "content": task_input}) | ||
|
||
workflow = None | ||
|
||
if self.workflow_mode == "automatic": | ||
workflow = self.automatic_workflow() | ||
self.messages = self.messages[:1] # clear long context | ||
|
||
else: | ||
assert self.workflow_mode == "manual" | ||
workflow = self.manual_workflow() | ||
|
||
self.messages.append( | ||
{ | ||
"role": "user", | ||
"content": f"[Thinking]: The workflow generated for the problem is {json.dumps(workflow)}. Follow the workflow to solve the problem step by step. ", | ||
} | ||
) | ||
|
||
try: | ||
if workflow: | ||
final_result = "" | ||
|
||
for i, step in enumerate(workflow): | ||
action_type = step["action_type"] | ||
action = step["action"] | ||
tool_use = step["tool_use"] | ||
|
||
prompt = f"At step {i + 1}, you need to: {action}. " | ||
self.messages.append({"role": "user", "content": prompt}) | ||
|
||
if tool_use: | ||
selected_tools = self.pre_select_tools(tool_use) | ||
|
||
else: | ||
selected_tools = None | ||
|
||
response = self.send_request( | ||
agent_name=self.agent_name, | ||
query=LLMQuery( | ||
messages=self.messages, | ||
tools=selected_tools, | ||
action_type=action_type, | ||
), | ||
)["response"] | ||
|
||
self.messages.append({"role": "assistant", "content": response.response_message}) | ||
|
||
self.rounds += 1 | ||
|
||
|
||
final_result = self.messages[-1]["content"] | ||
|
||
return { | ||
"agent_name": self.agent_name, | ||
"result": final_result, | ||
"rounds": self.rounds, | ||
} | ||
|
||
else: | ||
return { | ||
"agent_name": self.agent_name, | ||
"result": "Failed to generate a valid workflow in the given times.", | ||
"rounds": self.rounds, | ||
|
||
} | ||
|
||
except Exception as e: | ||
|
||
return {} | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
{ | ||
"name": "language_tutor", | ||
"description": [ | ||
"You are a language tutor. You can provide vocabulary exercises, grammar explanations, and conversation practice. ", | ||
"You can also offer pronunciation guidance and cultural insights. " | ||
], | ||
"tools": [ | ||
], | ||
"meta": { | ||
"author": "example", | ||
"version": "0.0.1", | ||
"license": "CC0" | ||
}, | ||
"build": { | ||
"entry": "agent.py", | ||
"module": "LanguageTutor" | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
|
Oops, something went wrong.