Skip to content

Commit

Permalink
Implements delete_nodes() and clear() for Weviate, Opensearch, Mi…
Browse files Browse the repository at this point in the history
…lvus, Postgres, and Pinecone Vector Stores (run-llama#14800)
  • Loading branch information
jonathanhliu21 authored Jul 19, 2024
1 parent d53b7f6 commit 59f2202
Show file tree
Hide file tree
Showing 12 changed files with 633 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
)
from llama_index.core.vector_stores.types import (
BasePydanticVectorStore,
FilterOperator,
MetadataFilter,
MetadataFilters,
VectorStoreQuery,
VectorStoreQueryMode,
Expand Down Expand Up @@ -365,6 +367,43 @@ def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
self._milvusclient.delete(collection_name=self.collection_name, pks=ids)
logger.debug(f"Successfully deleted embedding with doc_id: {doc_ids}")

def delete_nodes(
self,
node_ids: Optional[List[str]] = None,
filters: Optional[MetadataFilters] = None,
**delete_kwargs: Any,
) -> None:
"""Deletes nodes.
Args:
node_ids (Optional[List[str]], optional): IDs of nodes to delete. Defaults to None.
filters (Optional[MetadataFilters], optional): Metadata filters. Defaults to None.
"""
from copy import deepcopy

filters_cpy = deepcopy(filters) or MetadataFilters(filters=[])

if node_ids:
filters_cpy.filters.append(
MetadataFilter(key="id", value=node_ids, operator=FilterOperator.IN)
)

if filters_cpy is not None:
filter = _to_milvus_filter(filters_cpy)
else:
filter = None

self._milvusclient.delete(
collection_name=self.collection_name,
filter=filter,
**delete_kwargs,
)
logger.debug(f"Successfully deleted node_ids: {node_ids}")

def clear(self) -> None:
"""Clears db."""
self._milvusclient.drop_collection(self.collection_name)

def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
"""Query index for top k most similar nodes.
Expand Down Expand Up @@ -393,9 +432,11 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul
expr.append(
_to_milvus_filter(
query.filters,
kwargs["milvus_scalar_filters"]
if "milvus_scalar_filters" in kwargs
else None,
(
kwargs["milvus_scalar_filters"]
if "milvus_scalar_filters" in kwargs
else None
),
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-vector-stores-milvus"
readme = "README.md"
version = "0.1.20"
version = "0.1.21"

[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,10 @@ def _parse_filters(self, filters: Optional[MetadataFilters]) -> Any:
pre_filter = []
if filters is not None:
for f in filters.legacy_filters():
pre_filter.append({f.key: json.loads(str(f.value))})
if isinstance(f.value, str):
pre_filter.append({f.key: f.value})
else:
pre_filter.append({f.key: json.loads(str(f.value))})

return pre_filter

Expand Down Expand Up @@ -389,6 +392,41 @@ async def delete_by_doc_id(self, doc_id: str) -> None:
}
await self._os_client.delete_by_query(index=self._index, body=search_query)

async def delete_nodes(
self,
node_ids: Optional[List[str]] = None,
filters: Optional[MetadataFilters] = None,
**delete_kwargs: Any,
) -> None:
"""Deletes nodes.
Args:
node_ids (Optional[List[str]], optional): IDs of nodes to delete. Defaults to None.
filters (Optional[MetadataFilters], optional): Metadata filters. Defaults to None.
"""
if not node_ids and not filters:
return

query = {"query": {"bool": {"filter": []}}}
if node_ids:
query["query"]["bool"]["filter"].append({"terms": {"_id": node_ids or []}})

if filters:
for filter in self._parse_filters(filters):
newfilter = {}

for key in filter:
newfilter[f"metadata.{key}.keyword"] = filter[key]

query["query"]["bool"]["filter"].append({"term": newfilter})

await self._os_client.delete_by_query(index=self._index, body=query)

async def clear(self) -> None:
"""Clears index."""
query = {"query": {"bool": {"filter": []}}}
await self._os_client.delete_by_query(index=self._index, body=query)

async def aquery(
self,
query_mode: VectorStoreQueryMode,
Expand Down Expand Up @@ -574,6 +612,44 @@ async def adelete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
"""
await self._client.delete_by_doc_id(ref_doc_id)

async def adelete_nodes(
self,
node_ids: Optional[List[str]] = None,
filters: Optional[MetadataFilters] = None,
**delete_kwargs: Any,
) -> None:
"""Deletes nodes async.
Args:
node_ids (Optional[List[str]], optional): IDs of nodes to delete. Defaults to None.
filters (Optional[MetadataFilters], optional): Metadata filters. Defaults to None.
"""
await self._client.delete_nodes(node_ids, filters, **delete_kwargs)

def delete_nodes(
self,
node_ids: Optional[List[str]] = None,
filters: Optional[MetadataFilters] = None,
**delete_kwargs: Any,
) -> None:
"""Deletes nodes.
Args:
node_ids (Optional[List[str]], optional): IDs of nodes to delete. Defaults to None.
filters (Optional[MetadataFilters], optional): Metadata filters. Defaults to None.
"""
asyncio.get_event_loop().run_until_complete(
self.adelete_nodes(node_ids, filters, **delete_kwargs)
)

async def aclear(self) -> None:
"""Clears index."""
await self._client.clear()

def clear(self) -> None:
"""Clears index."""
asyncio.get_event_loop().run_until_complete(self.aclear())

def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
"""
Query index for top k most similar nodes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-vector-stores-opensearch"
readme = "README.md"
version = "0.1.12"
version = "0.1.13"

[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,19 @@
import pytest
import uuid
from typing import List, Generator
import time

from llama_index.core.schema import NodeRelationship, RelatedNodeInfo, TextNode
from llama_index.vector_stores.opensearch import (
OpensearchVectorClient,
OpensearchVectorStore,
)
from llama_index.core.vector_stores.types import (
FilterOperator,
MetadataFilter,
MetadataFilters,
VectorStoreQuery,
)
from llama_index.core.vector_stores.types import VectorStoreQuery

##
Expand All @@ -33,6 +40,16 @@
finally:
evt_loop.run_until_complete(os_client.close())

TEST_EMBED_DIM = 3


def _get_sample_vector(num: float) -> List[float]:
"""
Get sample embedding vector of the form [num, 1, 1, ..., 1]
where the length of the vector is TEST_EMBED_DIM.
"""
return [num] + [1.0] * (TEST_EMBED_DIM - 1)


@pytest.mark.skipif(opensearch_not_available, reason="opensearch is not available")
def test_connection() -> None:
Expand Down Expand Up @@ -130,6 +147,40 @@ def node_embeddings() -> List[TextNode]:
]


@pytest.fixture(scope="session")
def node_embeddings_2() -> List[TextNode]:
return [
TextNode(
text="lorem ipsum",
id_="aaa",
relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="aaa")},
extra_info={"test_num": "1"},
embedding=_get_sample_vector(1.0),
),
TextNode(
text="dolor sit amet",
id_="bbb",
relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="bbb")},
extra_info={"test_key": "test_value"},
embedding=_get_sample_vector(0.1),
),
TextNode(
text="consectetur adipiscing elit",
id_="ccc",
relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="ccc")},
extra_info={"test_key_list": ["test_value"]},
embedding=_get_sample_vector(0.1),
),
TextNode(
text="sed do eiusmod tempor",
id_="ddd",
relationships={NodeRelationship.SOURCE: RelatedNodeInfo(node_id="ccc")},
extra_info={"test_key_2": "test_val_2"},
embedding=_get_sample_vector(0.1),
),
]


def count_docs_in_index(os_store: OpensearchVectorStore) -> int:
"""Refresh indices and return the count of documents in the index."""
evt_loop.run_until_complete(
Expand All @@ -156,3 +207,123 @@ def test_functionality(
# delete one node using its associated doc_id
os_store.delete("test-1")
assert count_docs_in_index(os_store) == len(node_embeddings) - 1


@pytest.mark.skipif(opensearch_not_available, reason="opensearch is not available")
def test_delete_nodes(
os_store: OpensearchVectorStore, node_embeddings_2: List[TextNode]
):
os_store.add(node_embeddings_2)

q = VectorStoreQuery(query_embedding=_get_sample_vector(0.5), similarity_top_k=10)

# test deleting nothing
os_store.delete_nodes()
time.sleep(1)
res = os_store.query(q)
assert all(i in res.ids for i in ["aaa", "bbb", "ccc"])

# test deleting element that doesn't exist
os_store.delete_nodes(["asdf"])
time.sleep(1)
res = os_store.query(q)
assert all(i in res.ids for i in ["aaa", "bbb", "ccc"])

# test deleting list
os_store.delete_nodes(["aaa", "bbb"])
time.sleep(1)
res = os_store.query(q)
assert all(i not in res.ids for i in ["aaa", "bbb"])
assert "ccc" in res.ids


@pytest.mark.skipif(opensearch_not_available, reason="opensearch is not available")
def test_delete_nodes_metadata(
os_store: OpensearchVectorStore, node_embeddings_2: List[TextNode]
) -> None:
os_store.add(node_embeddings_2)

q = VectorStoreQuery(query_embedding=_get_sample_vector(0.5), similarity_top_k=10)

# test deleting multiple IDs but only one satisfies filter
filters = MetadataFilters(
filters=[
MetadataFilter(
key="test_key",
value="test_value",
operator=FilterOperator.EQ,
)
]
)
os_store.delete_nodes(["aaa", "bbb"], filters=filters)
time.sleep(1)
res = os_store.query(q)
assert all(i in res.ids for i in ["aaa", "ccc", "ddd"])
assert "bbb" not in res.ids

# test deleting one ID which satisfies the filter
filters = MetadataFilters(
filters=[
MetadataFilter(
key="test_num",
value=1,
operator=FilterOperator.EQ,
)
]
)
os_store.delete_nodes(["aaa"], filters=filters)
time.sleep(1)
res = os_store.query(q)
assert all(i not in res.ids for i in ["bbb", "aaa"])
assert all(i in res.ids for i in ["ccc", "ddd"])

# test deleting one ID which doesn't satisfy the filter
filters = MetadataFilters(
filters=[
MetadataFilter(
key="test_num",
value="1",
operator=FilterOperator.EQ,
)
]
)
os_store.delete_nodes(["ccc"], filters=filters)
time.sleep(1)
res = os_store.query(q)
assert all(i not in res.ids for i in ["bbb", "aaa"])
assert all(i in res.ids for i in ["ccc", "ddd"])

# test deleting purely based on filters
filters = MetadataFilters(
filters=[
MetadataFilter(
key="test_key_2",
value="test_val_2",
operator=FilterOperator.EQ,
)
]
)
os_store.delete_nodes(filters=filters)
time.sleep(1)
res = os_store.query(q)
assert all(i not in res.ids for i in ["bbb", "aaa", "ddd"])
assert "ccc" in res.ids


@pytest.mark.skipif(opensearch_not_available, reason="opensearch is not available")
def test_clear(
os_store: OpensearchVectorStore, node_embeddings_2: List[TextNode]
) -> None:
os_store.add(node_embeddings_2)

q = VectorStoreQuery(query_embedding=_get_sample_vector(0.5), similarity_top_k=10)
res = os_store.query(q)
assert all(i in res.ids for i in ["bbb", "aaa", "ddd", "ccc"])

os_store.clear()

time.sleep(1)

res = os_store.query(q)
assert all(i not in res.ids for i in ["bbb", "aaa", "ddd", "ccc"])
assert len(res.ids) == 0
Loading

0 comments on commit 59f2202

Please sign in to comment.