Skip to content

Commit

Permalink
Add truncate table (before copy) option to S3ToRedshiftOperator (apac…
Browse files Browse the repository at this point in the history
…he#9246)

- add table arg to jinja template fields
- change ui_color

Co-authored-by: javier.lopez <[email protected]>
  • Loading branch information
JavierLopezT and JavierLTPromofarma authored Oct 28, 2020
1 parent 555c574 commit db121f7
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 21 deletions.
42 changes: 26 additions & 16 deletions airflow/providers/amazon/aws/transfers/s3_to_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,16 @@ class S3ToRedshiftOperator(BaseOperator):
:type verify: bool or str
:param copy_options: reference to a list of COPY options
:type copy_options: list
:param truncate_table: whether or not to truncate the destination table before the copy
:type truncate_table: bool
"""

template_fields = ('s3_key',)
template_fields = (
's3_key',
'table',
)
template_ext = ()
ui_color = '#ededed'
ui_color = '#99e699'

@apply_defaults
def __init__(
Expand All @@ -75,6 +80,7 @@ def __init__(
verify: Optional[Union[bool, str]] = None,
copy_options: Optional[List] = None,
autocommit: bool = False,
truncate_table: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -87,29 +93,33 @@ def __init__(
self.verify = verify
self.copy_options = copy_options or []
self.autocommit = autocommit
self.truncate_table = truncate_table

def execute(self, context) -> None:
postgres_hook = PostgresHook(postgres_conn_id=self.redshift_conn_id)
s3_hook = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
credentials = s3_hook.get_credentials()
copy_options = '\n\t\t\t'.join(self.copy_options)

copy_query = """
COPY {schema}.{table}
FROM 's3://{s3_bucket}/{s3_key}'
copy_statement = f"""
COPY {self.schema}.{self.table}
FROM 's3://{self.s3_bucket}/{self.s3_key}'
with credentials
'aws_access_key_id={access_key};aws_secret_access_key={secret_key}'
'aws_access_key_id={credentials.access_key};aws_secret_access_key={credentials.secret_key}'
{copy_options};
""".format(
schema=self.schema,
table=self.table,
s3_bucket=self.s3_bucket,
s3_key=self.s3_key,
access_key=credentials.access_key,
secret_key=credentials.secret_key,
copy_options=copy_options,
)
"""

if self.truncate_table:
truncate_statement = f'TRUNCATE TABLE {self.schema}.{self.table};'
sql = f"""
BEGIN;
{truncate_statement}
{copy_statement}
COMMIT
"""
else:
sql = copy_statement

self.log.info('Executing COPY command...')
postgres_hook.run(copy_query, self.autocommit)
postgres_hook.run(sql, self.autocommit)
self.log.info("COPY command complete...")
48 changes: 43 additions & 5 deletions tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,21 +53,59 @@ def test_execute(self, mock_run, mock_session):
)
op.execute(None)

copy_query = """
copy_query = f"""
COPY {schema}.{table}
FROM 's3://{s3_bucket}/{s3_key}'
with credentials
'aws_access_key_id={access_key};aws_secret_access_key={secret_key}'
{copy_options};
""".format(
"""

assert mock_run.call_count == 1
assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0], copy_query)

@mock.patch("boto3.session.Session")
@mock.patch("airflow.providers.postgres.hooks.postgres.PostgresHook.run")
def test_truncate(self, mock_run, mock_session):
access_key = "aws_access_key_id"
secret_key = "aws_secret_access_key"
mock_session.return_value = Session(access_key, secret_key)

schema = "schema"
table = "table"
s3_bucket = "bucket"
s3_key = "key"
copy_options = ""

op = S3ToRedshiftOperator(
schema=schema,
table=table,
s3_bucket=s3_bucket,
s3_key=s3_key,
access_key=access_key,
secret_key=secret_key,
copy_options=copy_options,
truncate_table=True,
redshift_conn_id="redshift_conn_id",
aws_conn_id="aws_conn_id",
task_id="task_id",
dag=None,
)
op.execute(None)

copy_statement = f"""
COPY {schema}.{table}
FROM 's3://{s3_bucket}/{s3_key}'
with credentials
'aws_access_key_id={access_key};aws_secret_access_key={secret_key}'
{copy_options};
"""

truncate_statement = f'TRUNCATE TABLE {schema}.{table};'
transaction = f"""
BEGIN;
{truncate_statement}
{copy_statement}
COMMIT
"""
assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0], transaction)

assert mock_run.call_count == 1
assert_equal_ignore_multiple_spaces(self, mock_run.call_args[0][0], copy_query)

0 comments on commit db121f7

Please sign in to comment.