Skip to content

Commit

Permalink
Change schema statements to use bind parameters #829
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmezzetti committed Dec 5, 2024
1 parent 26bd35d commit 89410d3
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 17 deletions.
18 changes: 10 additions & 8 deletions src/python/txtai/ann/pgvector.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,6 @@ def __init__(self, config):
self.sqldialect(text("CREATE EXTENSION IF NOT EXISTS vector"))
self.database.commit()

# Set default schema, if necessary
schema = self.setting("schema")
if schema:
self.sqldialect(CreateSchema(schema, if_not_exists=True))
self.sqldialect(text(f"SET search_path TO {schema},public"))

# Table instance
self.table = None

Expand Down Expand Up @@ -110,6 +104,12 @@ def initialize(self, recreate=False):
recreate: Recreates the database tables if True
"""

# Set default schema, if necessary
schema = self.setting("schema")
if schema:
self.sqldialect(CreateSchema(schema, if_not_exists=True))
self.sqldialect(text("SET search_path TO :schema,public"), {"schema": schema})

# Table name
table = self.setting("table", "vectors")

Expand Down Expand Up @@ -149,12 +149,14 @@ def settings(self):

return {"m": self.setting("m", 16), "ef_construction": self.setting("efconstruction", 200)}

def sqldialect(self, sql):
def sqldialect(self, sql, parameters=None):
"""
Executes a SQL statement based on the current SQL dialect.
Args:
sql: SQL to execute
parameters: optional bind parameters
"""

self.database.execute(sql if self.engine.dialect.name == "postgresql" else text("SELECT 1"))
args = (sql, parameters) if self.engine.dialect.name == "postgresql" else (text("SELECT 1"),)
self.database.execute(*args)
8 changes: 5 additions & 3 deletions src/python/txtai/database/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def connect(self, path=None):
schema = self.config.get("schema")
if schema:
self.sqldialect(database, engine, CreateSchema(schema, if_not_exists=True))
self.sqldialect(database, engine, textsql(f"SET search_path TO {schema}"))
self.sqldialect(database, engine, textsql("SET search_path TO :schema"), {"schema": schema})

return database

Expand All @@ -135,17 +135,19 @@ def rows(self):
def addfunctions(self):
return

def sqldialect(self, database, engine, sql):
def sqldialect(self, database, engine, sql, parameters=None):
"""
Executes a SQL statement based on the current SQL dialect.
Args:
database: current database
engine: database engine
sql: SQL to execute
parameters: optional bind parameters
"""

database.execute(sql if engine.dialect.name == "postgresql" else textsql("SELECT 1"))
args = (sql, parameters) if engine.dialect.name == "postgresql" else (textsql("SELECT 1"),)
database.execute(*args)


class Cursor:
Expand Down
17 changes: 11 additions & 6 deletions src/python/txtai/scoring/pgtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,16 +89,19 @@ def count(self):

def load(self, path):
# Reset database to original checkpoint
self.database.rollback()
if self.database:
self.database.rollback()

# Initialize tables
self.initialize()

def save(self, path):
self.database.commit()
if self.database:
self.database.commit()

def close(self):
self.database.close()
if self.database:
self.database.close()

def hasterms(self):
return True
Expand All @@ -123,7 +126,7 @@ def initialize(self, recreate=False):
schema = self.config.get("schema")
if schema:
self.sqldialect(CreateSchema(schema, if_not_exists=True))
self.sqldialect(text(f"SET search_path TO {schema}"))
self.sqldialect(text("SET search_path TO :schema"), {"schema": schema})

# Table name
table = self.config.get("table", "scoring")
Expand Down Expand Up @@ -157,12 +160,14 @@ def initialize(self, recreate=False):
self.table.create(self.engine, checkfirst=True)
index.create(self.engine, checkfirst=True)

def sqldialect(self, sql):
def sqldialect(self, sql, parameters=None):
"""
Executes a SQL statement based on the current SQL dialect.
Args:
sql: SQL to execute
parameters: optional bind parameters
"""

self.database.execute(sql if self.engine.dialect.name == "postgresql" else text("SELECT 1"))
args = (sql, parameters) if self.engine.dialect.name == "postgresql" else (text("SELECT 1"),)
self.database.execute(*args)

0 comments on commit 89410d3

Please sign in to comment.