diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 185015d5ea0e0..02ba5c32735aa 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -453,7 +453,7 @@ Here is the list of packages and their extras: ========================== =========================== Package Extras ========================== =========================== -amazon apache.hive,google,imap,mongo,postgres,ssh +amazon apache.hive,google,imap,mongo,mysql,postgres,ssh apache.druid apache.hive apache.hive amazon,microsoft.mssql,mysql,presto,samba,vertica apache.livy http diff --git a/airflow/providers/amazon/aws/operators/mysql_to_s3.py b/airflow/providers/amazon/aws/operators/mysql_to_s3.py new file mode 100644 index 0000000000000..748e0f99a844e --- /dev/null +++ b/airflow/providers/amazon/aws/operators/mysql_to_s3.py @@ -0,0 +1,125 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import tempfile +from typing import Optional, Union + +import numpy as np +import pandas as pd + +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.s3 import S3Hook +from airflow.providers.mysql.hooks.mysql import MySqlHook +from airflow.utils.decorators import apply_defaults + + +class MySQLToS3Operator(BaseOperator): + """ + Saves data from an specific MySQL query into a file in S3. + + :param query: the sql query to be executed. If you want to execute a file, place the absolute path of it, + ending with .sql extension. + :type query: str + :param s3_bucket: bucket where the data will be stored + :type s3_bucket: str + :param s3_key: desired key for the file. It includes the name of the file + :type s3_key: str + :param mysql_conn_id: reference to a specific mysql database + :type mysql_conn_id: str + :param aws_conn_id: reference to a specific S3 connection + :type aws_conn_id: str + :param verify: Whether or not to verify SSL certificates for S3 connection. + By default SSL certificates are verified. + You can provide the following values: + + - ``False``: do not validate SSL certificates. SSL will still be used + (unless use_ssl is False), but SSL certificates will not be verified. + - ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses. + You can specify this argument if you want to use a different + CA cert bundle than the one used by botocore. + :type verify: bool or str + :param pd_csv_kwargs: arguments to include in pd.to_csv (header, index, columns...) + :type pd_csv_kwargs: dict + :param index: whether to have the index or not in the dataframe + :type index: str + :param header: whether to include header or not into the S3 file + :type header: bool + """ + + template_fields = ('s3_key', 'query',) + template_ext = ('.sql',) + + @apply_defaults + def __init__( + self, + query: str, + s3_bucket: str, + s3_key: str, + mysql_conn_id: str = 'mysql_default', + aws_conn_id: str = 'aws_default', + verify: Optional[Union[bool, str]] = None, + pd_csv_kwargs: Optional[dict] = None, + index: Optional[bool] = False, + header: Optional[bool] = False, + *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self.query = query + self.s3_bucket = s3_bucket + self.s3_key = s3_key + self.mysql_conn_id = mysql_conn_id + self.aws_conn_id = aws_conn_id + self.verify = verify + + self.pd_csv_kwargs = pd_csv_kwargs or {} + if "path_or_buf" in self.pd_csv_kwargs: + raise AirflowException('The argument path_or_buf is not allowed, please remove it') + if "index" not in self.pd_csv_kwargs: + self.pd_csv_kwargs["index"] = index + if "header" not in self.pd_csv_kwargs: + self.pd_csv_kwargs["header"] = header + + def _fix_int_dtypes(self, df): + """ + Mutate DataFrame to set dtypes for int columns containing NaN values." + """ + for col in df: + if "float" in df[col].dtype.name and df[col].hasnans: + # inspect values to determine if dtype of non-null values is int or float + notna_series = df[col].dropna().values + if np.isclose(notna_series, notna_series.astype(int)).all(): + # set to dtype that retains integers and supports NaNs + df[col] = np.where(df[col].isnull(), None, df[col]).astype(pd.Int64Dtype) + + def execute(self, context): + mysql_hook = MySqlHook(mysql_conn_id=self.mysql_conn_id) + s3_conn = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify) + data_df = mysql_hook.get_pandas_df(self.query) + self.log.info("Data from MySQL obtained") + + self._fix_int_dtypes(data_df) + with tempfile.NamedTemporaryFile(mode='r+', suffix='.csv') as tmp_csv: + data_df.to_csv(tmp_csv.name, **self.pd_csv_kwargs) + s3_conn.load_file(filename=tmp_csv.name, + key=self.s3_key, + bucket_name=self.s3_bucket) + + if s3_conn.check_for_key(self.s3_key, bucket_name=self.s3_bucket): + file_location = os.path.join(self.s3_bucket, self.s3_key) + self.log.info("File saved correctly in %s", file_location) diff --git a/airflow/providers/dependencies.json b/airflow/providers/dependencies.json index 2c0b2ead64659..90ce43481c960 100644 --- a/airflow/providers/dependencies.json +++ b/airflow/providers/dependencies.json @@ -4,6 +4,7 @@ "google", "imap", "mongo", + "mysql", "postgres", "ssh" ], diff --git a/docs/operators-and-hooks-ref.rst b/docs/operators-and-hooks-ref.rst index cc76d7ff16c5c..909318ccafc5a 100644 --- a/docs/operators-and-hooks-ref.rst +++ b/docs/operators-and-hooks-ref.rst @@ -582,6 +582,11 @@ These integrations allow you to copy data from/to Amazon Web Services. - - :mod:`airflow.providers.amazon.aws.transfers.sftp_to_s3` + * - `MySQL `__ + - `Amazon Simple Storage Service (S3) `_ + - + - :mod:`airflow.providers.amazon.aws.operators.mysql_to_s3` + :ref:`[1] ` Those discovery-based operators use :class:`~airflow.providers.google.common.hooks.discovery_api.GoogleDiscoveryApiHook` to communicate with Google Services via the `Google API Python Client `__. diff --git a/tests/providers/amazon/aws/operators/test_mysql_to_s3.py b/tests/providers/amazon/aws/operators/test_mysql_to_s3.py new file mode 100644 index 0000000000000..8f8b9dc389482 --- /dev/null +++ b/tests/providers/amazon/aws/operators/test_mysql_to_s3.py @@ -0,0 +1,61 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +import unittest +from unittest import mock + +import pandas as pd + +from airflow.providers.amazon.aws.operators.mysql_to_s3 import MySQLToS3Operator + + +class TestMySqlToS3Operator(unittest.TestCase): + + @mock.patch("airflow.providers.amazon.aws.operators.mysql_to_s3.tempfile.NamedTemporaryFile") + @mock.patch("airflow.providers.amazon.aws.operators.mysql_to_s3.S3Hook") + @mock.patch("airflow.providers.amazon.aws.operators.mysql_to_s3.MySqlHook") + def test_execute(self, mock_mysql_hook, mock_s3_hook, temp_mock): + query = "query" + s3_bucket = "bucket" + s3_key = "key" + + test_df = pd.DataFrame({'a': '1', 'b': '2'}, index=[0, 1]) + get_pandas_df_mock = mock_mysql_hook.return_value.get_pandas_df + get_pandas_df_mock.return_value = test_df + + op = MySQLToS3Operator(query=query, + s3_bucket=s3_bucket, + s3_key=s3_key, + mysql_conn_id="mysql_conn_id", + aws_conn_id="aws_conn_id", + task_id="task_id", + pd_csv_kwargs={'index': False, 'header': False}, + dag=None + ) + op.execute(None) + mock_mysql_hook.assert_called_once_with(mysql_conn_id="mysql_conn_id") + mock_s3_hook.assert_called_once_with(aws_conn_id="aws_conn_id", verify=None) + + get_pandas_df_mock.assert_called_once_with(query) + + temp_mock.assert_called_once_with(mode='r+', suffix=".csv") + filename = "file" + temp_mock.return_value.__enter__.return_value.name = mock.PropertyMock(return_value=filename) + mock_s3_hook.return_value.load_file.assert_called_once_with(filename=filename, + key=s3_key, + bucket_name=s3_bucket)