Skip to content

Commit

Permalink
add persisted database
Browse files Browse the repository at this point in the history
  • Loading branch information
buremba committed Aug 6, 2024
1 parent 63dca25 commit d966bd3
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 50 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,5 @@ celerybeat.pid
.certs/*
.rill/*
venv/*
.venv/*
.venv/*
.db/*
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,11 @@ By default, UniverSQL uses [your default Azure tenant](https://learn.microsoft.c

## Compute Strategies

`auto` (default): Best effort to run the query locally, with the fallback option to run them on Snowflake.
`auto` (default): Best effort to run the query locally, with the fallback option to run it on Snowflake.

`local`: If the query requires a running warehouse on Snowflake, fails the query. Otherwise runs the query locally.
`local`: If the query requires a running warehouse on Snowflake, fails the query. Otherwise, runs the query locally.

`snowflake`: Runs the queries directly on Snowflake, use UniverSQL as a passthrough. Useful for rewriting queries on the fly, blocking queries based on conditions or re-routing warehouses based on custom logic.
`snowflake`: Runs the queries directly on Snowflake and doesn't change query. Useful for rewriting and blocking queries on the fly based on specific rules or re-routing warehouses based on custom logic.

# Limitations

Expand Down
7 changes: 0 additions & 7 deletions tests/sqlglot_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,6 @@
from universql.util import time_me
from universql.warehouse.duckdb import fix_snowflake_to_duckdb_types

cache = "/Users/bkabak/.universql/cache"
@time_me
def test_cache_size():
print(sum(f.stat().st_size for f in Path(cache).glob('**/*') if f.is_file()))

test_cache_size()
test_cache_size()
# queries = sqlglot.parse("""
# SET tables = (SHOW TABLES);
#
Expand Down
19 changes: 6 additions & 13 deletions universql/catalog/snow/show_iceberg_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
MAX_LIMIT = 10000

logging.basicConfig(level=logging.INFO)
cloud_logger = logging.getLogger("❄️(cloud services)")
logger = logging.getLogger("❄️")

queries_that_doesnt_need_warehouse = ["show"]

