Skip to content

Commit

Permalink
Allow passing context to DruidDbApiHook (apache#34603)
Browse files Browse the repository at this point in the history
Druid's SQL API endpoint can accept context param to allow use of
various query functionality [described in documentation](https://druid.apache.org/docs/latest/querying/sql-query-context/).
This change enables passing context when using `DruidDbApiHook`.
  • Loading branch information
saulius authored Oct 3, 2023
1 parent 3064812 commit 89df63b
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
9 changes: 9 additions & 0 deletions airflow/providers/apache/druid/hooks/druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ class DruidDbApiHook(DbApiHook):
This hook is purely for users to query druid broker.
For ingestion, please use druidHook.
:param context: Optional query context parameters to pass to the SQL endpoint.
Example: ``{"sqlFinalizeOuterSketches": True}``
See: https://druid.apache.org/docs/latest/querying/sql-query-context/
"""

conn_name_attr = "druid_broker_conn_id"
Expand All @@ -164,6 +168,10 @@ class DruidDbApiHook(DbApiHook):
hook_name = "Druid"
supports_autocommit = False

def __init__(self, context: dict | None = None, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.context = context or {}

def get_conn(self) -> connect:
"""Establish a connection to druid broker."""
conn = self.get_connection(getattr(self, self.conn_name_attr))
Expand All @@ -174,6 +182,7 @@ def get_conn(self) -> connect:
scheme=conn.extra_dejson.get("schema", "http"),
user=conn.login,
password=conn.password,
context=self.context,
)
self.log.info("Get the connection to druid broker on %s using user %s", conn.host, conn.login)
return druid_broker_conn
Expand Down
32 changes: 32 additions & 0 deletions tests/providers/apache/druid/hooks/test_druid.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,38 @@ def get_connection(self, conn_id):

self.db_hook = TestDruidDBApiHook

@patch("airflow.providers.apache.druid.hooks.druid.DruidDbApiHook.get_connection")
@patch("airflow.providers.apache.druid.hooks.druid.connect")
@pytest.mark.parametrize(
("specified_context", "passed_context"),
[
(None, {}),
({"query_origin": "airflow"}, {"query_origin": "airflow"}),
],
)
def test_get_conn_with_context(
self, mock_connect, mock_get_connection, specified_context, passed_context
):
get_conn_value = MagicMock()
get_conn_value.host = "test_host"
get_conn_value.conn_type = "https"
get_conn_value.login = "test_login"
get_conn_value.password = "test_password"
get_conn_value.port = 10000
get_conn_value.extra_dejson = {"endpoint": "/test/endpoint", "schema": "https"}
mock_get_connection.return_value = get_conn_value
hook = DruidDbApiHook(context=specified_context)
hook.get_conn()
mock_connect.assert_called_with(
host="test_host",
port=10000,
path="/test/endpoint",
scheme="https",
user="test_login",
password="test_password",
context=passed_context,
)

def test_get_uri(self):
db_hook = self.db_hook()
assert "druid://host:1000/druid/v2/sql" == db_hook.get_uri()
Expand Down

0 comments on commit 89df63b

Please sign in to comment.