Skip to content

Commit

Permalink
Fix import statements in llm_query.py and test_e_llm_query.py
Browse files Browse the repository at this point in the history
  • Loading branch information
dermatologist committed Jan 5, 2024
1 parent 44932df commit be2b805
Show file tree
Hide file tree
Showing 3 changed files with 241 additions and 3 deletions.
237 changes: 237 additions & 0 deletions src/pyomop/llm_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
"""SQL wrapper around SQLDatabase in langchain."""
from typing import Any, Dict, Iterable, List, Optional, Tuple

from sqlalchemy import MetaData, create_engine, insert, inspect, text
from sqlalchemy.engine import Engine
from sqlalchemy.exc import OperationalError, ProgrammingError


class SQLDatabase:
"""SQL Database.
This class provides a wrapper around the SQLAlchemy engine to interact with a SQL
database.
It provides methods to execute SQL commands, insert data into tables, and retrieve
information about the database schema.
It also supports optional features such as including or excluding specific tables,
sampling rows for table info,
including indexes in table info, and supporting views.
Based on langchain SQLDatabase.
https://github.com/langchain-ai/langchain/blob/e355606b1100097665207ca259de6dc548d44c78/libs/langchain/langchain/utilities/sql_database.py#L39
Args:
engine (Engine): The SQLAlchemy engine instance to use for database operations.
schema (Optional[str]): The name of the schema to use, if any.
metadata (Optional[MetaData]): The metadata instance to use, if any.
ignore_tables (Optional[List[str]]): List of table names to ignore. If set,
include_tables must be None.
include_tables (Optional[List[str]]): List of table names to include. If set,
ignore_tables must be None.
sample_rows_in_table_info (int): The number of sample rows to include in table
info.
indexes_in_table_info (bool): Whether to include indexes in table info.
custom_table_info (Optional[dict]): Custom table info to use.
view_support (bool): Whether to support views.
max_string_length (int): The maximum string length to use.
"""

def __init__(
self,
engine: Engine,
schema: Optional[str] = None,
metadata: Optional[MetaData] = None,
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 3,
indexes_in_table_info: bool = False,
custom_table_info: Optional[dict] = None,
view_support: bool = False,
max_string_length: int = 300,
):
"""Create engine from database URI."""
self._engine = engine
self._schema = schema
if include_tables and ignore_tables:
raise ValueError("Cannot specify both include_tables and ignore_tables")
self._inspector = self.setup(
engine,
schema,
metadata,
ignore_tables,
include_tables,
sample_rows_in_table_info,
indexes_in_table_info,
custom_table_info,
view_support,
max_string_length,
)


@property
def engine(self) -> Engine:
"""Return SQL Alchemy engine."""
return self._engine

@property
def metadata_obj(self) -> MetaData:
"""Return SQL Alchemy metadata."""
return self._metadata

@classmethod
def from_uri(
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
) -> "SQLDatabase":
"""Construct a SQLAlchemy engine from URI."""
_engine_args = engine_args or {}
return cls(create_engine(database_uri, **_engine_args), **kwargs)

@property
def dialect(self) -> str:
"""Return string representation of dialect to use."""
return self._engine.dialect.name

async def setup(
self,
engine,
schema: Optional[str] = None,
metadata: Optional[MetaData] = None,
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 3,
indexes_in_table_info: bool = False,
custom_table_info: Optional[dict] = None,
view_support: bool = False,
max_string_length: int = 300,
) -> None:
# self._inspector = inspect(self._engine)
_inspector = await self.get_inspector(engine)

# including view support by adding the views as well as tables to the all
# tables list if view_support is True
self._all_tables = set(
_inspector.get_table_names(schema=schema)
+ (_inspector.get_view_names(schema=schema) if view_support else [])
)

self._include_tables = set(include_tables) if include_tables else set()
if self._include_tables:
missing_tables = self._include_tables - self._all_tables
if missing_tables:
raise ValueError(
f"include_tables {missing_tables} not found in database"
)
self._ignore_tables = set(ignore_tables) if ignore_tables else set()
if self._ignore_tables:
missing_tables = self._ignore_tables - self._all_tables
if missing_tables:
raise ValueError(
f"ignore_tables {missing_tables} not found in database"
)
usable_tables = self.get_usable_table_names()
self._usable_tables = set(usable_tables) if usable_tables else self._all_tables

