Skip to content

Commit

Permalink
Add support for private key in connection for Snowflake (apache#22266)
Browse files Browse the repository at this point in the history
  • Loading branch information
mik-laj authored Mar 15, 2022
1 parent a66c072 commit d6ed9cb
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 24 deletions.
66 changes: 46 additions & 20 deletions airflow/providers/snowflake/hooks/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os
from contextlib import closing
from io import StringIO
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Sequence, Union

from cryptography.hazmat.backends import default_backend
Expand All @@ -28,6 +29,7 @@
from snowflake.sqlalchemy import URL
from sqlalchemy import create_engine

from airflow import AirflowException
from airflow.hooks.dbapi import DbApiHook
from airflow.utils.strings import to_boolean

Expand All @@ -44,8 +46,7 @@ class SnowflakeHook(DbApiHook):
This hook requires the snowflake_conn_id connection. The snowflake host, login,
and, password field must be setup in the connection. Other inputs can be defined
in the connection or hook instantiation. If used with the S3ToSnowflakeOperator
add 'aws_access_key_id' and 'aws_secret_access_key' to extra field in the connection.
in the connection or hook instantiation.
:param snowflake_conn_id: Reference to
:ref:`Snowflake connection id<howto/connection:snowflake>`
Expand All @@ -55,7 +56,7 @@ class SnowflakeHook(DbApiHook):
'externalbrowser' to authenticate using your web browser and
Okta, ADFS or any other SAML 2.0-compliant identify provider
(IdP) that has been defined for your account
'https://<your_okta_account_name>.okta.com' to authenticate
``https://<your_okta_account_name>.okta.com`` to authenticate
through native Okta.
:param warehouse: name of snowflake warehouse
:param database: name of snowflake database
Expand All @@ -69,7 +70,7 @@ class SnowflakeHook(DbApiHook):
<https://community.snowflake.com/s/article/How-to-turn-off-OCSP-checking-in-Snowflake-client-drivers>`__
.. note::
get_sqlalchemy_engine() depends on snowflake-sqlalchemy
``get_sqlalchemy_engine()`` depends on ``snowflake-sqlalchemy``
.. seealso::
For more information on how to use this Snowflake connection, take a look at the guide:
Expand All @@ -85,7 +86,7 @@ class SnowflakeHook(DbApiHook):
@staticmethod
def get_connection_form_widgets() -> Dict[str, Any]:
"""Returns connection widgets to add to connection form"""
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
from flask_appbuilder.fieldwidgets import BS3PasswordFieldWidget, BS3TextFieldWidget
from flask_babel import lazy_gettext
from wtforms import BooleanField, StringField

Expand All @@ -97,6 +98,12 @@ def get_connection_form_widgets() -> Dict[str, Any]:
"extra__snowflake__database": StringField(lazy_gettext('Database'), widget=BS3TextFieldWidget()),
"extra__snowflake__region": StringField(lazy_gettext('Region'), widget=BS3TextFieldWidget()),
"extra__snowflake__role": StringField(lazy_gettext('Role'), widget=BS3TextFieldWidget()),
"extra__snowflake__private_key_file": StringField(
lazy_gettext('Private key (Path)'), widget=BS3TextFieldWidget()
),
"extra__snowflake__private_key_content": StringField(
lazy_gettext('Private key (Text)'), widget=BS3PasswordFieldWidget()
),
"extra__snowflake__insecure_mode": BooleanField(
label=lazy_gettext('Insecure mode'), description="Turns off OCSP certificate checks"
),
Expand Down Expand Up @@ -127,6 +134,8 @@ def get_ui_field_behaviour() -> Dict[str, Any]:
'extra__snowflake__database': 'snowflake db name',
'extra__snowflake__region': 'snowflake hosted region',
'extra__snowflake__role': 'snowflake role',
'extra__snowflake__private_key_file': 'Path of snowflake private key (PEM Format)',
'extra__snowflake__private_key_content': 'Content to snowflake private key (PEM format)',
'extra__snowflake__insecure_mode': 'insecure mode',
},
}
Expand Down Expand Up @@ -186,21 +195,38 @@ def _get_conn_params(self) -> Dict[str, Optional[str]]:
if insecure_mode:
conn_config['insecure_mode'] = insecure_mode

# If private_key_file is specified in the extra json, load the contents of the file as a private
# key and specify that in the connection configuration. The connection password then becomes the
# passphrase for the private key. If your private key file is not encrypted (not recommended), then
# leave the password empty.

private_key_file = conn.extra_dejson.get('private_key_file')
if private_key_file:
with open(private_key_file, "rb") as key:
passphrase = None
if conn.password:
passphrase = conn.password.strip().encode()

