Skip to content

Commit

Permalink
Do not auto-connect on ctor for postgres (run-llama#7793)
Browse files Browse the repository at this point in the history
  • Loading branch information
Javtor authored Sep 25, 2023
1 parent c8f5d84 commit b869946
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 5 deletions.
5 changes: 4 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
"python.formatting.provider": "none",
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.organizeImports": true,
"source.organizeImports": true
},
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter"
},
"python.testing.pytestArgs": ["tests"],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
### Bug Fixes / Nits
- Normalize scores returned from ElasticSearch vector store (#7792)
- Fixed `refresh_ref_docs()` bug with order of operations (#7664)
- Delay postgresql connection for `PGVectorStore` until actually needed (#7793)

## [0.8.33] - 2023-09-25

Expand Down
22 changes: 18 additions & 4 deletions llama_index/vector_stores/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class PGVectorStore(BasePydanticVectorStore):
_session: Any = PrivateAttr()
_async_engine: Any = PrivateAttr()
_async_session: Any = PrivateAttr()
_is_initialized: bool = PrivateAttr(default=False)

def __init__(
self,
Expand Down Expand Up @@ -157,11 +158,10 @@ def __init__(
debug=debug,
)

self._connect()
self._create_extension()
self._create_tables_if_not_exists()

async def close(self) -> None:
if not self._is_initialized:
return None

self._session.close_all()
self._engine.dispose()

Expand Down Expand Up @@ -207,6 +207,8 @@ def from_params(

@property
def client(self) -> Any:
if not self._is_initialized:
return None
return self._engine

def _connect(self) -> Any:
Expand Down Expand Up @@ -235,6 +237,13 @@ def _create_extension(self) -> None:
session.execute(statement)
session.commit()

def _initialize(self) -> None:
if not self._is_initialized:
self._connect()
self._create_extension()
self._create_tables_if_not_exists()
self._is_initialized = True

def _node_to_table_row(self, node: BaseNode) -> Any:
return self._table_class(
node_id=node.node_id,
Expand All @@ -248,6 +257,7 @@ def _node_to_table_row(self, node: BaseNode) -> Any:
)

def add(self, nodes: List[BaseNode]) -> List[str]:
self._initialize()
ids = []
with self._session() as session:
with session.begin():
Expand All @@ -259,6 +269,7 @@ def add(self, nodes: List[BaseNode]) -> List[str]:
return ids

async def async_add(self, nodes: List[BaseNode]) -> List[str]:
self._initialize()
ids = []
async with self._async_session() as session:
async with session.begin():
Expand Down Expand Up @@ -480,6 +491,7 @@ def _db_rows_to_query_result(
async def aquery(
self, query: VectorStoreQuery, **kwargs: Any
) -> VectorStoreQueryResult:
self._initialize()
if query.mode == VectorStoreQueryMode.HYBRID:
results = await self._async_hybrid_query(query)
elif query.mode in [
Expand All @@ -500,6 +512,7 @@ async def aquery(
return self._db_rows_to_query_result(results)

def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
self._initialize()
if query.mode == VectorStoreQueryMode.HYBRID:
results = self._hybrid_query(query)
elif query.mode in [
Expand All @@ -522,6 +535,7 @@ def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResul
def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
import sqlalchemy

self._initialize()
with self._session() as session:
with session.begin():
stmt = sqlalchemy.text(
Expand Down
11 changes: 11 additions & 0 deletions tests/vector_stores/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ async def test_instance_creation(db: None) -> None:
table_name=TEST_TABLE_NAME,
)
assert isinstance(pg, PGVectorStore)
assert not hasattr(pg, "_engine")
assert pg.client is None
await pg.close()


Expand All @@ -173,6 +175,7 @@ async def test_add_to_db_and_query(
else:
pg.add(node_embeddings)
assert isinstance(pg, PGVectorStore)
assert hasattr(pg, "_engine")
q = VectorStoreQuery(query_embedding=_get_sample_vector(1.0), similarity_top_k=1)
if use_async:
res = await pg.aquery(q)
Expand All @@ -194,6 +197,7 @@ async def test_add_to_db_and_query_with_metadata_filters(
else:
pg.add(node_embeddings)
assert isinstance(pg, PGVectorStore)
assert hasattr(pg, "_engine")
filters = MetadataFilters(
filters=[ExactMatchFilter(key="test_key", value="test_value")]
)
Expand All @@ -220,6 +224,7 @@ async def test_add_to_db_query_and_delete(
else:
pg.add(node_embeddings)
assert isinstance(pg, PGVectorStore)
assert hasattr(pg, "_engine")

q = VectorStoreQuery(query_embedding=_get_sample_vector(0.1), similarity_top_k=1)

Expand All @@ -243,6 +248,7 @@ async def test_save_load(
else:
pg.add(node_embeddings)
assert isinstance(pg, PGVectorStore)
assert hasattr(pg, "_engine")

q = VectorStoreQuery(query_embedding=_get_sample_vector(0.1), similarity_top_k=1)

Expand All @@ -258,6 +264,7 @@ async def test_save_load(
await pg.close()

loaded_pg = cast(PGVectorStore, load_vector_store(pg_dict))
assert not hasattr(loaded_pg, "_engine")
loaded_pg_dict = loaded_pg.to_dict()
for key, val in pg.to_dict().items():
assert loaded_pg_dict[key] == val
Expand All @@ -266,6 +273,7 @@ async def test_save_load(
res = await loaded_pg.aquery(q)
else:
res = loaded_pg.query(q)
assert hasattr(loaded_pg, "_engine")
assert res.nodes
assert len(res.nodes) == 1
assert res.nodes[0].node_id == "bbb"
Expand All @@ -286,6 +294,7 @@ async def test_sparse_query(
else:
pg_hybrid.add(hybrid_node_embeddings)
assert isinstance(pg_hybrid, PGVectorStore)
assert hasattr(pg_hybrid, "_engine")

# text search should work when query is a sentence and not just a single word
q = VectorStoreQuery(
Expand Down Expand Up @@ -318,6 +327,7 @@ async def test_hybrid_query(
else:
pg_hybrid.add(hybrid_node_embeddings)
assert isinstance(pg_hybrid, PGVectorStore)
assert hasattr(pg_hybrid, "_engine")

q = VectorStoreQuery(
query_embedding=_get_sample_vector(0.1),
Expand Down Expand Up @@ -389,6 +399,7 @@ async def test_add_to_db_and_hybrid_query_with_metadata_filters(
else:
pg_hybrid.add(hybrid_node_embeddings)
assert isinstance(pg_hybrid, PGVectorStore)
assert hasattr(pg_hybrid, "_engine")
filters = MetadataFilters(
filters=[ExactMatchFilter(key="test_key", value="test_value")]
)
Expand Down

0 comments on commit b869946

Please sign in to comment.