forked from sinaptik-ai/pandas-ai
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmemory.py
110 lines (90 loc) · 3.33 KB
/
memory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
""" Memory class to store the conversations """
from typing import Union
class Memory:
"""Memory class to store the conversations"""
_messages: list
_memory_size: int
_agent_info: str
def __init__(self, memory_size: int = 1, agent_info: Union[str, None] = None):
self._messages = []
self._memory_size = memory_size
self._agent_info = agent_info
def add(self, message: str, is_user: bool):
self._messages.append({"message": message, "is_user": is_user})
def count(self) -> int:
return len(self._messages)
def all(self) -> list:
return self._messages
def last(self) -> dict:
return self._messages[-1]
def _truncate(self, message: Union[str, int], max_length: int = 100) -> str:
"""
Truncates the message if it is longer than max_length
"""
return (
f"{message[:max_length]} ..." if len(str(message)) > max_length else message
)
def get_messages(self, limit: int = None) -> list:
"""
Returns the conversation messages based on limit parameter
or default memory size
"""
limit = self._memory_size if limit is None else limit
return [
f"{'### QUERY' if message['is_user'] else '### ANSWER'}\n {message['message'] if message['is_user'] else self._truncate(message['message'])}"
for message in self._messages[-limit:]
]
def get_conversation(self, limit: int = None) -> str:
"""
Returns the conversation messages based on limit parameter
or default memory size
"""
return "\n".join(self.get_messages(limit))
def get_previous_conversation(self) -> str:
"""
Returns the previous conversation but the last message
"""
messages = self.get_messages(self._memory_size)
return "" if len(messages) <= 1 else "\n".join(messages[:-1])
def get_last_message(self) -> str:
"""
Returns the last message in the conversation
"""
messages = self.get_messages(self._memory_size)
return "" if len(messages) == 0 else messages[-1]
def get_system_prompt(self) -> str:
return self._agent_info
def to_json(self):
messages = []
for message in self.all():
if message["is_user"]:
messages.append({"role": "user", "message": message["message"]})
else:
messages.append({"role": "assistant", "message": message["message"]})
return messages
def to_openai_messages(self):
"""
Returns the conversation messages in the format expected by the OpenAI API
"""
messages = []
if self.agent_info:
messages.append(
{
"role": "system",
"content": self.get_system_prompt(),
}
)
for message in self.all():
if message["is_user"]:
messages.append({"role": "user", "content": message["message"]})
else:
messages.append({"role": "assistant", "content": message["message"]})
return messages
def clear(self):
self._messages = []
@property
def size(self):
return self._memory_size
@property
def agent_info(self):
return self._agent_info