Skip to content

Commit

Permalink
feat(DB engine spec): get_catalog_names (apache#23447)
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida authored Mar 22, 2023
1 parent fb270cb commit 8588f81
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 0 deletions.
22 changes: 22 additions & 0 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,14 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
# a custom `adjust_engine_params` method.
supports_dynamic_schema = False

# Does the DB support catalogs? A catalog here is a group of schemas, and has
# different names depending on the DB: BigQuery calles it a "project", Postgres calls
# it a "database", Trino calls it a "catalog", etc.
supports_catalog = False

# Can the catalog be changed on a per-query basis?
supports_dynamic_catalog = False

@classmethod
def supports_url(cls, url: URL) -> bool:
"""
Expand Down Expand Up @@ -1091,6 +1099,20 @@ def patch(cls) -> None:
TODO: Improve docstring and refactor implementation in Hive
"""

@classmethod
def get_catalog_names( # pylint: disable=unused-argument
cls,
database: Database,
inspector: Inspector,
) -> List[str]:
"""
Get all catalogs from database.
This needs to be implemented per database, since SQLAlchemy doesn't offer an
abstraction.
"""
return []

@classmethod
def get_schema_names(cls, inspector: Inspector) -> List[str]:
"""
Expand Down
22 changes: 22 additions & 0 deletions superset/db_engine_specs/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from flask_babel import gettext as __
from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, ENUM, JSON
from sqlalchemy.dialects.postgresql.base import PGInspector
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from sqlalchemy.types import Date, DateTime, String

Expand Down Expand Up @@ -291,6 +292,27 @@ def query_cost_formatter(
) -> List[Dict[str, str]]:
return [{k: str(v) for k, v in row.items()} for row in raw_cost]

@classmethod
def get_catalog_names(
cls,
database: "Database",
inspector: Inspector,
) -> List[str]:
"""
Return all catalogs.
In Postgres, a catalog is called a "database".
"""
return sorted(
catalog
for (catalog,) in inspector.bind.execute(
"""
SELECT datname FROM pg_database
WHERE datistemplate = false;
"""
).fetchall()
)

@classmethod
def get_table_names(
cls, database: "Database", inspector: PGInspector, schema: Optional[str]
Expand Down
18 changes: 18 additions & 0 deletions tests/integration_tests/db_engine_specs/postgres_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
from textwrap import dedent
from unittest import mock

from flask.ctx import AppContext
from sqlalchemy import column, literal_column
from sqlalchemy.dialects import postgresql

from superset.db_engine_specs import load_engine_specs
from superset.db_engine_specs.postgres import PostgresEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.models.sql_lab import Query
from superset.utils.database import get_example_database
from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec
from tests.integration_tests.fixtures.certificates import ssl_certificate
from tests.integration_tests.fixtures.database import default_db_extra
Expand Down Expand Up @@ -514,3 +516,19 @@ def test_base_parameters_mixin():
},
"required": ["database", "host", "port", "username"],
}


def test_get_catalog_names(app_context: AppContext) -> None:
"""
Test the ``get_catalog_names`` method.
"""
database = get_example_database()

if database.backend != "postgresql":
return

with database.get_inspector_with_context() as inspector:
assert PostgresEngineSpec.get_catalog_names(database, inspector) == [
"postgres",
"superset",
]

0 comments on commit 8588f81

Please sign in to comment.