Skip to content

Commit

Permalink
FEAT: support milvus to full text search (langgenius#11430)
Browse files Browse the repository at this point in the history
Signed-off-by: YoungLH <[email protected]>
  • Loading branch information
kgpp34 authored Jan 8, 2025
1 parent d649037 commit 040a3b7
Show file tree
Hide file tree
Showing 8 changed files with 392 additions and 240 deletions.
6 changes: 6 additions & 0 deletions api/configs/middleware/vdb/milvus_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,9 @@ class MilvusConfig(BaseSettings):
description="Name of the Milvus database to connect to (default is 'default')",
default="default",
)

MILVUS_ENABLE_HYBRID_SEARCH: bool = Field(
description="Enable hybrid search features (requires Milvus >= 2.5.0). Set to false for compatibility with "
"older versions",
default=True,
)
2 changes: 2 additions & 0 deletions api/core/rag/datasource/vdb/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ class Field(Enum):
METADATA_KEY = "metadata"
GROUP_KEY = "group_id"
VECTOR = "vector"
# Sparse Vector aims to support full text search
SPARSE_VECTOR = "sparse_vector"
TEXT_KEY = "text"
PRIMARY_KEY = "id"
DOC_ID = "metadata.doc_id"
200 changes: 171 additions & 29 deletions api/core/rag/datasource/vdb/milvus/milvus_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
from typing import Any, Optional

from packaging import version
from pydantic import BaseModel, model_validator
from pymilvus import MilvusClient, MilvusException # type: ignore
from pymilvus.milvus_client import IndexParams # type: ignore
Expand All @@ -20,16 +21,25 @@


class MilvusConfig(BaseModel):
uri: str
token: Optional[str] = None
user: str
password: str
batch_size: int = 100
database: str = "default"
"""
Configuration class for Milvus connection.
"""

uri: str # Milvus server URI
token: Optional[str] = None # Optional token for authentication
user: str # Username for authentication
password: str # Password for authentication
batch_size: int = 100 # Batch size for operations
database: str = "default" # Database name
enable_hybrid_search: bool = False # Flag to enable hybrid search

@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
"""
Validate the configuration values.
Raises ValueError if required fields are missing.
"""
if not values.get("uri"):
raise ValueError("config MILVUS_URI is required")
if not values.get("user"):
Expand All @@ -39,6 +49,9 @@ def validate_config(cls, values: dict) -> dict:
return values

def to_milvus_params(self):
"""
Convert the configuration to a dictionary of Milvus connection parameters.
"""
return {
"uri": self.uri,
"token": self.token,
Expand All @@ -49,39 +62,69 @@ def to_milvus_params(self):


class MilvusVector(BaseVector):
"""
Milvus vector storage implementation.
"""

def __init__(self, collection_name: str, config: MilvusConfig):
super().__init__(collection_name)
self._client_config = config
self._client = self._init_client(config)
self._consistency_level = "Session"
self._fields: list[str] = []
self._consistency_level = "Session" # Consistency level for Milvus operations
self._fields: list[str] = [] # List of fields in the collection
self._hybrid_search_enabled = self._check_hybrid_search_support() # Check if hybrid search is supported

def _check_hybrid_search_support(self) -> bool:
"""
Check if the current Milvus version supports hybrid search.
Returns True if the version is >= 2.5.0, otherwise False.
"""
if not self._client_config.enable_hybrid_search:
return False

try:
milvus_version = self._client.get_server_version()
return version.parse(milvus_version).base_version >= version.parse("2.5.0").base_version
except Exception as e:
logger.warning(f"Failed to check Milvus version: {str(e)}. Disabling hybrid search.")
return False

def get_type(self) -> str:
"""
Get the type of vector storage (Milvus).
"""
return VectorType.MILVUS

def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
"""
Create a collection and add texts with embeddings.
"""
index_params = {"metric_type": "IP", "index_type": "HNSW", "params": {"M": 8, "efConstruction": 64}}
metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
self.create_collection(embeddings, metadatas, index_params)
self.add_texts(texts, embeddings)

def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
"""
Add texts and their embeddings to the collection.
"""
insert_dict_list = []
for i in range(len(documents)):
insert_dict = {
# Do not need to insert the sparse_vector field separately, as the text_bm25_emb
# function will automatically convert the native text into a sparse vector for us.
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i],
Field.METADATA_KEY.value: documents[i].metadata,
}
insert_dict_list.append(insert_dict)
# Total insert count
total_count = len(insert_dict_list)

