Skip to content

Commit

Permalink
Makes changes to SqlToS3Operator method _fix_int_dtypes (apache#25083)
Browse files Browse the repository at this point in the history
Convert dataframe object columns to str, to avoid errors when converting from df to parquet.

Renamed methods to remove old name:
_fix_int_dtypes -> _fix_dtypes
test_fix_int_dtypes -> test_fix_dtypes

Co-authored-by: Paul Stanton <[email protected]>
  • Loading branch information
pastanton and pastanton authored Jul 18, 2022
1 parent fd6f537 commit 693fe60
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 7 deletions.
15 changes: 12 additions & 3 deletions airflow/providers/amazon/aws/transfers/sql_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,18 @@ def __init__(
raise AirflowException(f"The argument file_format doesn't support {file_format} value.")

@staticmethod
def _fix_int_dtypes(df: pd.DataFrame) -> None:
"""Mutate DataFrame to set dtypes for int columns containing NaN values."""
def _fix_dtypes(df: pd.DataFrame) -> None:
"""
Mutate DataFrame to set dtypes for float columns containing NaN values.
Set dtype of object to str to allow for downstream transformations.
"""
for col in df:

if df[col].dtype.name == 'object':
# if the type wasn't identified or converted, change it to a string so if can still be
# processed.
df[col] = df[col].astype(str)

if "float" in df[col].dtype.name and df[col].hasnans:
# inspect values to determine if dtype of non-null values is int or float
notna_series = df[col].dropna().values
Expand All @@ -148,7 +157,7 @@ def execute(self, context: 'Context') -> None:
data_df = sql_hook.get_pandas_df(sql=self.query, parameters=self.parameters)
self.log.info("Data from SQL obtained")

self._fix_int_dtypes(data_df)
self._fix_dtypes(data_df)
file_options = FILE_OPTIONS_MAP[self.file_format]

with NamedTemporaryFile(mode=file_options.mode, suffix=file_options.suffix) as tmp_file:
Expand Down
4 changes: 2 additions & 2 deletions tests/providers/amazon/aws/transfers/test_mysql_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,10 @@ def test_execute_parquet(self, mock_s3_hook, temp_mock):
filename=f.name, key=s3_key, bucket_name=s3_bucket, replace=False
)

def test_fix_int_dtypes(self):
def test_fix_dtypes(self):
from airflow.providers.amazon.aws.transfers.mysql_to_s3 import MySQLToS3Operator

op = MySQLToS3Operator(query="query", s3_bucket="s3_bucket", s3_key="s3_key", task_id="task_id")
dirty_df = pd.DataFrame({"strings": ["a", "b", "c"], "ints": [1, 2, None]})
op._fix_int_dtypes(df=dirty_df)
op._fix_dtypes(df=dirty_df)
assert dirty_df["ints"].dtype.kind == "i"
4 changes: 2 additions & 2 deletions tests/providers/amazon/aws/transfers/test_sql_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def test_execute_json(self, mock_s3_hook, temp_mock):
replace=True,
)

def test_fix_int_dtypes(self):
def test_fix_dtypes(self):
op = SqlToS3Operator(
query="query",
s3_bucket="s3_bucket",
Expand All @@ -153,7 +153,7 @@ def test_fix_int_dtypes(self):
sql_conn_id="mysql_conn_id",
)
dirty_df = pd.DataFrame({"strings": ["a", "b", "c"], "ints": [1, 2, None]})
op._fix_int_dtypes(df=dirty_df)
op._fix_dtypes(df=dirty_df)
assert dirty_df["ints"].dtype.kind == "i"

def test_invalid_file_format(self):
Expand Down

0 comments on commit 693fe60

Please sign in to comment.