if not isinstance(sample_rows_in_table_info, int):
raise TypeError("sample_rows_in_table_info must be an integer")

self._sample_rows_in_table_info = sample_rows_in_table_info
self._indexes_in_table_info = indexes_in_table_info

self._custom_table_info = custom_table_info
if self._custom_table_info:
if not isinstance(self._custom_table_info, dict):
raise TypeError(
"table_info must be a dictionary with table names as keys and the "
"desired table info as values"
)
# only keep the tables that are also present in the database
intersection = set(self._custom_table_info).intersection(self._all_tables)
self._custom_table_info = {
table: info
for table, info in self._custom_table_info.items()
if table in intersection
}

self._max_string_length = max_string_length

self._metadata = metadata or MetaData()
# including view support if view_support = true
self._metadata.reflect(
views=view_support,
bind=self._engine,
only=list(self._usable_tables),
schema=self._schema,
)
return _inspector

async def get_inspector(self, engine: Engine) -> Any:
"""Return inspector."""
async with engine.connect() as conn:
_inspect = await conn.run_sync(
lambda sync_conn: inspect(sync_conn)
)
return _inspect

def get_usable_table_names(self) -> Iterable[str]:
"""Get names of tables available."""
if self._include_tables:
return sorted(self._include_tables)
return sorted(self._all_tables - self._ignore_tables)

async def get_table_columns(self, table_name: str) -> List[Any]:
"""Get table columns."""
return await self._inspector.get_columns(table_name)

async def get_single_table_info(self, table_name: str) -> str:
"""Get table info for a single table."""
# same logic as table_info, but with specific table names
template = (
"Table '{table_name}' has columns: {columns}, "
"and foreign keys: {foreign_keys}."
)
columns = []
for column in await self._inspector.get_clumns(table_name):
if column.get("comment"):
columns.append(
f"{column['name']} ({column['type']!s}): "
f"'{column.get('comment')}'"
)
else:
columns.append(f"{column['name']} ({column['type']!s})")

column_str = ", ".join(columns)
foreign_keys = []
for foreign_key in self._inspector.get_foreign_keys(table_name):
foreign_keys.append(
f"{foreign_key['constrained_columns']} -> "
)
foreign_key_str = ", ".join(foreign_keys)
return template.format(
table_name=table_name, columns=column_str, foreign_keys=foreign_key_str
)

def insert_into_table(self, table_name: str, data: dict) -> None:
"""Insert data into a table."""
table = self._metadata.tables[table_name]
stmt = insert(table).values(**data)
with self._engine.begin() as connection:
connection.execute(stmt)

def run_sql(self, command: str) -> Tuple[str, Dict]:
"""Execute a SQL statement and return a string representing the results.
If the statement returns rows, a string of the results is returned.
If the statement returns no rows, an empty string is returned.
"""
with self._engine.begin() as connection:
try:
cursor = connection.execute(text(command))
except (ProgrammingError, OperationalError) as exc:
raise NotImplementedError(
f"Statement {command!r} is invalid SQL."
) from exc
if cursor.returns_rows:
result = cursor.fetchall()
return str(result), {"result": result, "col_keys": list(cursor.keys())}
return "", {}
3 changes: 2 additions & 1 deletion src/pyomop/llm_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
SQLTableSchema,
)
from llama_index import VectorStoreIndex
from llama_index import SQLDatabase, ServiceContext
from llama_index import ServiceContext
from typing import Any, Optional
from llama_index.prompts import BasePromptTemplate
from llama_index.objects.base import ObjectRetriever

from .llm_engine import SQLDatabase

class CdmLLMQuery(SQLTableRetrieverQueryEngine):
def __init__(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_e_llm_query.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import pytest
from llama_index.llms import Vertex
from llama_index import SQLDatabase
from src.pyomop.llm_engine import SQLDatabase

@staticmethod
def test_create_cohort(pyomop_fixture, metadata_fixture, capsys):
Expand Down Expand Up @@ -37,7 +37,7 @@ async def create_llm_query(pyomop_fixture,engine):
assert cohort.subject_id == 100

llm = Vertex()
sql_database = await SQLDatabase(engine, include_tables=[
sql_database = SQLDatabase(engine, include_tables=[
"cohort",
])
query_engine = CdmLLMQuery(sql_database, llm=llm)
Expand Down

0 comments on commit be2b805

Please sign in to comment.