-
Notifications
You must be signed in to change notification settings - Fork 17
/
test_kg_gemini.py
108 lines (89 loc) · 2.99 KB
/
test_kg_gemini.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import re
import os
import logging
import unittest
import vertexai
from falkordb import FalkorDB
from dotenv import load_dotenv
from graphrag_sdk.entity import Entity
from graphrag_sdk.source import Source
from graphrag_sdk.ontology import Ontology
from graphrag_sdk.relation import Relation
from graphrag_sdk.attribute import Attribute, AttributeType
from graphrag_sdk.models.gemini import GeminiGenerativeModel
from graphrag_sdk import KnowledgeGraph, KnowledgeGraphModelConfig
load_dotenv()
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
vertexai.init(project=os.getenv("PROJECT_ID"), location=os.getenv("REGION"))
class TestKGGemini(unittest.TestCase):
"""
Test Knowledge Graph
"""
@classmethod
def setUpClass(cls):
cls.ontology = Ontology([], [])
cls.ontology.add_entity(
Entity(
label="Actor",
attributes=[
Attribute(
name="name",
attr_type=AttributeType.STRING,
unique=True,
required=True,
),
],
)
)
cls.ontology.add_entity(
Entity(
label="Movie",
attributes=[
Attribute(
name="title",
attr_type=AttributeType.STRING,
unique=True,
required=True,
),
],
)
)
cls.ontology.add_relation(
Relation(
label="ACTED_IN",
source="Actor",
target="Movie",
attributes=[
Attribute(
name="role",
attr_type=AttributeType.STRING,
unique=False,
required=False,
),
],
)
)
cls.graph_name = "IMDB_gemini"
model = GeminiGenerativeModel(model_name="gemini-1.5-flash-001")
cls.kg = KnowledgeGraph(
name=cls.graph_name,
ontology=cls.ontology,
model_config=KnowledgeGraphModelConfig.with_model(model),
)
def test_kg_creation(self):
file_path = "tests/data/madoff.txt"
sources = [Source(file_path)]
self.kg.process_sources(sources)
chat = self.kg.chat_session()
answer = chat.send_message("How many actors acted in a movie?")
answer = answer['response']
logger.info(f"Answer: {answer}")
actors_count = re.findall(r'\d+', answer)
num_actors = 0 if len(actors_count) == 0 else int(actors_count[0])
assert num_actors > 10, "The number of actors found should be greater than 10"
def test_kg_delete(self):
self.kg.delete()
db = FalkorDB()
graphs = db.list_graphs()
self.assertNotIn(self.graph_name, graphs)