Skip to content

Commit

Permalink
Fix resource issues with embeddings indexing components backed by dat…
Browse files Browse the repository at this point in the history
…abases, closes #831
  • Loading branch information
davidmezzetti committed Dec 6, 2024
1 parent 89410d3 commit e7aa46a
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 42 deletions.
56 changes: 36 additions & 20 deletions src/python/txtai/ann/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
try:
from pgvector.sqlalchemy import Vector

from sqlalchemy import create_engine, delete, text, Column, Index, Integer, MetaData, StaticPool, Table
from sqlalchemy import create_engine, delete, func, text, Column, Index, Integer, MetaData, StaticPool, Table
from sqlalchemy.orm import Session
from sqlalchemy.schema import CreateSchema

Expand All @@ -30,21 +30,10 @@ def __init__(self, config):
if not PGVECTOR:
raise ImportError('PGVector is not available - install "ann" extra to enable')

# Create engine
self.engine = create_engine(self.setting("url", os.environ.get("ANN_URL")), poolclass=StaticPool, echo=False)

# Initialize pgvector extension
self.database = Session(self.engine)
self.sqldialect(text("CREATE EXTENSION IF NOT EXISTS vector"))
self.database.commit()

# Table instance
self.table = None
# Database connection
self.engine, self.database, self.connection, self.table = None, None, None, None

def load(self, path):
# Reset database to original checkpoint
self.database.rollback()

# Initialize tables
self.initialize()

Expand Down Expand Up @@ -84,17 +73,22 @@ def search(self, queries, limit):
return results

def count(self):
return self.database.query(self.table.c["indexid"]).count()
# pylint: disable=E1102
return self.database.query(func.count(self.table.c["indexid"])).scalar()

def save(self, path):
# Commit session and connection
self.database.commit()
self.connection.commit()

def close(self):
# Parent logic
super().close()

# Close database connection
self.database.close()
if self.database:
self.database.close()
self.engine.dispose()

def initialize(self, recreate=False):
"""
Expand All @@ -104,6 +98,9 @@ def initialize(self, recreate=False):
recreate: Recreates the database tables if True
"""

# Connect to database
self.connect()

# Set default schema, if necessary
schema = self.setting("schema")
if schema:
Expand Down Expand Up @@ -132,12 +129,31 @@ def initialize(self, recreate=False):

# Drop and recreate table
if recreate:
self.table.drop(self.engine, checkfirst=True)
index.drop(self.engine, checkfirst=True)
self.table.drop(self.connection, checkfirst=True)
index.drop(self.connection, checkfirst=True)

# Create table and index
self.table.create(self.engine, checkfirst=True)
index.create(self.engine, checkfirst=True)
self.table.create(self.connection, checkfirst=True)
index.create(self.connection, checkfirst=True)

def connect(self):
"""
Establishes a database connection. Cleans up any existing database connection first.
"""

# Close existing connection
if self.database:
self.close()

# Create engine
self.engine = create_engine(self.setting("url", os.environ.get("ANN_URL")), poolclass=StaticPool, echo=False)
self.connection = self.engine.connect()

# Start database session
self.database = Session(self.connection)

# Initialize pgvector extension
self.sqldialect(text("CREATE EXTENSION IF NOT EXISTS vector"))

def settings(self):
"""
Expand Down
9 changes: 9 additions & 0 deletions src/python/txtai/ann/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,15 @@ def save(self, path):
else:
self.copy(path).close()

def close(self):
# Parent logic
super().close()

# Close database connection
if self.connection:
self.connection.close()
self.connection = None

def initialize(self, recreate=False):
"""
Initializes a new database session.
Expand Down
41 changes: 29 additions & 12 deletions src/python/txtai/database/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,30 @@ def __init__(self, config):
if not ORM:
raise ImportError('SQLAlchemy is not available - install "database" extra to enable')

# SQLAlchemy parameters
self.engine, self.dbconnection = None, None

