forked from Significant-Gravitas/AutoGPT
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request Significant-Gravitas#1658 from wangxuqi/milvus_memory
Fix Milvus as a long-term memory backend.
- Loading branch information
Showing
4 changed files
with
129 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import random | ||
import string | ||
import unittest | ||
|
||
from autogpt.config import Config | ||
from autogpt.memory.milvus import MilvusMemory | ||
|
||
|
||
class TestMilvusMemory(unittest.TestCase): | ||
def random_string(self, length): | ||
return "".join(random.choice(string.ascii_letters) for _ in range(length)) | ||
|
||
def setUp(self): | ||
cfg = Config() | ||
cfg.milvus_addr = "localhost:19530" | ||
self.memory = MilvusMemory(cfg) | ||
self.memory.clear() | ||
|
||
# Add example texts to the cache | ||
self.example_texts = [ | ||
"The quick brown fox jumps over the lazy dog", | ||
"I love machine learning and natural language processing", | ||
"The cake is a lie, but the pie is always true", | ||
"ChatGPT is an advanced AI model for conversation", | ||
] | ||
|
||
for text in self.example_texts: | ||
self.memory.add(text) | ||
|
||
# Add some random strings to test noise | ||
for _ in range(5): | ||
self.memory.add(self.random_string(10)) | ||
|
||
def test_get_relevant(self): | ||
query = "I'm interested in artificial intelligence and NLP" | ||
k = 3 | ||
relevant_texts = self.memory.get_relevant(query, k) | ||
|
||
print(f"Top {k} relevant texts for the query '{query}':") | ||
for i, text in enumerate(relevant_texts, start=1): | ||
print(f"{i}. {text}") | ||
|
||
self.assertEqual(len(relevant_texts), k) | ||
self.assertIn(self.example_texts[1], relevant_texts) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import os | ||
import sys | ||
import unittest | ||
|
||
from autogpt.memory.milvus import MilvusMemory | ||
|
||
|
||
def MockConfig(): | ||
return type( | ||
"MockConfig", | ||
(object,), | ||
{ | ||
"debug_mode": False, | ||
"continuous_mode": False, | ||
"speak_mode": False, | ||
"milvus_collection": "autogpt", | ||
"milvus_addr": "localhost:19530", | ||
|
||
}, | ||
) | ||
|
||
|
||
class TestMilvusMemory(unittest.TestCase): | ||
def setUp(self): | ||
self.cfg = MockConfig() | ||
self.memory = MilvusMemory(self.cfg) | ||
|
||
def test_add(self): | ||
text = "Sample text" | ||
self.memory.clear() | ||
self.memory.add(text) | ||
result = self.memory.get(text) | ||
self.assertEqual([text], result) | ||
|
||
def test_clear(self): | ||
self.memory.clear() | ||
self.assertEqual(self.memory.collection.num_entities, 0) | ||
|
||
def test_get(self): | ||
text = "Sample text" | ||
self.memory.clear() | ||
self.memory.add(text) | ||
result = self.memory.get(text) | ||
self.assertEqual(result, [text]) | ||
|
||
def test_get_relevant(self): | ||
text1 = "Sample text 1" | ||
text2 = "Sample text 2" | ||
self.memory.clear() | ||
self.memory.add(text1) | ||
self.memory.add(text2) | ||
result = self.memory.get_relevant(text1, 1) | ||
self.assertEqual(result, [text1]) | ||
|
||
def test_get_stats(self): | ||
text = "Sample text" | ||
self.memory.clear() | ||
self.memory.add(text) | ||
stats = self.memory.get_stats() | ||
self.assertEqual(15, len(stats)) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |