-
Notifications
You must be signed in to change notification settings - Fork 43
/
utils.py
241 lines (201 loc) · 8.29 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
from openai import OpenAI
import os
from neo4j import GraphDatabase
import numpy as np
from camel.storages import Neo4jGraph
import uuid
from summerize import process_chunks
import openai
sys_prompt_one = """
Please answer the question using insights supported by provided graph-based data relevant to medical information.
"""
sys_prompt_two = """
Modify the response to the question using the provided references. Include precise citations relevant to your answer. You may use multiple citations simultaneously, denoting each with the reference index number. For example, cite the first and third documents as [1][3]. If the references do not pertain to the response, simply provide a concise answer to the original question.
"""
# Add your own OpenAI API key
openai_api_key = os.getenv("OPENAI_API_KEY")
def get_embedding(text, mod = "text-embedding-3-small"):
client = OpenAI(api_key = os.getenv("OPENAI_API_KEY"))
response = client.embeddings.create(
input=text,
model=mod
)
return response.data[0].embedding
def fetch_texts(n4j):
# Fetch the text for each node
query = "MATCH (n) RETURN n.id AS id"
return n4j.query(query)
def add_embeddings(n4j, node_id, embedding):
# Upload embeddings to Neo4j
query = "MATCH (n) WHERE n.id = $node_id SET n.embedding = $embedding"
n4j.query(query, params = {"node_id":node_id, "embedding":embedding})
def add_nodes_emb(n4j):
nodes = fetch_texts(n4j)
for node in nodes:
# Calculate embedding for each node's text
if node['id']: # Ensure there is text to process
embedding = get_embedding(node['id'])
# Store embedding back in the node
add_embeddings(n4j, node['id'], embedding)
def add_ge_emb(graph_element):
for node in graph_element.nodes:
emb = get_embedding(node.id)
node.properties['embedding'] = emb
return graph_element
def add_gid(graph_element, gid):
for node in graph_element.nodes:
node.properties['gid'] = gid
for rel in graph_element.relationships:
rel.properties['gid'] = gid
return graph_element
def add_sum(n4j,content,gid):
sum = process_chunks(content)
creat_sum_query = """
CREATE (s:Summary {content: $sum, gid: $gid})
RETURN s
"""
s = n4j.query(creat_sum_query, {'sum': sum, 'gid': gid})
link_sum_query = """
MATCH (s:Summary {gid: $gid}), (n)
WHERE n.gid = s.gid AND NOT n:Summary
CREATE (s)-[:SUMMARIZES]->(n)
RETURN s, n
"""
n4j.query(link_sum_query, {'gid': gid})
return s
def call_llm(sys, user):
response = openai.chat.completions.create(
model="gpt-4-1106-preview",
messages=[
{"role": "system", "content": sys},
{"role": "user", "content": f" {user}"},
],
max_tokens=500,
n=1,
stop=None,
temperature=0.5,
)
return response.choices[0].message.content
def find_index_of_largest(nums):
# Sorting the list while keeping track of the original indexes
sorted_with_index = sorted((num, index) for index, num in enumerate(nums))
# Extracting the original index of the largest element
largest_original_index = sorted_with_index[-1][1]
return largest_original_index
def get_response(n4j, gid, query):
selfcont = ret_context(n4j, gid)
linkcont = link_context(n4j, gid)
user_one = "the question is: " + query + "the provided information is:" + "".join(selfcont)
res = call_llm(sys_prompt_one,user_one)
user_two = "the question is: " + query + "the last response of it is:" + res + "the references are: " + "".join(linkcont)
res = call_llm(sys_prompt_two,user_two)
return res
def link_context(n4j, gid):
cont = []
retrieve_query = """
// Match all 'n' nodes with a specific gid but not of the "Summary" type
MATCH (n)
WHERE n.gid = $gid AND NOT n:Summary
// Find all 'm' nodes where 'm' is a reference of 'n' via a 'REFERENCES' relationship
MATCH (n)-[r:REFERENCE]->(m)
WHERE NOT m:Summary
// Find all 'o' nodes connected to each 'm', and include the relationship type,
// while excluding 'Summary' type nodes and 'REFERENCE' relationship
MATCH (m)-[s]-(o)
WHERE NOT o:Summary AND TYPE(s) <> 'REFERENCE'
// Collect and return details in a structured format
RETURN n.id AS NodeId1,
m.id AS Mid,
TYPE(r) AS ReferenceType,
collect(DISTINCT {RelationType: type(s), Oid: o.id}) AS Connections
"""
res = n4j.query(retrieve_query, {'gid': gid})
for r in res:
# Expand each set of connections into separate entries with n and m
for ind, connection in enumerate(r["Connections"]):
cont.append("Reference " + str(ind) + ": " + r["NodeId1"] + "has the reference that" + r['Mid'] + connection['RelationType'] + connection['Oid'])
return cont
def ret_context(n4j, gid):
cont = []
ret_query = """
// Match all nodes with a specific gid but not of type "Summary" and collect them
MATCH (n)
WHERE n.gid = $gid AND NOT n:Summary
WITH collect(n) AS nodes
// Unwind the nodes to a pairs and match relationships between them
UNWIND nodes AS n
UNWIND nodes AS m
MATCH (n)-[r]-(m)
WHERE n.gid = m.gid AND id(n) < id(m) AND NOT n:Summary AND NOT m:Summary // Ensure each pair is processed once and exclude "Summary" nodes in relationships
WITH n, m, TYPE(r) AS relType
// Return node IDs and relationship types in structured format
RETURN n.id AS NodeId1, relType, m.id AS NodeId2
"""
res = n4j.query(ret_query, {'gid': gid})
for r in res:
cont.append(r['NodeId1'] + r['relType'] + r['NodeId2'])
return cont
def merge_similar_nodes(n4j, gid):
# Define your merge query here. Adjust labels and properties according to your graph schema
if gid:
merge_query = """
WITH 0.5 AS threshold
MATCH (n), (m)
WHERE NOT n:Summary AND NOT m:Summary AND n.gid = m.gid AND n.gid = $gid AND n<>m AND apoc.coll.sort(labels(n)) = apoc.coll.sort(labels(m))
WITH n, m,
gds.similarity.cosine(n.embedding, m.embedding) AS similarity
WHERE similarity > threshold
WITH head(collect([n,m])) as nodes
CALL apoc.refactor.mergeNodes(nodes, {properties: 'overwrite', mergeRels: true})
YIELD node
RETURN count(*)
"""
result = n4j.query(merge_query, {'gid': gid})
else:
merge_query = """
// Define a threshold for cosine similarity
WITH 0.5 AS threshold
MATCH (n), (m)
WHERE NOT n:Summary AND NOT m:Summary AND n<>m AND apoc.coll.sort(labels(n)) = apoc.coll.sort(labels(m))
WITH n, m,
gds.similarity.cosine(n.embedding, m.embedding) AS similarity
WHERE similarity > threshold
WITH head(collect([n,m])) as nodes
CALL apoc.refactor.mergeNodes(nodes, {properties: 'overwrite', mergeRels: true})
YIELD node
RETURN count(*)
"""
result = n4j.query(merge_query)
return result
def ref_link(n4j, gid1, gid2):
trinity_query = """
// Match nodes from Graph A
MATCH (a)
WHERE a.gid = $gid1 AND NOT a:Summary
WITH collect(a) AS GraphA
// Match nodes from Graph B
MATCH (b)
WHERE b.gid = $gid2 AND NOT b:Summary
WITH GraphA, collect(b) AS GraphB
// Unwind the nodes to compare each against each
UNWIND GraphA AS n
UNWIND GraphB AS m
// Set the threshold for cosine similarity
WITH n, m, 0.6 AS threshold
// Compute cosine similarity and apply the threshold
WHERE apoc.coll.sort(labels(n)) = apoc.coll.sort(labels(m)) AND n <> m
WITH n, m, threshold,
gds.similarity.cosine(n.embedding, m.embedding) AS similarity
WHERE similarity > threshold
// Create a relationship based on the condition
MERGE (m)-[:REFERENCE]->(n)
// Return results
RETURN n, m
"""
result = n4j.query(trinity_query, {'gid1': gid1, 'gid2': gid2})
return result
def str_uuid():
# Generate a random UUID
generated_uuid = uuid.uuid4()
# Convert UUID to a string
return str(generated_uuid)