Skip to content

Commit

Permalink
Pass X-Presto-Client-Info in presto hook (apache#22416)
Browse files Browse the repository at this point in the history
  • Loading branch information
pingzh authored Mar 24, 2022
1 parent b060416 commit 05b4409
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 1 deletion.
31 changes: 31 additions & 0 deletions airflow/providers/presto/hooks/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import os
import warnings
from typing import Any, Callable, Iterable, Optional, overload
Expand All @@ -27,6 +28,34 @@
from airflow.configuration import conf
from airflow.hooks.dbapi import DbApiHook
from airflow.models import Connection
from airflow.utils.operator_helpers import AIRFLOW_VAR_NAME_FORMAT_MAPPING

try:
from airflow.utils.operator_helpers import DEFAULT_FORMAT_PREFIX
except ImportError:
# This is from airflow.utils.operator_helpers,
# For the sake of provider backward compatibility, this is hardcoded if import fails
# https://github.com/apache/airflow/pull/22416#issuecomment-1075531290
DEFAULT_FORMAT_PREFIX = 'airflow.ctx.'


def generate_presto_client_info() -> str:
"""Return json string with dag_id, task_id, execution_date and try_number"""
context_var = {
format_map['default'].replace(DEFAULT_FORMAT_PREFIX, ''): os.environ.get(
format_map['env_var_format'], ''
)
for format_map in AIRFLOW_VAR_NAME_FORMAT_MAPPING.values()
}
task_info = {
'dag_id': context_var['dag_id'],
'task_id': context_var['task_id'],
'execution_date': context_var['execution_date'],
'try_number': context_var['try_number'],
'dag_run_id': context_var['dag_run_id'],
'dag_owner': context_var['dag_owner'],
}
return json.dumps(task_info, sort_keys=True)


class PrestoException(Exception):
Expand Down Expand Up @@ -83,11 +112,13 @@ def get_conn(self) -> Connection:
ca_bundle=extra.get('kerberos__ca_bundle'),
)

http_headers = {"X-Presto-Client-Info": generate_presto_client_info()}
presto_conn = prestodb.dbapi.connect(
host=db.host,
port=db.port,
user=db.login,
source=db.extra_dejson.get('source', 'airflow'),
http_headers=http_headers,
http_scheme=db.extra_dejson.get('protocol', 'http'),
catalog=db.extra_dejson.get('catalog', 'hive'),
schema=db.schema,
Expand Down
73 changes: 72 additions & 1 deletion tests/providers/presto/hooks/test_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,31 @@

from airflow import AirflowException
from airflow.models import Connection
from airflow.providers.presto.hooks.presto import PrestoHook
from airflow.providers.presto.hooks.presto import PrestoHook, generate_presto_client_info


def test_generate_airflow_presto_client_info_header():
env_vars = {
'AIRFLOW_CTX_DAG_ID': 'dag_id',
'AIRFLOW_CTX_EXECUTION_DATE': '2022-01-01T00:00:00',
'AIRFLOW_CTX_TASK_ID': 'task_id',
'AIRFLOW_CTX_TRY_NUMBER': '1',
'AIRFLOW_CTX_DAG_RUN_ID': 'dag_run_id',
'AIRFLOW_CTX_DAG_OWNER': 'dag_owner',
}
expected = json.dumps(
{
"dag_id": "dag_id",
"execution_date": "2022-01-01T00:00:00",
"task_id": "task_id",
"try_number": "1",
"dag_run_id": "dag_run_id",
"dag_owner": "dag_owner",
},
sort_keys=True,
)
with patch.dict('os.environ', env_vars):
assert generate_presto_client_info() == expected


class TestPrestoHookConn(unittest.TestCase):
Expand All @@ -45,6 +69,7 @@ def test_get_conn_basic_auth(self, mock_get_connection, mock_connect, mock_basic
catalog='hive',
host='host',
port=None,
http_headers=mock.ANY,
http_scheme='http',
schema='hive',
source='airflow',
Expand Down Expand Up @@ -98,6 +123,7 @@ def test_get_conn_kerberos_auth(self, mock_get_connection, mock_connect, mock_au
catalog='hive',
host='host',
port=None,
http_headers=mock.ANY,
http_scheme='http',
schema='hive',
source='airflow',
Expand All @@ -118,6 +144,51 @@ def test_get_conn_kerberos_auth(self, mock_get_connection, mock_connect, mock_au
)
assert mock_connect.return_value == conn

@patch('airflow.providers.presto.hooks.presto.generate_presto_client_info')
@patch('airflow.providers.presto.hooks.presto.prestodb.auth.BasicAuthentication')
@patch('airflow.providers.presto.hooks.presto.prestodb.dbapi.connect')
@patch('airflow.providers.presto.hooks.presto.PrestoHook.get_connection')
def test_http_headers(
self,
mock_get_connection,
mock_connect,
mock_basic_auth,
mocked_generate_airflow_presto_client_info_header,
):
mock_get_connection.return_value = Connection(
login='login', password='password', host='host', schema='hive'
)
client = json.dumps(
{
"dag_id": "dag-id",
"execution_date": "2022-01-01T00:00:00",
"task_id": "task-id",
"try_number": "1",
"dag_run_id": "dag-run-id",
"dag_owner": "dag-owner",
},
sort_keys=True,
)
http_headers = {'X-Presto-Client-Info': client}

mocked_generate_airflow_presto_client_info_header.return_value = http_headers['X-Presto-Client-Info']

conn = PrestoHook().get_conn()
mock_connect.assert_called_once_with(
catalog='hive',
host='host',
port=None,
http_headers=http_headers,
http_scheme='http',
schema='hive',
source='airflow',
user='login',
isolation_level=0,
auth=mock_basic_auth.return_value,
)
mock_basic_auth.assert_called_once_with('login', 'password')
assert mock_connect.return_value == conn

@parameterized.expand(
[
('False', False),
Expand Down

0 comments on commit 05b4409

Please sign in to comment.