def save(self, path):
# Commit session and database connection
super().save(path)

if self.dbconnection:
self.dbconnection.commit()

def close(self):
super().close()

# Dispose of engine, which also closes dbconnection
if self.engine:
self.engine.dispose()

def reindexstart(self):
# Working table name
name = f"rebuild{round(time.time() * 1000)}"

# Create working table metadata
type("Rebuild", (SectionBase,), {"__tablename__": name})
Base.metadata.tables[name].create(self.connection.bind)
Base.metadata.tables[name].create(self.dbconnection)

return name

Expand All @@ -62,11 +79,11 @@ def jsoncolumn(self, name):
d = aliased(Document, name="d")

# Build JSON column expression for column
return str(cast(d.data[name].as_string(), Text).compile(dialect=self.connection.bind.dialect, compile_kwargs={"literal_binds": True}))
return str(cast(d.data[name].as_string(), Text).compile(dialect=self.engine.dialect, compile_kwargs={"literal_binds": True}))

def createtables(self):
# Create tables
Base.metadata.create_all(self.connection.bind, checkfirst=True)
Base.metadata.create_all(self.dbconnection, checkfirst=True)

# Clear existing data - table schema is created upon connecting to database
for table in ["sections", "documents", "objects"]:
Expand All @@ -88,7 +105,7 @@ def insertsection(self, index, uid, text, tags, entry):

def createbatch(self):
# Create temporary batch table, if necessary
Base.metadata.tables["batch"].create(self.connection.bind, checkfirst=True)
Base.metadata.tables["batch"].create(self.dbconnection, checkfirst=True)

def insertbatch(self, indexids, ids, batch):
if indexids:
Expand All @@ -98,7 +115,7 @@ def insertbatch(self, indexids, ids, batch):

def createscores(self):
# Create temporary scores table, if necessary
Base.metadata.tables["scores"].create(self.connection.bind, checkfirst=True)
Base.metadata.tables["scores"].create(self.dbconnection, checkfirst=True)

def insertscores(self, scores):
# Average scores by id
Expand All @@ -113,16 +130,17 @@ def connect(self, path=None):
content = os.environ.get("CLIENT_URL") if content == "client" else content

# Create engine using database URL
engine = create_engine(content, poolclass=StaticPool, echo=False, json_serializer=lambda x: x)
self.engine = create_engine(content, poolclass=StaticPool, echo=False, json_serializer=lambda x: x)
self.dbconnection = self.engine.connect()

# Create database session
database = Session(engine)
database = Session(self.dbconnection)

# Set default schema, if necessary
schema = self.config.get("schema")
if schema:
self.sqldialect(database, engine, CreateSchema(schema, if_not_exists=True))
self.sqldialect(database, engine, textsql("SET search_path TO :schema"), {"schema": schema})
self.sqldialect(database, CreateSchema(schema, if_not_exists=True))
self.sqldialect(database, textsql("SET search_path TO :schema"), {"schema": schema})

return database

Expand All @@ -135,18 +153,17 @@ def rows(self):
def addfunctions(self):
return

def sqldialect(self, database, engine, sql, parameters=None):
def sqldialect(self, database, sql, parameters=None):
"""
Executes a SQL statement based on the current SQL dialect.
Args:
database: current database
engine: database engine
sql: SQL to execute
parameters: optional bind parameters
"""

args = (sql, parameters) if engine.dialect.name == "postgresql" else (textsql("SELECT 1"),)
args = (sql, parameters) if self.engine.dialect.name == "postgresql" else (textsql("SELECT 1"),)
database.execute(*args)


Expand Down
18 changes: 17 additions & 1 deletion src/python/txtai/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,10 @@ def initindex(self, reindex):
# Reset archive since this is a new index
self.archive = None

# Close existing ANN, if necessary
if self.ann:
self.ann.close()

# Initialize ANN, will be created after index transformations complete
self.ann = None