class SnowflakeIcebergCursor(Cursor):
def __init__(self, query_id, cursor: SnowflakeCursor):
Expand All @@ -33,17 +34,12 @@ def execute(self, asts: typing.Optional[List[sqlglot.exp.Expression]], raw_query
else:
compiled_sql = ";".join([ast.sql(dialect="snowflake", pretty=True) for ast in asts])
try:
self.cursor.execute(compiled_sql)
emoji = ""
if all(ast.key == 'show' for ast in asts):
logger = cloud_logger
else:
emoji = "💰"
logger = logging.getLogger(f"❄️{self.cursor.connection.warehouse}")

run_on_warehouse = not all(ast.name in queries_that_doesnt_need_warehouse for ast in asts)
emoji = "☁️(user cloud services)" if not run_on_warehouse else "💰(used warehouse)"
logger.info(f"[{self.query_id}] Running on Snowflake.. {emoji}")
self.cursor.execute(compiled_sql)
except DatabaseError as e:
message = f"Unable to run Snowflake query: \n {compiled_sql} \n {e.msg}"
message = f"Unable to run Snowflake query {e.sfqid}: \n {compiled_sql} \n {e.msg}"
raise SnowflakeError(e.sfqid, message, e.sqlstate)

def close(self):
Expand Down Expand Up @@ -202,9 +198,6 @@ def get_table_references(self, cursor: duckdb.DuckDBPyConnection, tables: List[s
cur.execute(final_query, values)

result = cur.fetchall()
used_tables = ",".join(set(table.sql() for table in tables))
logging.getLogger("❄️cloud").info(
f"[{self.query_id}] Executed metadata query to get Iceberg table locations for tables {used_tables}")
return {table: SnowflakeShowIcebergTables._get_ref(json.loads(result[0][idx])) for idx, table in
enumerate(tables)}

Expand Down
2 changes: 2 additions & 0 deletions universql/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def cli():
type=str)
@click.option('--max-cache-size', type=str, default=DEFAULTS["max_cache_size"],
help='DuckDB maximum cache used in local disk (default: 80% of total available disk)')
@click.option('--database-path', type=click.Path(exists=False, writable=True), default=":memory:",
help='For persistent storage, provide a path to the DuckDB database file (default: :memory:)')
def snowflake(host, port, ssl_keyfile, ssl_certfile, account, catalog, compute, **kwargs):
context__params = click.get_current_context().params
auto_catalog_mode = catalog is None
Expand Down
23 changes: 11 additions & 12 deletions universql/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,18 @@ async def login_request(request: Request) -> JSONResponse:
client_environment = login_data.get('CLIENT_ENVIRONMENT')
credentials = {key: client_environment[key] for key in ["schema", "warehouse", "role", "user", "database"] if
key in client_environment}
credentials['password'] = login_data.get('PASSWORD')

if "user" not in credentials:
if login_data.get('PASSWORD') is not None:
credentials['password'] = login_data.get('PASSWORD')
if "user" not in credentials and login_data.get('LOGIN_NAME') is not None:
credentials["user"] = login_data.get("LOGIN_NAME")

params = request.query_params
if "database" not in credentials:
credentials["database"] = request.query_params.get('databaseName')
credentials["database"] = params.get('databaseName')
if "warehouse" not in credentials:
credentials["warehouse"] = params.get('warehouse')
if "warehouse" not in credentials:
credentials["warehouse"] = request.query_params.get('warehouse')
credentials["role"] = params.get('roleName')

token = str(uuid4())
message = None
Expand All @@ -81,12 +85,7 @@ async def login_request(request: Request) -> JSONResponse:
"token": token,
"masterToken": token,
"parameters": parameters,
"sessionInfo": {
"databaseName": credentials.get('database'),
"schemaName": credentials.get('schema'),
"warehouseName": credentials.get('warehouse'),
"roleName": credentials.get('role')
},
"sessionInfo": {f'{k}Name': v for k, v in credentials.items()},
"idToken": None,
"idTokenValidityInSeconds": 0,
"responseData": None,
Expand All @@ -100,7 +99,7 @@ async def login_request(request: Request) -> JSONResponse:
"masterValidityInSeconds": 14400,
"displayUserName": "",
"serverVersion": "duck",
"firstLogin": True,
"firstLogin": False,
"remMeToken": None,
"remMeValidityInSeconds": 0,
"healthCheckInterval": 45,
Expand Down
32 changes: 18 additions & 14 deletions universql/warehouse/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@
from sqlglot.optimizer.simplify import simplify

from universql.catalog import get_catalog
from universql.catalog.snow.show_iceberg_tables import cloud_logger
from universql.catalog.snow.show_iceberg_tables import logger as cloud_logger, queries_that_doesnt_need_warehouse
from universql.lake.cloud import s3, gcs
from universql.util import get_columns_for_duckdb, SnowflakeError, Compute, Catalog, get_friendly_time_since, \
prepend_to_lines

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("🐥")

queries_that_doesnt_need_warehouse = ["show"]


class UniverSQLSession:
Expand All @@ -35,7 +34,8 @@ def __init__(self, context, token, credentials: dict, session_parameters: dict):
self.token = token
self.catalog = get_catalog(context, self.token,
self.credentials)
self.duckdb = duckdb.connect(read_only=False, config={
duckdb_path = context.get('database_path')
self.duckdb = duckdb.connect(duckdb_path, config={
'max_memory': context.get('max_memory'),
'temp_directory': os.path.join(context.get('cache_directory'), "duckdb-staging"),
'max_temp_directory_size': context.get('max_cache_size'),
Expand Down Expand Up @@ -94,15 +94,14 @@ def _do_query(self, raw_query: str) -> (str, List, pyarrow.Table):
queries = None

should_run_locally = self.compute != Compute.SNOWFLAKE.value
can_run_locally = queries is not None
run_snowflake_already = False
last_compute = Compute.LOCAL if queries is not None else Compute.SNOWFLAKE

if can_run_locally and should_run_locally:
if last_compute == Compute.LOCAL and should_run_locally:
for ast in queries:
if ast.key in queries_that_doesnt_need_warehouse and self.context.get(
if ast.name in queries_that_doesnt_need_warehouse and self.context.get(
'catalog') == Catalog.SNOWFLAKE.value:
self.do_snowflake_query(queries, raw_query, start_time, local_error_message)
run_snowflake_already = True
last_compute = Compute.SNOWFLAKE
else:
tables = list(ast.find_all(sqlglot.exp.Table))

Expand All @@ -116,30 +115,35 @@ def _do_query(self, raw_query: str) -> (str, List, pyarrow.Table):
transformed_ast = self.sync_duckdb_catalog(locations,
simplify(ast)) if locations is not None else None
if transformed_ast is None:
can_run_locally = False
last_compute = None
break

sql = transformed_ast.sql(dialect="duckdb", pretty=True)
try:
self.duckdb_emulator.execute(sql)
logger.info(f"[{self.token}] executing DuckDB query:\n{prepend_to_lines(sql)}")
self.duckdb_emulator.execute(sql)
last_compute = Compute.LOCAL
except duckdb.Error as e:
local_error_message = f"Unable to run the parse locally on DuckDB. {e.args}"
can_run_locally = False
last_compute = None
break
except DatabaseError as e:
local_error_message = f"Unable to run the query locally on DuckDB. {e.msg}"
can_run_locally = False
last_compute = None
break

if can_run_locally and not run_snowflake_already and should_run_locally:
if last_compute == Compute.SNOWFLAKE:
return self.get_snowflake_result()
elif last_compute == Compute.LOCAL:
logger.info(f"[{self.token}] Run locally 🚀 ({get_friendly_time_since(start_time)})")
return self.get_duckdb_result()
else:
elif last_compute is None:
if local_error_message is not None:
logger.error(f"[{self.token}] {local_error_message}")
self.do_snowflake_query(queries, raw_query, start_time, local_error_message)
return self.get_snowflake_result()
else:
raise SnowflakeError(self.token, f"Unsupported compute type. {last_compute}")

def do_snowflake_query(self, queries, raw_query, start_time, local_error_message):
try:
Expand Down

0 comments on commit d966bd3

Please sign in to comment.