p_key = serialization.load_pem_private_key(
key.read(), password=passphrase, backend=default_backend()
)
# If private_key_file is specified in the extra json, load the contents of the file as a private key.
# If private_key_content is specified in the extra json, use it as a private key.
# As a next step, specify this private key in the connection configuration.
# The connection password then becomes the passphrase for the private key.
# If your private key is not encrypted (not recommended), then leave the password empty.

private_key_file = conn.extra_dejson.get(
'extra__snowflake__private_key_file'
) or conn.extra_dejson.get('private_key_file')
private_key_content = conn.extra_dejson.get(
'extra__snowflake__private_key_content'
) or conn.extra_dejson.get('private_key_content')

private_key_pem = None
if private_key_content and private_key_file:
raise AirflowException(
"The private_key_file and private_key_content extra fields are mutually exclusive. "
"Please remove one."
)
elif private_key_file:
private_key_pem = Path(private_key_file).read_bytes()
elif private_key_content:
private_key_pem = private_key_content.encode()

if private_key_pem:
passphrase = None
if conn.password:
passphrase = conn.password.strip().encode()

p_key = serialization.load_pem_private_key(
private_key_pem, password=passphrase, backend=default_backend()
)

pkb = p_key.private_bytes(
encoding=serialization.Encoding.DER,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Login
Specify the snowflake username.

Password
Specify the snowflake password.
Specify the snowflake password. For public key authentication, the passphrase for the private key.

Host (optional)
Specify the snowflake hostname.
Expand All @@ -61,6 +61,7 @@ Extra (optional)
* ``role``: Snowflake role.
* ``authenticator``: To connect using OAuth set this parameter ``oath``.
* ``private_key_file``: Specify the path to the private key file.
* ``private_key_content``: Specify the content of the private key file.
* ``session_parameters``: Specify `session level parameters <https://docs.snowflake.com/en/user-guide/python-connector-example.html#setting-session-parameters>`_.
* ``insecure_mode``: Turn off OCSP certificate checks. For details, see: `How To: Turn Off OCSP Checking in Snowflake Client Drivers - Snowflake Community <https://community.snowflake.com/s/article/How-to-turn-off-OCSP-checking-in-Snowflake-client-drivers>`_.

Expand Down
26 changes: 23 additions & 3 deletions tests/providers/snowflake/hooks/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import unittest
from copy import deepcopy
from pathlib import Path
from typing import Dict
from typing import Any, Dict
from unittest import mock

import pytest
Expand Down Expand Up @@ -48,7 +48,7 @@


@pytest.fixture()
def non_encrypted_temporary_private_key(tmp_path: Path):
def non_encrypted_temporary_private_key(tmp_path: Path) -> Path:
key = rsa.generate_private_key(backend=default_backend(), public_exponent=65537, key_size=2048)
private_key = key.private_bytes(
serialization.Encoding.PEM, serialization.PrivateFormat.PKCS8, serialization.NoEncryption()
Expand All @@ -59,7 +59,7 @@ def non_encrypted_temporary_private_key(tmp_path: Path):


@pytest.fixture()
def encrypted_temporary_private_key(tmp_path: Path):
def encrypted_temporary_private_key(tmp_path: Path) -> Path:
key = rsa.generate_private_key(backend=default_backend(), public_exponent=65537, key_size=2048)
private_key = key.private_bytes(
serialization.Encoding.PEM,
Expand Down Expand Up @@ -271,6 +271,26 @@ def test_hook_should_support_prepare_basic_conn_params_and_uri(
assert SnowflakeHook(snowflake_conn_id='test_conn').get_uri() == expected_uri
assert SnowflakeHook(snowflake_conn_id='test_conn')._get_conn_params() == expected_conn_params

def test_get_conn_params_should_support_private_auth_in_connection(
self, encrypted_temporary_private_key: Path
):
connection_kwargs: Any = {
**BASE_CONNECTION_KWARGS,
'password': _PASSWORD,
'extra': {
'database': 'db',
'account': 'airflow',
'warehouse': 'af_wh',
'region': 'af_region',
'role': 'af_role',
'private_key_content': str(encrypted_temporary_private_key.read_text()),
},
}
with unittest.mock.patch.dict(
'os.environ', AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()
):
assert 'private_key' in SnowflakeHook(snowflake_conn_id='test_conn')._get_conn_params()

def test_get_conn_params_should_support_private_auth_with_encrypted_key(
self, encrypted_temporary_private_key
):
Expand Down

0 comments on commit d6ed9cb

Please sign in to comment.