forked from openai/chatgpt-retrieval-plugin
-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathredis_datastore.py
401 lines (347 loc) · 13.9 KB
/
redis_datastore.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
import asyncio
import os
import re
import json
import redis.asyncio as redis
import numpy as np
from redis.commands.search.query import Query as RediSearchQuery
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
from redis.commands.search.field import (
TagField,
TextField,
NumericField,
VectorField,
)
from loguru import logger
from typing import Dict, List, Optional
from datastore.datastore import DataStore
from models.models import (
DocumentChunk,
DocumentMetadataFilter,
DocumentChunkWithScore,
DocumentMetadataFilter,
QueryResult,
QueryWithEmbedding,
)
from services.date import to_unix_timestamp
# Read environment variables for Redis
REDIS_HOST = os.environ.get("REDIS_HOST", "localhost")
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD")
REDIS_INDEX_NAME = os.environ.get("REDIS_INDEX_NAME", "index")
REDIS_DOC_PREFIX = os.environ.get("REDIS_DOC_PREFIX", "doc")
REDIS_DISTANCE_METRIC = os.environ.get("REDIS_DISTANCE_METRIC", "COSINE")
REDIS_INDEX_TYPE = os.environ.get("REDIS_INDEX_TYPE", "FLAT")
assert REDIS_INDEX_TYPE in ("FLAT", "HNSW")
# OpenAI Embeddings Dimension
VECTOR_DIMENSION = int(os.environ.get("EMBEDDING_DIMENSION", 256))
# RediSearch constants
REDIS_REQUIRED_MODULES = [
{"name": "search", "ver": 20600},
{"name": "ReJSON", "ver": 20404},
]
REDIS_DEFAULT_ESCAPED_CHARS = re.compile(r"[,.<>{}\[\]\\\"\':;!@#$%^&()\-+=~\/ ]")
# Helper functions
def unpack_schema(d: dict):
for v in d.values():
if isinstance(v, dict):
yield from unpack_schema(v)
else:
yield v
async def _check_redis_module_exist(client: redis.Redis, modules: List[dict]):
installed_modules = (await client.info()).get("modules", [])
installed_modules = {module["name"]: module for module in installed_modules}
for module in modules:
if module["name"] not in installed_modules or int(
installed_modules[module["name"]]["ver"]
) < int(module["ver"]):
error_message = (
"You must add the RediSearch (>= 2.6) and ReJSON (>= 2.4) modules from Redis Stack. "
"Please refer to Redis Stack docs: https://redis.io/docs/stack/"
)
logger.error(error_message)
raise AttributeError(error_message)
class RedisDataStore(DataStore):
def __init__(self, client: redis.Redis, redisearch_schema: dict):
self.client = client
self._schema = redisearch_schema
# Init default metadata with sentinel values in case the document written has no metadata
self._default_metadata = {
field: (0 if field == "created_at" else "_null_")
for field in redisearch_schema["metadata"]
}
### Redis Helper Methods ###
@classmethod
async def init(cls, **kwargs):
"""
Setup the index if it does not exist.
"""
try:
# Connect to the Redis Client
logger.info("Connecting to Redis")
client = redis.Redis(
host=REDIS_HOST, port=REDIS_PORT, password=REDIS_PASSWORD
)
except Exception as e:
logger.error(f"Error setting up Redis: {e}")
raise e
await _check_redis_module_exist(client, modules=REDIS_REQUIRED_MODULES)
dim = kwargs.get("dim", VECTOR_DIMENSION)
redisearch_schema = {
"metadata": {
"document_id": TagField(
"$.metadata.document_id", as_name="document_id"
),
"source_id": TagField("$.metadata.source_id", as_name="source_id"),
"source": TagField("$.metadata.source", as_name="source"),
"author": TextField("$.metadata.author", as_name="author"),
"created_at": NumericField(
"$.metadata.created_at", as_name="created_at"
),
},
"embedding": VectorField(
"$.embedding",
REDIS_INDEX_TYPE,
{
"TYPE": "FLOAT64",
"DIM": dim,
"DISTANCE_METRIC": REDIS_DISTANCE_METRIC,
},
as_name="embedding",
),
}
try:
# Check for existence of RediSearch Index
await client.ft(REDIS_INDEX_NAME).info()
logger.info(f"RediSearch index {REDIS_INDEX_NAME} already exists")
except:
# Create the RediSearch Index
logger.info(f"Creating new RediSearch index {REDIS_INDEX_NAME}")
definition = IndexDefinition(
prefix=[REDIS_DOC_PREFIX], index_type=IndexType.JSON
)
fields = list(unpack_schema(redisearch_schema))
logger.info(f"Creating index with fields: {fields}")
await client.ft(REDIS_INDEX_NAME).create_index(
fields=fields, definition=definition
)
return cls(client, redisearch_schema)
@staticmethod
def _redis_key(document_id: str, chunk_id: str) -> str:
"""
Create the JSON key for document chunks in Redis.
Args:
document_id (str): Document Identifier
chunk_id (str): Chunk Identifier
Returns:
str: JSON key string.
"""
return f"doc:{document_id}:chunk:{chunk_id}"
@staticmethod
def _escape(value: str) -> str:
"""
Escape filter value.
Args:
value (str): Value to escape.
Returns:
str: Escaped filter value for RediSearch.
"""
def escape_symbol(match) -> str:
value = match.group(0)
return f"\\{value}"
return REDIS_DEFAULT_ESCAPED_CHARS.sub(escape_symbol, value)
def _get_redis_chunk(self, chunk: DocumentChunk) -> dict:
"""
Convert DocumentChunk into a JSON object for storage
in Redis.
Args:
chunk (DocumentChunk): Chunk of a Document.
Returns:
dict: JSON object for storage in Redis.
"""
# Convert chunk -> dict
data = chunk.__dict__
metadata = chunk.metadata.__dict__
data["chunk_id"] = data.pop("id")
# Prep Redis Metadata
redis_metadata = dict(self._default_metadata)
if metadata:
for field, value in metadata.items():
if value:
if field == "created_at":
redis_metadata[field] = to_unix_timestamp(value) # type: ignore
else:
redis_metadata[field] = value
data["metadata"] = redis_metadata
return data
def _get_redis_query(self, query: QueryWithEmbedding) -> RediSearchQuery:
"""
Convert a QueryWithEmbedding into a RediSearchQuery.
Args:
query (QueryWithEmbedding): Search query.
Returns:
RediSearchQuery: Query for RediSearch.
"""
filter_str: str = ""
# RediSearch field type to query string
def _typ_to_str(typ, field, value) -> str: # type: ignore
if isinstance(typ, TagField):
return f"@{field}:{{{self._escape(value)}}} "
elif isinstance(typ, TextField):
return f"@{field}:{value} "
elif isinstance(typ, NumericField):
num = to_unix_timestamp(value)
match field:
case "start_date":
return f"@{field}:[{num} +inf] "
case "end_date":
return f"@{field}:[-inf {num}] "
# Build filter
if query.filter:
redisearch_schema = self._schema
for field, value in query.filter.__dict__.items():
if not value:
continue
if field in redisearch_schema:
filter_str += _typ_to_str(redisearch_schema[field], field, value)
elif field in redisearch_schema["metadata"]:
if field == "source": # handle the enum
value = value.value
filter_str += _typ_to_str(
redisearch_schema["metadata"][field], field, value
)
elif field in ["start_date", "end_date"]:
filter_str += _typ_to_str(
redisearch_schema["metadata"]["created_at"], field, value
)
# Postprocess filter string
filter_str = filter_str.strip()
filter_str = filter_str if filter_str else "*"
# Prepare query string
query_str = (
f"({filter_str})=>[KNN {query.top_k} @embedding $embedding as score]"
)
return (
RediSearchQuery(query_str)
.sort_by("score")
.paging(0, query.top_k)
.dialect(2)
)
async def _redis_delete(self, keys: List[str]):
"""
Delete a list of keys from Redis.
Args:
keys (List[str]): List of keys to delete.
"""
# Delete the keys
await asyncio.gather(*[self.client.delete(key) for key in keys])
#######
async def _upsert(self, chunks: Dict[str, List[DocumentChunk]]) -> List[str]:
"""
Takes in a list of list of document chunks and inserts them into the database.
Return a list of document ids.
"""
# Initialize a list of ids to return
doc_ids: List[str] = []
# Loop through the dict items
for doc_id, chunk_list in chunks.items():
# Append the id to the ids list
doc_ids.append(doc_id)
# Write chunks in a pipelines
async with self.client.pipeline(transaction=False) as pipe:
for chunk in chunk_list:
key = self._redis_key(doc_id, chunk.id)
data = self._get_redis_chunk(chunk)
await pipe.json().set(key, "$", data)
await pipe.execute()
return doc_ids
async def _query(
self,
queries: List[QueryWithEmbedding],
) -> List[QueryResult]:
"""
Takes in a list of queries with embeddings and filters and
returns a list of query results with matching document chunks and scores.
"""
# Prepare query responses and results object
results: List[QueryResult] = []
# Gather query results in a pipeline
logger.info(f"Gathering {len(queries)} query results")
for query in queries:
logger.debug(f"Query: {query.query}")
query_results: List[DocumentChunkWithScore] = []
# Extract Redis query
redis_query: RediSearchQuery = self._get_redis_query(query)
embedding = np.array(query.embedding, dtype=np.float64).tobytes()
# Perform vector search
query_response = await self.client.ft(REDIS_INDEX_NAME).search(
redis_query, {"embedding": embedding}
)
# Iterate through the most similar documents
for doc in query_response.docs:
# Load JSON data
doc_json = json.loads(doc.json)
# Create document chunk object with score
result = DocumentChunkWithScore(
id=doc_json["metadata"]["document_id"],
score=doc.score,
text=doc_json["text"],
metadata=doc_json["metadata"],
)
query_results.append(result)
# Add to overall results
results.append(QueryResult(query=query.query, results=query_results))
return results
async def _find_keys(self, pattern: str) -> List[str]:
return [key async for key in self.client.scan_iter(pattern)]
async def delete(
self,
ids: Optional[List[str]] = None,
filter: Optional[DocumentMetadataFilter] = None,
delete_all: Optional[bool] = None,
) -> bool:
"""
Removes vectors by ids, filter, or everything in the datastore.
Returns whether the operation was successful.
"""
# Delete all vectors from the index if delete_all is True
if delete_all:
try:
logger.info(f"Deleting all documents from index")
await self.client.ft(REDIS_INDEX_NAME).dropindex(True)
logger.info(f"Deleted all documents successfully")
return True
except Exception as e:
logger.error(f"Error deleting all documents: {e}")
raise e
# Delete by filter
if filter:
# TODO - extend this to work with other metadata filters?
if filter.document_id:
try:
keys = await self._find_keys(
f"{REDIS_DOC_PREFIX}:{filter.document_id}:*"
)
await self._redis_delete(keys)
logger.info(f"Deleted document {filter.document_id} successfully")
except Exception as e:
logger.error(f"Error deleting document {filter.document_id}: {e}")
raise e
# Delete by explicit ids (Redis keys)
if ids:
try:
logger.info(f"Deleting document ids {ids}")
keys = []
# find all keys associated with the document ids
for document_id in ids:
doc_keys = await self._find_keys(
pattern=f"{REDIS_DOC_PREFIX}:{document_id}:*"
)
keys.extend(doc_keys)
# delete all keys
logger.info(f"Deleting {len(keys)} keys from Redis")
await self._redis_delete(keys)
except Exception as e:
logger.error(f"Error deleting ids: {e}")
raise e
return True