-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix import statements in llm_query.py and test_e_llm_query.py
- Loading branch information
1 parent
44932df
commit be2b805
Showing
3 changed files
with
241 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,237 @@ | ||
"""SQL wrapper around SQLDatabase in langchain.""" | ||
from typing import Any, Dict, Iterable, List, Optional, Tuple | ||
|
||
from sqlalchemy import MetaData, create_engine, insert, inspect, text | ||
from sqlalchemy.engine import Engine | ||
from sqlalchemy.exc import OperationalError, ProgrammingError | ||
|
||
|
||
class SQLDatabase: | ||
"""SQL Database. | ||
This class provides a wrapper around the SQLAlchemy engine to interact with a SQL | ||
database. | ||
It provides methods to execute SQL commands, insert data into tables, and retrieve | ||
information about the database schema. | ||
It also supports optional features such as including or excluding specific tables, | ||
sampling rows for table info, | ||
including indexes in table info, and supporting views. | ||
Based on langchain SQLDatabase. | ||
https://github.com/langchain-ai/langchain/blob/e355606b1100097665207ca259de6dc548d44c78/libs/langchain/langchain/utilities/sql_database.py#L39 | ||
Args: | ||
engine (Engine): The SQLAlchemy engine instance to use for database operations. | ||
schema (Optional[str]): The name of the schema to use, if any. | ||
metadata (Optional[MetaData]): The metadata instance to use, if any. | ||
ignore_tables (Optional[List[str]]): List of table names to ignore. If set, | ||
include_tables must be None. | ||
include_tables (Optional[List[str]]): List of table names to include. If set, | ||
ignore_tables must be None. | ||
sample_rows_in_table_info (int): The number of sample rows to include in table | ||
info. | ||
indexes_in_table_info (bool): Whether to include indexes in table info. | ||
custom_table_info (Optional[dict]): Custom table info to use. | ||
view_support (bool): Whether to support views. | ||
max_string_length (int): The maximum string length to use. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
engine: Engine, | ||
schema: Optional[str] = None, | ||
metadata: Optional[MetaData] = None, | ||
ignore_tables: Optional[List[str]] = None, | ||
include_tables: Optional[List[str]] = None, | ||
sample_rows_in_table_info: int = 3, | ||
indexes_in_table_info: bool = False, | ||
custom_table_info: Optional[dict] = None, | ||
view_support: bool = False, | ||
max_string_length: int = 300, | ||
): | ||
"""Create engine from database URI.""" | ||
self._engine = engine | ||
self._schema = schema | ||
if include_tables and ignore_tables: | ||
raise ValueError("Cannot specify both include_tables and ignore_tables") | ||
self._inspector = self.setup( | ||
engine, | ||
schema, | ||
metadata, | ||
ignore_tables, | ||
include_tables, | ||
sample_rows_in_table_info, | ||
indexes_in_table_info, | ||
custom_table_info, | ||
view_support, | ||
max_string_length, | ||
) | ||
|
||
|
||
@property | ||
def engine(self) -> Engine: | ||
"""Return SQL Alchemy engine.""" | ||
return self._engine | ||
|
||
@property | ||
def metadata_obj(self) -> MetaData: | ||
"""Return SQL Alchemy metadata.""" | ||
return self._metadata | ||
|
||
@classmethod | ||
def from_uri( | ||
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any | ||
) -> "SQLDatabase": | ||
"""Construct a SQLAlchemy engine from URI.""" | ||
_engine_args = engine_args or {} | ||
return cls(create_engine(database_uri, **_engine_args), **kwargs) | ||
|
||
@property | ||
def dialect(self) -> str: | ||
"""Return string representation of dialect to use.""" | ||
return self._engine.dialect.name | ||
|
||
async def setup( | ||
self, | ||
engine, | ||
schema: Optional[str] = None, | ||
metadata: Optional[MetaData] = None, | ||
ignore_tables: Optional[List[str]] = None, | ||
include_tables: Optional[List[str]] = None, | ||
sample_rows_in_table_info: int = 3, | ||
indexes_in_table_info: bool = False, | ||
custom_table_info: Optional[dict] = None, | ||
view_support: bool = False, | ||
max_string_length: int = 300, | ||
) -> None: | ||
# self._inspector = inspect(self._engine) | ||
_inspector = await self.get_inspector(engine) | ||
|
||
# including view support by adding the views as well as tables to the all | ||
# tables list if view_support is True | ||
self._all_tables = set( | ||
_inspector.get_table_names(schema=schema) | ||
+ (_inspector.get_view_names(schema=schema) if view_support else []) | ||
) | ||
|
||
self._include_tables = set(include_tables) if include_tables else set() | ||
if self._include_tables: | ||
missing_tables = self._include_tables - self._all_tables | ||
if missing_tables: | ||
raise ValueError( | ||
f"include_tables {missing_tables} not found in database" | ||
) | ||
self._ignore_tables = set(ignore_tables) if ignore_tables else set() | ||
if self._ignore_tables: | ||
missing_tables = self._ignore_tables - self._all_tables | ||
if missing_tables: | ||
raise ValueError( | ||
f"ignore_tables {missing_tables} not found in database" | ||
) | ||
usable_tables = self.get_usable_table_names() | ||
self._usable_tables = set(usable_tables) if usable_tables else self._all_tables | ||
|
||
if not isinstance(sample_rows_in_table_info, int): | ||
raise TypeError("sample_rows_in_table_info must be an integer") | ||
|
||
self._sample_rows_in_table_info = sample_rows_in_table_info | ||
self._indexes_in_table_info = indexes_in_table_info | ||
|
||
self._custom_table_info = custom_table_info | ||
if self._custom_table_info: | ||
if not isinstance(self._custom_table_info, dict): | ||
raise TypeError( | ||
"table_info must be a dictionary with table names as keys and the " | ||
"desired table info as values" | ||
) | ||
# only keep the tables that are also present in the database | ||
intersection = set(self._custom_table_info).intersection(self._all_tables) | ||
self._custom_table_info = { | ||
table: info | ||
for table, info in self._custom_table_info.items() | ||
if table in intersection | ||
} | ||
|
||
self._max_string_length = max_string_length | ||
|
||
self._metadata = metadata or MetaData() | ||
# including view support if view_support = true | ||
self._metadata.reflect( | ||
views=view_support, | ||
bind=self._engine, | ||
only=list(self._usable_tables), | ||
schema=self._schema, | ||
) | ||
return _inspector | ||
|
||
async def get_inspector(self, engine: Engine) -> Any: | ||
"""Return inspector.""" | ||
async with engine.connect() as conn: | ||
_inspect = await conn.run_sync( | ||
lambda sync_conn: inspect(sync_conn) | ||
) | ||
return _inspect | ||
|
||
def get_usable_table_names(self) -> Iterable[str]: | ||
"""Get names of tables available.""" | ||
if self._include_tables: | ||
return sorted(self._include_tables) | ||
return sorted(self._all_tables - self._ignore_tables) | ||
|
||
async def get_table_columns(self, table_name: str) -> List[Any]: | ||
"""Get table columns.""" | ||
return await self._inspector.get_columns(table_name) | ||
|
||
async def get_single_table_info(self, table_name: str) -> str: | ||
"""Get table info for a single table.""" | ||
# same logic as table_info, but with specific table names | ||
template = ( | ||
"Table '{table_name}' has columns: {columns}, " | ||
"and foreign keys: {foreign_keys}." | ||
) | ||
columns = [] | ||
for column in await self._inspector.get_clumns(table_name): | ||
if column.get("comment"): | ||
columns.append( | ||
f"{column['name']} ({column['type']!s}): " | ||
f"'{column.get('comment')}'" | ||
) | ||
else: | ||
columns.append(f"{column['name']} ({column['type']!s})") | ||
|
||
column_str = ", ".join(columns) | ||
foreign_keys = [] | ||
for foreign_key in self._inspector.get_foreign_keys(table_name): | ||
foreign_keys.append( | ||
f"{foreign_key['constrained_columns']} -> " | ||
) | ||
foreign_key_str = ", ".join(foreign_keys) | ||
return template.format( | ||
table_name=table_name, columns=column_str, foreign_keys=foreign_key_str | ||
) | ||
|
||
def insert_into_table(self, table_name: str, data: dict) -> None: | ||
"""Insert data into a table.""" | ||
table = self._metadata.tables[table_name] | ||
stmt = insert(table).values(**data) | ||
with self._engine.begin() as connection: | ||
connection.execute(stmt) | ||
|
||
def run_sql(self, command: str) -> Tuple[str, Dict]: | ||
"""Execute a SQL statement and return a string representing the results. | ||
If the statement returns rows, a string of the results is returned. | ||
If the statement returns no rows, an empty string is returned. | ||
""" | ||
with self._engine.begin() as connection: | ||
try: | ||
cursor = connection.execute(text(command)) | ||
except (ProgrammingError, OperationalError) as exc: | ||
raise NotImplementedError( | ||
f"Statement {command!r} is invalid SQL." | ||
) from exc | ||
if cursor.returns_rows: | ||
result = cursor.fetchall() | ||
return str(result), {"result": result, "col_keys": list(cursor.keys())} | ||
return "", {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters