forked from TsinghuaDatabaseGroup/DB-GPT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
base.py
91 lines (76 loc) · 2.89 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import logging
from abc import abstractmethod
from typing import List, NamedTuple, Set, Union
from pydantic import BaseModel, Field
from multiagents.llms import BaseLLM
from multiagents.memory import BaseMemory, ChatHistoryMemory
from multiagents.message import Message
from multiagents.custom_parser import OutputParser
from string import Template
class BaseAgent(BaseModel):
name: str
llm: BaseLLM
output_parser: OutputParser
prompt_template: str = Field(default="")
role_description: str = Field(default="")
memory: BaseMemory = Field(default_factory=ChatHistoryMemory)
max_retry: int = Field(default=3)
receiver: Set[str] = Field(default=set({"all"}))
async_mode: bool = Field(default=True)
language: str = Field(default="en")
knowledge_list: list = Field(default=[])
@abstractmethod
def step(self, env_description: str = "") -> Message:
"""Get one step response"""
pass
@abstractmethod
def astep(self, env_description: str = "") -> Message:
"""Asynchronous version of step"""
pass
@abstractmethod
def reset(self) -> None:
"""Reset the agent"""
pass
@abstractmethod
def add_message_to_memory(self, messages: List[Message]) -> None:
"""Add a message to the memory"""
pass
def get_all_prompts(self, **kwargs):
prompt = Template(self.prompt_template).safe_substitute(
**kwargs
)
return prompt
def get_receiver(self) -> Set[str]:
return self.receiver
def set_receiver(self, receiver: Union[Set[str], str]) -> None:
if isinstance(receiver, str):
self.receiver = set({receiver})
elif isinstance(receiver, set):
self.receiver = receiver
else:
raise ValueError(
"input argument `receiver` must be a string or a set of string"
)
def add_receiver(self, receiver: Union[Set[str], str]) -> None:
if isinstance(receiver, str):
self.receiver.add(receiver)
elif isinstance(receiver, set):
self.receiver = self.receiver.union(receiver)
else:
raise ValueError(
"input argument `receiver` must be a string or a set of string"
)
def remove_receiver(self, receiver: Union[Set[str], str]) -> None:
if isinstance(receiver, str):
try:
self.receiver.remove(receiver)
except KeyError as e:
logging.warning(f"Receiver {receiver} not found.")
elif isinstance(receiver, set):
self.receiver = self.receiver.difference(receiver)
else:
raise ValueError(
"input argument `receiver` must be a string or a set of string"
)
def enable_feedback(self):
return self.llm is not None and hasattr(self.llm, 'enable_feedback') and self.llm.enable_feedback