Skip to content

Commit

Permalink
[AIRFLOW-2826] Add GoogleCloudKMSHook (apache#3677)
Browse files Browse the repository at this point in the history
Adds a hook enabling encryption and decryption through Google Cloud KMS.
This should also contribute to AIRFLOW-2062.
  • Loading branch information
jakahn authored and Fokko committed Aug 8, 2018
1 parent 6fd4e60 commit acca61c
Show file tree
Hide file tree
Showing 2 changed files with 268 additions and 0 deletions.
108 changes: 108 additions & 0 deletions airflow/contrib/hooks/gcp_kms_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#

import base64

from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook

from apiclient.discovery import build


def _b64encode(s):
""" Base 64 encodes a bytes object to a string """
return base64.b64encode(s).decode('ascii')


def _b64decode(s):
""" Base 64 decodes a string to bytes. """
return base64.b64decode(s.encode('utf-8'))


class GoogleCloudKMSHook(GoogleCloudBaseHook):
"""
Interact with Google Cloud KMS. This hook uses the Google Cloud Platform
connection.
"""

def __init__(self, gcp_conn_id='google_cloud_default', delegate_to=None):
super(GoogleCloudKMSHook, self).__init__(gcp_conn_id, delegate_to=delegate_to)

def get_conn(self):
"""
Returns a KMS service object.
:rtype: apiclient.discovery.Resource
"""
http_authorized = self._authorize()
return build(
'cloudkms', 'v1', http=http_authorized, cache_discovery=False)

def encrypt(self, key_name, plaintext, authenticated_data=None):
"""
Encrypts a plaintext message using Google Cloud KMS.
:param key_name: The Resource Name for the key (or key version)
to be used for encyption. Of the form
``projects/*/locations/*/keyRings/*/cryptoKeys/**``
:type key_name: str
:param plaintext: The message to be encrypted.
:type plaintext: bytes
:param authenticated_data: Optional additional authenticated data that
must also be provided to decrypt the message.
:type authenticated_data: bytes
:return: The base 64 encoded ciphertext of the original message.
:rtype: str
"""
keys = self.get_conn().projects().locations().keyRings().cryptoKeys()
body = {'plaintext': _b64encode(plaintext)}
if authenticated_data:
body['additionalAuthenticatedData'] = _b64encode(authenticated_data)

request = keys.encrypt(name=key_name, body=body)
response = request.execute()

ciphertext = response['ciphertext']
return ciphertext

def decrypt(self, key_name, ciphertext, authenticated_data=None):
"""
Decrypts a ciphertext message using Google Cloud KMS.
:param key_name: The Resource Name for the key to be used for decyption.
Of the form ``projects/*/locations/*/keyRings/*/cryptoKeys/**``
:type key_name: str
:param ciphertext: The message to be decrypted.
:type ciphertext: str
:param authenticated_data: Any additional authenticated data that was
provided when encrypting the message.
:type authenticated_data: bytes
:return: The original message.
:rtype: bytes
"""
keys = self.get_conn().projects().locations().keyRings().cryptoKeys()
body = {'ciphertext': ciphertext}
if authenticated_data:
body['additionalAuthenticatedData'] = _b64encode(authenticated_data)

request = keys.decrypt(name=key_name, body=body)
response = request.execute()

plaintext = _b64decode(response['plaintext'])
return plaintext
160 changes: 160 additions & 0 deletions tests/contrib/hooks/test_gcp_kms_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#

from __future__ import unicode_literals

import unittest
from base64 import b64encode

from airflow.contrib.hooks.gcp_kms_hook import GoogleCloudKMSHook

try:
from unittest import mock
except ImportError:
try:
import mock
except ImportError:
mock = None

BASE_STRING = 'airflow.contrib.hooks.gcp_api_base_hook.{}'
KMS_STRING = 'airflow.contrib.hooks.gcp_kms_hook.{}'


TEST_PROJECT = 'test-project'
TEST_LOCATION = 'global'
TEST_KEY_RING = 'test-key-ring'
TEST_KEY = 'test-key'
TEST_KEY_ID = 'projects/{}/locations/{}/keyRings/{}/cryptoKeys/{}'.format(
TEST_PROJECT, TEST_LOCATION, TEST_KEY_RING, TEST_KEY)


def mock_init(self, gcp_conn_id, delegate_to=None):
pass


class GoogleCloudKMSHookTest(unittest.TestCase):
def setUp(self):
with mock.patch(BASE_STRING.format('GoogleCloudBaseHook.__init__'),
new=mock_init):
self.kms_hook = GoogleCloudKMSHook(gcp_conn_id='test')

@mock.patch(KMS_STRING.format('GoogleCloudKMSHook.get_conn'))
def test_encrypt(self, mock_service):
plaintext = b'Test plaintext'
ciphertext = 'Test ciphertext'
plaintext_b64 = b64encode(plaintext).decode('ascii')
body = {'plaintext': plaintext_b64}
response = {'ciphertext': ciphertext}

encrypt_method = (mock_service.return_value
.projects.return_value
.locations.return_value
.keyRings.return_value
.cryptoKeys.return_value
.encrypt)
execute_method = encrypt_method.return_value.execute
execute_method.return_value = response

ret_val = self.kms_hook.encrypt(TEST_KEY_ID, plaintext)
encrypt_method.assert_called_with(name=TEST_KEY_ID,
body=body)
execute_method.assert_called_with()
self.assertEqual(ciphertext, ret_val)

@mock.patch(KMS_STRING.format('GoogleCloudKMSHook.get_conn'))
def test_encrypt_authdata(self, mock_service):
plaintext = b'Test plaintext'
auth_data = b'Test authdata'
ciphertext = 'Test ciphertext'
plaintext_b64 = b64encode(plaintext).decode('ascii')
auth_data_b64 = b64encode(auth_data).decode('ascii')
body = {
'plaintext': plaintext_b64,
'additionalAuthenticatedData': auth_data_b64
}
response = {'ciphertext': ciphertext}

encrypt_method = (mock_service.return_value
.projects.return_value
.locations.return_value
.keyRings.return_value
.cryptoKeys.return_value
.encrypt)
execute_method = encrypt_method.return_value.execute
execute_method.return_value = response

ret_val = self.kms_hook.encrypt(TEST_KEY_ID, plaintext,
authenticated_data=auth_data)
encrypt_method.assert_called_with(name=TEST_KEY_ID,
body=body)
execute_method.assert_called_with()
self.assertEqual(ciphertext, ret_val)

@mock.patch(KMS_STRING.format('GoogleCloudKMSHook.get_conn'))
def test_decrypt(self, mock_service):
plaintext = b'Test plaintext'
ciphertext = 'Test ciphertext'
plaintext_b64 = b64encode(plaintext).decode('ascii')
body = {'ciphertext': ciphertext}
response = {'plaintext': plaintext_b64}

decrypt_method = (mock_service.return_value
.projects.return_value
.locations.return_value
.keyRings.return_value
.cryptoKeys.return_value
.decrypt)
execute_method = decrypt_method.return_value.execute
execute_method.return_value = response

ret_val = self.kms_hook.decrypt(TEST_KEY_ID, ciphertext)
decrypt_method.assert_called_with(name=TEST_KEY_ID,
body=body)
execute_method.assert_called_with()
self.assertEqual(plaintext, ret_val)

@mock.patch(KMS_STRING.format('GoogleCloudKMSHook.get_conn'))
def test_decrypt_authdata(self, mock_service):
plaintext = b'Test plaintext'
auth_data = b'Test authdata'
ciphertext = 'Test ciphertext'
plaintext_b64 = b64encode(plaintext).decode('ascii')
auth_data_b64 = b64encode(auth_data).decode('ascii')
body = {
'ciphertext': ciphertext,
'additionalAuthenticatedData': auth_data_b64
}
response = {'plaintext': plaintext_b64}

decrypt_method = (mock_service.return_value
.projects.return_value
.locations.return_value
.keyRings.return_value
.cryptoKeys.return_value
.decrypt)
execute_method = decrypt_method.return_value.execute
execute_method.return_value = response

ret_val = self.kms_hook.decrypt(TEST_KEY_ID, ciphertext,
authenticated_data=auth_data)
decrypt_method.assert_called_with(name=TEST_KEY_ID,
body=body)
execute_method.assert_called_with()
self.assertEqual(plaintext, ret_val)

0 comments on commit acca61c

Please sign in to comment.