forked from Deeptechia/geppetto
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_gemini.py
97 lines (77 loc) · 3.26 KB
/
test_gemini.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import sys
import unittest
from geppetto.gemini_handler import GeminiHandler
from geppetto.exceptions import InvalidThreadFormatError
from unittest.mock import Mock, patch
from tests import TestBase
script_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(script_dir)
sys.path.append(parent_dir)
class TestGemini(TestBase):
@classmethod
def setUpClass(cls):
cls.patcher = patch("geppetto.gemini_handler.genai")
cls.mock_genai = cls.patcher.start()
cls.gemini_handler = GeminiHandler(personality="Your AI personality")
@classmethod
def tearDownClass(cls):
cls.patcher.stop()
def test_personality(self):
self.assertEqual(self.gemini_handler.personality, "Your AI personality")
@patch("geppetto.gemini_handler.convert_gemini_to_slack")
def test_llm_generate_content(self, mock_to_markdown):
user_prompt = [
{"role": "user", "parts": ["Hello"]},
{"role": "user", "parts": ["How are you?"]},
]
mock_response = Mock()
mock_response.text = "Mocked Gemini response"
self.gemini_handler.client.generate_content.return_value = mock_response
mock_to_markdown.return_value = "Mocked Markdown data"
response = self.gemini_handler.llm_generate_content(user_prompt)
self.assertEqual(response, "Mocked Markdown data")
mock_to_markdown.assert_called_once_with("Mocked Gemini response")
def test_get_prompt_from_thread(self):
thread = [
{"role": "slack_user", "content": "Message 1"},
{"role": "geppetto", "content": "Message 2"},
]
ROLE_FIELD = "role"
MSG_FIELD = "parts"
prompt = self.gemini_handler.get_prompt_from_thread(
thread, assistant_tag="geppetto", user_tag="slack_user"
)
self.assertIsInstance(prompt, list)
for msg in prompt:
self.assertIsInstance(msg, dict)
self.assertIn(ROLE_FIELD, msg)
self.assertIn(MSG_FIELD, msg)
self.assertIsInstance(msg[MSG_FIELD], list)
self.assertTrue(msg[MSG_FIELD])
with self.assertRaises(InvalidThreadFormatError):
incomplete_thread = [{"role": "geppetto"}]
self.gemini_handler.get_prompt_from_thread(
incomplete_thread, assistant_tag="geppetto", user_tag="slack_user"
)
def test_llm_generate_content_user_repetition(self):
user_prompt = [
{"role": "user", "parts": ["Hello"]},
{"role": "user", "parts": ["How are you?"]},
{"role": "geppetto", "parts": ["I'm fine."]},
]
with patch.object(
self.gemini_handler.client, "generate_content"
) as mock_generate_content:
mock_response = Mock()
mock_response.text = "Mocked Gemini response"
mock_generate_content.return_value = mock_response
self.gemini_handler.llm_generate_content(user_prompt)
mock_generate_content.assert_called_once_with(
[
{"role": "user", "parts": ["Hello", "How are you?"]},
{"role": "geppetto", "parts": ["I'm fine."]},
]
)
if __name__ == "__main__":
unittest.main()