Skip to content

Commit

Permalink
refactor: update the default values of top-k parameter in vdb to be c…
Browse files Browse the repository at this point in the history
…onsistent (langgenius#9367)
  • Loading branch information
hwzhuhao authored Oct 16, 2024
1 parent a83cccc commit 8659485
Show file tree
Hide file tree
Showing 9 changed files with 9 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def delete(self) -> None:
self._client.indices.delete(index=self._collection_name)

def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 10)
top_k = kwargs.get("top_k", 4)
num_candidates = math.ceil(top_k * 1.5)
knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates}

Expand Down
2 changes: 1 addition & 1 deletion api/core/rag/datasource/vdb/myscale/myscale_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return self._search(f"TextSearch('enable_nlq=false')(text, '{query}')", SortOrder.DESC, **kwargs)

def _search(self, dist: str, order: SortOrder, **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
top_k = kwargs.get("top_k", 4)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
where_str = (
f"WHERE dist < {1 - score_threshold}"
Expand Down
10 changes: 1 addition & 9 deletions api/core/rag/datasource/vdb/oracle/oraclevector.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,6 @@ def get_by_ids(self, ids: list[str]) -> list[Document]:
docs.append(Document(page_content=record[1], metadata=record[0]))
return docs

# def get_ids_by_metadata_field(self, key: str, value: str):
# with self._get_cursor() as cur:
# cur.execute(f"SELECT id FROM {self.table_name} d WHERE d.meta.{key}='{value}'" )
# idss = []
# for record in cur:
# idss.append(record[0])
# return idss

def delete_by_ids(self, ids: list[str]) -> None:
with self._get_cursor() as cur:
cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))
Expand All @@ -192,7 +184,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
:param top_k: The number of nearest neighbors to return, default is 5.
:return: List of Documents that are nearest to the query vector.
"""
top_k = kwargs.get("top_k", 5)
top_k = kwargs.get("top_k", 4)
with self._get_cursor() as cur:
cur.execute(
f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}"
Expand Down
14 changes: 1 addition & 13 deletions api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
query_vector,
).label("distance"),
)
.limit(kwargs.get("top_k", 2))
.limit(kwargs.get("top_k", 4))
.order_by("distance")
)
res = session.execute(stmt)
Expand All @@ -205,18 +205,6 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
return docs

def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
# with Session(self._client) as session:
# select_statement = sql_text(
# f"SELECT text, meta FROM {self._collection_name} WHERE to_tsvector(text) @@ '{query}'::tsquery"
# )
# results = session.execute(select_statement).fetchall()
# if results:
# docs = []
# for result in results:
# doc = Document(page_content=result[0],
# metadata=result[1])
# docs.append(doc)
# return docs
return []


Expand Down
2 changes: 1 addition & 1 deletion api/core/rag/datasource/vdb/pgvector/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Doc
:param top_k: The number of nearest neighbors to return, default is 5.
:return: List of Documents that are nearest to the query vector.
"""
top_k = kwargs.get("top_k", 5)
top_k = kwargs.get("top_k", 4)

with self._get_cursor() as cur:
cur.execute(
Expand Down
2 changes: 1 addition & 1 deletion api/core/rag/datasource/vdb/relyt/relyt_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def text_exists(self, id: str) -> bool:

def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
results = self.similarity_search_with_score_by_vector(
k=int(kwargs.get("top_k")), embedding=query_vector, filter=kwargs.get("filter")
k=int(kwargs.get("top_k", 4)), embedding=query_vector, filter=kwargs.get("filter")
)

# Organize results.
Expand Down
2 changes: 1 addition & 1 deletion api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def delete_by_metadata_field(self, key: str, value: str) -> None:
self._delete_by_ids(ids)

def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 5)
top_k = kwargs.get("top_k", 4)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
filter = kwargs.get("filter")
distance = 1 - score_threshold
Expand Down
2 changes: 1 addition & 1 deletion api/core/rag/datasource/vdb/vikingdb/vikingdb_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def delete_by_metadata_field(self, key: str, value: str) -> None:

def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
results = self._client.get_index(self._collection_name, self._index_name).search_by_vector(
query_vector, limit=kwargs.get("top_k", 50)
query_vector, limit=kwargs.get("top_k", 4)
)
score_threshold = float(kwargs.get("score_threshold") or 0.0)
return self._get_search_res(results, score_threshold)
Expand Down
2 changes: 1 addition & 1 deletion api/core/rag/datasource/vdb/weaviate/weaviate_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
query_obj = query_obj.with_where(kwargs.get("where_filter"))
query_obj = query_obj.with_additional(["vector"])
properties = ["text"]
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 2)).do()
result = query_obj.with_bm25(query=query, properties=properties).with_limit(kwargs.get("top_k", 4)).do()
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
docs = []
Expand Down

0 comments on commit 8659485

Please sign in to comment.