forked from TsinghuaDatabaseGroup/DB-GPT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconversation_agent.py
97 lines (81 loc) · 3.54 KB
/
conversation_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from __future__ import annotations
import logging
import bdb
from string import Template
from typing import TYPE_CHECKING, List
from multiagents.message import Message
from .base import BaseAgent
from . import agent_registry
@agent_registry.register("conversation")
class ConversationAgent(BaseAgent):
def step(self, env_description: str = "") -> Message:
prompt = self._fill_prompt_template(env_description)
parsed_response = None
for i in range(self.max_retry):
try:
response = self.llm.generate_response(prompt)
parsed_response = self.output_parser.parse(response)
break
except KeyboardInterrupt:
raise
except Exception as e:
logging.error(e)
logging.warning("Retrying...")
continue
if parsed_response is None:
logging.error(f"{self.name} failed to generate valid response.")
message = Message(
content={"diagnose": "", "solution": [], "knowledge": ""}
if parsed_response is None
else {"diagnose": parsed_response.return_values["diagnose"], "solution": parsed_response.return_values["solution"], "knowledge": parsed_response.return_values["knowledge"]},
sender=self.name,
receiver=self.get_receiver(),
)
return message
async def astep(self, env_description: str = "") -> Message:
"""Asynchronous version of step"""
prompt = self._fill_prompt_template(env_description)
parsed_response = None
for i in range(self.max_retry):
try:
response = await self.llm.agenerate_response(prompt)
parsed_response = self.output_parser.parse(response)
break
except (KeyboardInterrupt, bdb.BdbQuit):
raise
except Exception as e:
logging.error(e)
logging.warning("Retrying...")
continue
if parsed_response is None:
logging.error(f"{self.name} failed to generate valid response.")
message = Message(
content={"diagnose": "", "solution": [], "knowledge": ""}
if parsed_response is None
else {"diagnose": parsed_response.return_values["diagnose"], "solution": parsed_response.return_values["solution"], "knowledge": parsed_response.return_values["knowledge"]},
sender=self.name,
receiver=self.get_receiver(),
)
return message
def _fill_prompt_template(self, env_description: str = "") -> str:
"""Fill the placeholders in the prompt template
In the conversation agent, three placeholders are supported:
- ${agent_name}: the name of the agent
- ${env_description}: the description of the environment
- ${role_description}: the description of the role of the agent
- ${chat_history}: the chat history of the agent
"""
input_arguments = {
"agent_name": self.name,
"env_description": env_description,
"role_description": self.role_description,
"chat_history": self.memory.to_string(add_sender_prefix=True),
}
return Template(self.prompt_template).safe_substitute(input_arguments)
def add_message_to_memory(self, messages: List[Message]) -> None:
# pdb.set_trace()
self.memory.add_message(messages)
def reset(self) -> None:
"""Reset the agent"""
self.memory.reset()
# TODO: reset receiver