Skip to content

Commit

Permalink
feat: add functionality to search embeddings in Milvus
Browse files Browse the repository at this point in the history
Signed-off-by: Jianuo Kuang <[email protected]>
  • Loading branch information
gitveg committed Jan 31, 2025
1 parent c601195 commit 8442890
Show file tree
Hide file tree
Showing 5 changed files with 226 additions and 22 deletions.
Binary file added rfcs/assets/search_docs_false.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added rfcs/assets/search_docs_true.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
102 changes: 102 additions & 0 deletions rfcs/notes/kjn-notes-2025.1.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# 1月工作记录

本月主要工作集中在编写、测试向量数据库的检索逻辑代码

## milvus检索代码

`milvus` 向量数据库提供了两种检索的方法,分别是 query 和 search 。

其中,search 方法主要用于执行近似最近邻搜索(Approximate Nearest Neighbors, ANN),即根据给定的查询向量找到与之最相似的向量。它的核心功能是基于**向量相似性**进行检索。

query 方法用于执行更广泛的基于条件的查询,主要用于基于条件的过滤,根据指定的条件表达式检索数据

在代码编写上选择使用更适合 RAG 系统的 search 方法。

检索逻辑代码放在了 milvus_client.py 下:

```python
@get_time
def search_docs(self, query_embedding: List[float] = None, filter_expr: str = None, doc_limit: int = 10):
"""
从 Milvus 集合中检索文档。
Args:
query_embedding (List[float]): 查询向量,用于基于向量相似性检索。
filter_expr (str): 过滤条件表达式,用于基于字段值的过滤。如"user_id == 'abc1234'"
limit (int): 返回的文档数量上限,默认为 10。
Returns:
List[dict]: 检索到的文档列表,每个文档是一个字典,包含字段值和向量。
"""
try:
if not self.sess:
raise MilvusFailed("Milvus collection is not loaded. Call load_collection_() first.")

# 构造查询参数
search_params = {
"metric_type": self.search_params["metric_type"],
"params": self.search_params["params"]
}

# 构造查询表达式
expr = ""
if filter_expr:
expr = filter_expr

# 构造检索参数
search_params.update({
"data": [query_embedding] if query_embedding else None,
"anns_field": "embedding", # 指定集合中存储向量的字段名称。Milvus 会在该字段上进行向量相似性检索。
"param": {"metric_type": "L2", "params": {"nprobe": 128}}, # 检索的精度和性能
"limit": doc_limit, # 指定返回的最相似文档的数量上限
"expr": expr,
"output_fields": self.output_fields
})

# 执行检索
results = self.sess.search(**search_params)

# 处理检索结果
retrieved_docs = []
for hits in results:
for hit in hits:
doc = {
# "id": hit.id,
# "distance": hit.distance,
"user_id": hit.entity.get("user_id"),
"kb_id": hit.entity.get("kb_id"),
"file_id": hit.entity.get("file_id"),
"headers": json.loads(hit.entity.get("headers")),
"doc_id": hit.entity.get("doc_id"),
"content": hit.entity.get("content"),
"embedding": hit.entity.get("embedding")
}
retrieved_docs.append(doc)

return retrieved_docs

except Exception as e:
print(f'[{cur_func_name()}] [search_docs] Failed to search documents: {traceback.format_exc()}')
raise MilvusFailed(f"Failed to search documents: {str(e)}")
```

## 测试milvus检索逻辑

利用已有的 embedding 文件夹下的 embedding_client.py(原名为 client.py )中的embedding处理代码,同时编写了 embed_user_input 方便测试。

同时在 milvus_client.py 的 main 函数中调用 search_docs 函数进行测试,测试结果如下。

不设置过滤条件正常检索:

![search_true](/rfcs/assets/search_docs_true.png)

设置过滤条件,检索结果为空:

![search_false](/rfcs/assets/search_docs_false.png)


## 未来工作

后续继续实现 server 与 client 的交互处理,方便更好地测试用户的输入经过 embedding 后到 milvus 中进行检索的过程。

RAG 系统的 UI 界面逐步完善。
98 changes: 88 additions & 10 deletions src/client/database/milvus/milvus_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from src.utils.general_utils import get_time, cur_func_name
from src.configs.configs import MILVUS_HOST_LOCAL, MILVUS_PORT, VECTOR_SEARCH_TOP_K

from src.client.embedding.embedding_client import SBIEmbeddings, _process_query, embed_user_input


class MilvusFailed(Exception):
"""异常基类"""
Expand Down Expand Up @@ -107,6 +109,70 @@ def store_doc(self, doc: Document, embedding: List[float]):
print(f'[{cur_func_name()}] [store_doc] Failed to store document: {traceback.format_exc()}')
raise MilvusFailed(f"Failed to store document: {str(e)}")

@get_time
def search_docs(self, query_embedding: List[float] = None, filter_expr: str = None, doc_limit: int = 10):
"""
从 Milvus 集合中检索文档。
Args:
query_embedding (List[float]): 查询向量,用于基于向量相似性检索。
filter_expr (str): 过滤条件表达式,用于基于字段值的过滤。如"user_id == 'abc1234'"
limit (int): 返回的文档数量上限,默认为 10。
Returns:
List[dict]: 检索到的文档列表,每个文档是一个字典,包含字段值和向量。
"""
try:
if not self.sess:
raise MilvusFailed("Milvus collection is not loaded. Call load_collection_() first.")