Expand Down Expand Up @@ -890,6 +894,10 @@ def createann(self):
new ANN, if enabled in config
"""

# Free existing resources
if self.ann:
self.ann.close()

return ANNFactory.create(self.config) if self.config.get("path") or self.defaultallowed() else None

def createdatabase(self):
Expand All @@ -900,7 +908,7 @@ def createdatabase(self):
new database, if enabled in config
"""

# Free existing database resources
# Free existing resources
if self.database:
self.database.close()

Expand All @@ -922,6 +930,10 @@ def creategraph(self):
new graph, if enabled in config
"""

# Free existing resources
if self.graph:
self.graph.close()

if "graph" in self.config:
# Get or create graph configuration
config = self.config["graph"] if "graph" in self.config else {}
Expand Down Expand Up @@ -954,6 +966,10 @@ def createindexes(self):
list of subindexes
"""

# Free existing resources
if self.indexes:
self.indexes.close()

# Load subindexes
if "indexes" in self.config:
indexes = {}
Expand Down
24 changes: 15 additions & 9 deletions src/python/txtai/scoring/pgtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# Conditional import
try:
from sqlalchemy import create_engine, desc, delete, text
from sqlalchemy import create_engine, desc, delete, func, text
from sqlalchemy import Column, Computed, Index, Integer, MetaData, StaticPool, Table, Text
from sqlalchemy.dialects.postgresql import TSVECTOR
from sqlalchemy.orm import Session
Expand All @@ -31,7 +31,7 @@ def __init__(self, config=None):
raise ImportError('PGText is not available - install "scoring" extra to enable')

# Database connection
self.engine, self.database, self.table = None, None, None
self.engine, self.database, self.connection, self.table = None, None, None, None

# Language
self.language = self.config.get("language", "english")
Expand Down Expand Up @@ -85,23 +85,28 @@ def batchsearch(self, queries, limit=3, threads=True):
return [self.search(query, limit) for query in queries]

def count(self):
return self.database.query(self.table.c["indexid"]).count()
# pylint: disable=E1102
return self.database.query(func.count(self.table.c["indexid"])).scalar()

def load(self, path):
# Reset database to original checkpoint
if self.database:
self.database.rollback()
self.connection.rollback()

# Initialize tables
self.initialize()

def save(self, path):
# Commit session and connection
if self.database:
self.database.commit()
self.connection.commit()

def close(self):
if self.database:
self.database.close()
self.engine.dispose()

def hasterms(self):
return True
Expand All @@ -118,9 +123,10 @@ def initialize(self, recreate=False):
"""

if not self.database:
# Create engine and session
# Create engine, connection and session
self.engine = create_engine(self.config.get("url", os.environ.get("SCORING_URL")), poolclass=StaticPool, echo=False)
self.database = Session(self.engine)
self.connection = self.engine.connect()
self.database = Session(self.connection)

# Set default schema, if necessary
schema = self.config.get("schema")
Expand Down Expand Up @@ -153,12 +159,12 @@ def initialize(self, recreate=False):

# Drop and recreate table
if recreate:
self.table.drop(self.engine, checkfirst=True)
index.drop(self.engine, checkfirst=True)
self.table.drop(self.connection, checkfirst=True)
index.drop(self.connection, checkfirst=True)

# Create table and index
self.table.create(self.engine, checkfirst=True)
index.create(self.engine, checkfirst=True)
self.table.create(self.connection, checkfirst=True)
index.create(self.connection, checkfirst=True)

def sqldialect(self, sql, parameters=None):
"""
Expand Down
4 changes: 4 additions & 0 deletions test/python/testann.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,11 @@ def save(self, name, params=None):
# Generate temp file path
index = os.path.join(tempfile.gettempdir(), "ann")

# Save and close index
model.save(index)
model.close()

# Reload index
model.load(index)

return model
Expand Down

0 comments on commit e7aa46a

Please sign in to comment.