Skip to content

Commit

Permalink
Make DbApiHook use get_uri from Connection (apache#21764)
Browse files Browse the repository at this point in the history
DBApi has its own get_uri method which does not deal
with quoting properly and neither with empty passwords.
Connection also has a get_uri method that deals properly
with the above issues.

This also fixes issues with RFC compliancy.
  • Loading branch information
bolkedebruin authored Feb 25, 2022
1 parent 7e03225 commit 59c450e
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 31 deletions.
11 changes: 2 additions & 9 deletions airflow/hooks/dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from contextlib import closing
from datetime import datetime
from typing import Any, Optional
from urllib.parse import quote_plus, urlunsplit

from sqlalchemy import create_engine

Expand Down Expand Up @@ -96,14 +95,8 @@ def get_uri(self) -> str:
:return: the extracted uri.
"""
conn = self.get_connection(getattr(self, self.conn_name_attr))
login = ''
if conn.login:
login = f'{quote_plus(conn.login)}:{quote_plus(conn.password)}@'
host = conn.host
if conn.port is not None:
host += f':{conn.port}'
schema = self.__schema or conn.schema or ''
return urlunsplit((conn.conn_type, f'{login}{host}', schema, '', ''))
conn.schema = self.__schema or conn.schema
return conn.get_uri()

def get_sqlalchemy_engine(self, engine_kwargs=None):
"""
Expand Down
6 changes: 6 additions & 0 deletions airflow/models/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,12 @@ def _parse_from_uri(self, uri: str):

def get_uri(self) -> str:
"""Return connection in URI format"""
if '_' in self.conn_type:
self.log.warning(
f"Connection schemes (type: {str(self.conn_type)}) "
f"shall not contain '_' according to RFC3986."
)

uri = f"{str(self.conn_type).lower().replace('_', '-')}://"

authority_block = ''
Expand Down
8 changes: 0 additions & 8 deletions airflow/providers/mysql/hooks/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,6 @@ def get_conn(self) -> MySQLConnectionTypes:

raise ValueError('Unknown MySQL client name provided!')

def get_uri(self) -> str:
conn = self.get_connection(getattr(self, self.conn_name_attr))
uri = super().get_uri()
if conn.extra_dejson.get('charset', False):
charset = conn.extra_dejson["charset"]
return f"{uri}?charset={charset}"
return uri

def bulk_load(self, table: str, tmp_file: str) -> None:
"""Loads a tab-delimited file into a database table"""
conn = self.get_conn()
Expand Down
8 changes: 4 additions & 4 deletions airflow/providers/postgres/hooks/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,11 @@ def copy_expert(self, sql: str, filename: str) -> None:
conn.commit()

def get_uri(self) -> str:
conn = self.get_connection(getattr(self, self.conn_name_attr))
"""
Extract the URI from the connection.
:return: the extracted uri.
"""
uri = super().get_uri().replace("postgres://", "postgresql://")
if conn.extra_dejson.get('client_encoding', False):
charset = conn.extra_dejson["client_encoding"]
return f"{uri}?client_encoding={charset}"
return uri

def bulk_load(self, table: str, tmp_file: str) -> None:
Expand Down
61 changes: 51 additions & 10 deletions tests/hooks/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,49 +150,90 @@ def test_insert_rows_commit_every(self):
def test_get_uri_schema_not_none(self):
self.db_hook.get_connection = mock.MagicMock(
return_value=Connection(
conn_type="conn_type",
conn_type="conn-type",
host="host",
login="login",
password="password",
schema="schema",
port=1,
)
)
assert "conn_type://login:password@host:1/schema" == self.db_hook.get_uri()
assert "conn-type://login:password@host:1/schema" == self.db_hook.get_uri()

def test_get_uri_schema_override(self):
self.db_hook_schema_override.get_connection = mock.MagicMock(
return_value=Connection(
conn_type="conn_type",
conn_type="conn-type",
host="host",
login="login",
password="password",
schema="schema",
port=1,
)
)
assert "conn_type://login:password@host:1/schema-override" == self.db_hook_schema_override.get_uri()
assert "conn-type://login:password@host:1/schema-override" == self.db_hook_schema_override.get_uri()

def test_get_uri_schema_none(self):
self.db_hook.get_connection = mock.MagicMock(
return_value=Connection(
conn_type="conn_type", host="host", login="login", password="password", schema=None, port=1
conn_type="conn-type", host="host", login="login", password="password", schema=None, port=1
)
)
assert "conn_type://login:password@host:1" == self.db_hook.get_uri()
assert "conn-type://login:password@host:1" == self.db_hook.get_uri()

def test_get_uri_special_characters(self):
self.db_hook.get_connection = mock.MagicMock(
return_value=Connection(
conn_type="conn_type",
conn_type="conn-type",
host="host/",
login="lo/gi#! n",
password="pass*! word/",
schema="schema/",
port=1,
)
)
assert (
"conn-type://lo%2Fgi%23%21%20n:pass%2A%21%20word%2F@host%2F:1/schema%2F" == self.db_hook.get_uri()
)

def test_get_uri_login_none(self):
self.db_hook.get_connection = mock.MagicMock(
return_value=Connection(
conn_type="conn-type",
host="host",
login=None,
password="password",
schema="schema",
port=1,
)
)
assert "conn-type://:password@host:1/schema" == self.db_hook.get_uri()

def test_get_uri_password_none(self):
self.db_hook.get_connection = mock.MagicMock(
return_value=Connection(
conn_type="conn-type",
host="host",
login="login",
password=None,
schema="schema",
port=1,
)
)
assert "conn-type://login@host:1/schema" == self.db_hook.get_uri()

def test_get_uri_authority_none(self):
self.db_hook.get_connection = mock.MagicMock(
return_value=Connection(
conn_type="conn-type",
host="host",
login="logi#! n",
password="pass*! word",
login=None,
password=None,
schema="schema",
port=1,
)
)
assert "conn_type://logi%23%21+n:pass%2A%21+word@host:1/schema" == self.db_hook.get_uri()
assert "conn-type://host:1/schema" == self.db_hook.get_uri()

def test_run_log(self):
statement = 'SQL'
Expand Down
1 change: 1 addition & 0 deletions tests/providers/amazon/aws/hooks/test_base_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ def test_get_credentials_from_gcp_credentials(self):
}
)
)
mock_connection.conn_type = 'aws'

# Store original __import__
orig_import = __import__
Expand Down
1 change: 1 addition & 0 deletions tests/providers/snowflake/hooks/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

BASE_CONNECTION_KWARGS: Dict = {
'login': 'user',
'conn_type': 'snowflake',
'password': 'pw',
'schema': 'public',
'extra': {
Expand Down

0 comments on commit 59c450e

Please sign in to comment.