# 构造查询参数
search_params = {
"metric_type": self.search_params["metric_type"],
"params": self.search_params["params"]
}

# 构造查询表达式
expr = ""
if filter_expr:
expr = filter_expr

# 构造检索参数
search_params.update({
"data": [query_embedding] if query_embedding else None,
"anns_field": "embedding", # 指定集合中存储向量的字段名称。Milvus 会在该字段上进行向量相似性检索。
"param": {"metric_type": "L2", "params": {"nprobe": 128}}, # 检索的精度和性能
"limit": doc_limit, # 指定返回的最相似文档的数量上限
"expr": expr,
"output_fields": self.output_fields
})

# 执行检索
results = self.sess.search(**search_params)

# 处理检索结果
retrieved_docs = []
for hits in results:
for hit in hits:
doc = {
# "id": hit.id,
# "distance": hit.distance,
"user_id": hit.entity.get("user_id"),
"kb_id": hit.entity.get("kb_id"),
"file_id": hit.entity.get("file_id"),
"headers": json.loads(hit.entity.get("headers")),
"doc_id": hit.entity.get("doc_id"),
"content": hit.entity.get("content"),
"embedding": hit.entity.get("embedding")
}
retrieved_docs.append(doc)

return retrieved_docs

except Exception as e:
print(f'[{cur_func_name()}] [search_docs] Failed to search documents: {traceback.format_exc()}')
raise MilvusFailed(f"Failed to search documents: {str(e)}")

@property
def fields(self):
fields = [
Expand Down Expand Up @@ -144,16 +210,17 @@ def main():

# 检索所有文档
try:
# 构造查询表达式(检索所有文档)
query_expr = "" # 不设置过滤条件,检索所有文档

# 执行查询
results = client.sess.query(
expr=query_expr,
output_fields=client.output_fields, # 指定返回的字段
limit=1000
)
# # 构造查询表达式
filter_expr = "123" # 设置过滤条件

# # 执行查询
# results = client.sess.query(
# expr=query_expr,
# output_fields=client.output_fields, # 指定返回的字段
# limit=1000
# )
query_expr = embed_user_input("荷塘月色")
results = client.search_docs(query_expr, filter_expr, 1000)
# 打印检索结果
if not results:
print(f"No documents found in collection {user_id}.")
Expand All @@ -165,7 +232,18 @@ def main():
print(f" user_id: {result['user_id']}")
print(f" kb_id: {result['kb_id']}")
print(f" file_id: {result['file_id']}")
print(f" headers: {json.loads(result['headers'])}") # 将 headers 从 JSON 字符串解析为字典
# 检查 headers 的类型
headers = result.get('headers')
if isinstance(headers, dict):
print(f" headers: {headers}")
elif isinstance(headers, str):
try:
headers = json.loads(headers)
print(f" headers: {headers}")
except json.JSONDecodeError as e:
print(f" headers: {headers} (无法解析为 JSON)")
else:
print(f" headers: {headers} (未知类型)")
print(f" doc_id: {result['doc_id']}")
print(f" content: {result['content']}")
print(f" embedding: {result['embedding'][:5]}... (truncated)") # 只打印前 5 维向量
Expand Down
48 changes: 36 additions & 12 deletions src/client/embedding/client.py → src/client/embedding/embedding_client.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -180,23 +180,47 @@ async def performance_test():
async_time = time.time() - start_time
debug_logger.info(f"异步处理 {size} 个文本耗时: {async_time:.2f}秒")

def embed_user_input(user_input: str):
"""测试用户输入的文本嵌入"""
embedder = SBIEmbeddings()

# 对用户输入的文本进行预处理
processed_input = _process_query(user_input)

debug_logger.info("\n测试用户输入的嵌入:")
debug_logger.info(f"用户输入: {user_input}")
debug_logger.info(f"预处理后的输入: {processed_input}")

try:
# 使用同步方法获取嵌入向量
embedding = embedder.embed_query(processed_input)
debug_logger.info(f"嵌入向量维度: {len(embedding)}")
debug_logger.info(f"嵌入向量: {embedding}")
except Exception as e:
debug_logger.error(f"嵌入过程中发生错误: {str(e)}")

return embedding


async def main():
"""主测试函数"""
debug_logger.info(f"开始embedding客户端测试...")

# 测试异步方法
await test_async_methods()

# # 测试同步方法
# test_sync_methods()

# # 测试错误处理
# test_error_handling()

# # 执行性能测试
# await performance_test()

try:
# 测试异步方法
await test_async_methods()

# 测试同步方法
test_sync_methods()

# 测试错误处理
test_error_handling()

# 执行性能测试
await performance_test()
except Exception as e:
debug_logger.error(f"测试过程中发生错误: {str(e)}")

debug_logger.info("embedding客户端测试完成")


Expand Down

0 comments on commit 8442890

Please sign in to comment.