pks: list[str] = []

for i in range(0, total_count, 1000):
batch_insert_list = insert_dict_list[i : i + 1000]
# Insert into the collection.
batch_insert_list = insert_dict_list[i : i + 1000]
try:
ids = self._client.insert(collection_name=self._collection_name, data=batch_insert_list)
pks.extend(ids)
Expand All @@ -91,6 +134,9 @@ def add_texts(self, documents: list[Document], embeddings: list[list[float]], **
return pks

def get_ids_by_metadata_field(self, key: str, value: str):
"""
Get document IDs by metadata field key and value.
"""
result = self._client.query(
collection_name=self._collection_name, filter=f'metadata["{key}"] == "{value}"', output_fields=["id"]
)
Expand All @@ -100,12 +146,18 @@ def get_ids_by_metadata_field(self, key: str, value: str):
return None

def delete_by_metadata_field(self, key: str, value: str):
"""
Delete documents by metadata field key and value.
"""
if self._client.has_collection(self._collection_name):
ids = self.get_ids_by_metadata_field(key, value)
if ids:
self._client.delete(collection_name=self._collection_name, pks=ids)

def delete_by_ids(self, ids: list[str]) -> None:
"""
Delete documents by their IDs.
"""
if self._client.has_collection(self._collection_name):
result = self._client.query(
collection_name=self._collection_name, filter=f'metadata["doc_id"] in {ids}', output_fields=["id"]
Expand All @@ -115,10 +167,16 @@ def delete_by_ids(self, ids: list[str]) -> None:
self._client.delete(collection_name=self._collection_name, pks=ids)

def delete(self) -> None:
"""
Delete the entire collection.
"""
if self._client.has_collection(self._collection_name):
self._client.drop_collection(self._collection_name, None)

def text_exists(self, id: str) -> bool:
"""
Check if a text with the given ID exists in the collection.
"""
if not self._client.has_collection(self._collection_name):
return False

Expand All @@ -128,40 +186,88 @@ def text_exists(self, id: str) -> bool:

return len(result) > 0

def field_exists(self, field: str) -> bool:
"""
Check if a field exists in the collection.
"""
return field in self._fields

def _process_search_results(
self, results: list[Any], output_fields: list[str], score_threshold: float = 0.0
) -> list[Document]:
"""
Common method to process search results
:param results: Search results
:param output_fields: Fields to be output
:param score_threshold: Score threshold for filtering
:return: List of documents
"""
docs = []
for result in results[0]:
metadata = result["entity"].get(output_fields[1], {})
metadata["score"] = result["distance"]

if result["distance"] > score_threshold:
doc = Document(page_content=result["entity"].get(output_fields[0], ""), metadata=metadata)
docs.append(doc)

return docs

def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
# Set search parameters.
"""
Search for documents by vector similarity.
"""
results = self._client.search(
collection_name=self._collection_name,
data=[query_vector],
anns_field=Field.VECTOR.value,
limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
)
# Organize results.
docs = []
for result in results[0]:
metadata = result["entity"].get(Field.METADATA_KEY.value)
metadata["score"] = result["distance"]
score_threshold = float(kwargs.get("score_threshold") or 0.0)
if result["distance"] > score_threshold:
doc = Document(page_content=result["entity"].get(Field.CONTENT_KEY.value), metadata=metadata)
docs.append(doc)
return docs

return self._process_search_results(
results,
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
score_threshold=float(kwargs.get("score_threshold") or 0.0),
)

def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# milvus/zilliz doesn't support bm25 search
return []
"""
Search for documents by full-text search (if hybrid search is enabled).
"""
if not self._hybrid_search_enabled or not self.field_exists(Field.SPARSE_VECTOR.value):
logger.warning("Full-text search is not supported in current Milvus version (requires >= 2.5.0)")
return []

results = self._client.search(
collection_name=self._collection_name,
data=[query],
anns_field=Field.SPARSE_VECTOR.value,
limit=kwargs.get("top_k", 4),
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
)

return self._process_search_results(
results,
output_fields=[Field.CONTENT_KEY.value, Field.METADATA_KEY.value],
score_threshold=float(kwargs.get("score_threshold") or 0.0),
)

def create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None, index_params: Optional[dict] = None
):
"""
Create a new collection in Milvus with the specified schema and index parameters.
"""
lock_name = "vector_indexing_lock_{}".format(self._collection_name)
with redis_client.lock(lock_name, timeout=20):
collection_exist_cache_key = "vector_indexing_{}".format(self._collection_name)
if redis_client.get(collection_exist_cache_key):
return
# Grab the existing collection if it exists
if not self._client.has_collection(self._collection_name):
from pymilvus import CollectionSchema, DataType, FieldSchema # type: ignore
from pymilvus import CollectionSchema, DataType, FieldSchema, Function, FunctionType # type: ignore
from pymilvus.orm.types import infer_dtype_bydata # type: ignore

# Determine embedding dim
Expand All @@ -170,16 +276,36 @@ def create_collection(
if metadatas:
fields.append(FieldSchema(Field.METADATA_KEY.value, DataType.JSON, max_length=65_535))

# Create the text field
fields.append(FieldSchema(Field.CONTENT_KEY.value, DataType.VARCHAR, max_length=65_535))
# Create the text field, enable_analyzer will be set True to support milvus automatically
# transfer text to sparse_vector, reference: https://milvus.io/docs/full-text-search.md
fields.append(
FieldSchema(
Field.CONTENT_KEY.value,
DataType.VARCHAR,
max_length=65_535,
enable_analyzer=self._hybrid_search_enabled,
)
)
# Create the primary key field
fields.append(FieldSchema(Field.PRIMARY_KEY.value, DataType.INT64, is_primary=True, auto_id=True))
# Create the vector field, supports binary or float vectors
fields.append(FieldSchema(Field.VECTOR.value, infer_dtype_bydata(embeddings[0]), dim=dim))
# Create Sparse Vector Index for the collection
if self._hybrid_search_enabled:
fields.append(FieldSchema(Field.SPARSE_VECTOR.value, DataType.SPARSE_FLOAT_VECTOR))

# Create the schema for the collection
schema = CollectionSchema(fields)

# Create custom function to support text to sparse vector by BM25
if self._hybrid_search_enabled:
bm25_function = Function(
name="text_bm25_emb",
input_field_names=[Field.CONTENT_KEY.value],
output_field_names=[Field.SPARSE_VECTOR.value],
function_type=FunctionType.BM25,
)
schema.add_function(bm25_function)

for x in schema.fields:
self._fields.append(x.name)
# Since primary field is auto-id, no need to track it
Expand All @@ -189,23 +315,38 @@ def create_collection(
index_params_obj = IndexParams()
index_params_obj.add_index(field_name=Field.VECTOR.value, **index_params)

# Create Sparse Vector Index for the collection
if self._hybrid_search_enabled:
index_params_obj.add_index(
field_name=Field.SPARSE_VECTOR.value, index_type="AUTOINDEX", metric_type="BM25"
)

# Create the collection
collection_name = self._collection_name
self._client.create_collection(
collection_name=collection_name,
collection_name=self._collection_name,
schema=schema,
index_params=index_params_obj,
consistency_level=self._consistency_level,
)
redis_client.set(collection_exist_cache_key, 1, ex=3600)

def _init_client(self, config) -> MilvusClient:
"""
Initialize and return a Milvus client.
"""
client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database)
return client


class MilvusVectorFactory(AbstractVectorFactory):
"""
Factory class for creating MilvusVector instances.
"""

def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MilvusVector:
"""
Initialize a MilvusVector instance for the given dataset.
"""
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
Expand All @@ -222,5 +363,6 @@ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings
user=dify_config.MILVUS_USER or "",
password=dify_config.MILVUS_PASSWORD or "",
database=dify_config.MILVUS_DATABASE or "",
enable_hybrid_search=dify_config.MILVUS_ENABLE_HYBRID_SEARCH or False,
),
)
Loading

0 comments on commit 040a3b7

Please sign in to comment.