Skip to content

Commit

Permalink
Add aws_conn_id to DynamoDBToS3Operator (apache#20363)
Browse files Browse the repository at this point in the history
  • Loading branch information
pingzh authored Dec 18, 2021
1 parent e9dfe0b commit 5769def
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 5 deletions.
22 changes: 17 additions & 5 deletions airflow/providers/amazon/aws/transfers/dynamodb_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ def _convert_item_to_json_bytes(item: Dict[str, Any]) -> bytes:
return (json.dumps(item) + '\n').encode('utf-8')


def _upload_file_to_s3(file_obj: IO, bucket_name: str, s3_key_prefix: str) -> None:
s3_client = S3Hook().get_conn()
def _upload_file_to_s3(
file_obj: IO, bucket_name: str, s3_key_prefix: str, aws_conn_id: str = 'aws_default'
) -> None:
s3_client = S3Hook(aws_conn_id=aws_conn_id).get_conn()
file_obj.seek(0)
s3_client.upload_file(
Filename=file_obj.name,
Expand Down Expand Up @@ -94,6 +96,12 @@ class DynamoDBToS3Operator(BaseOperator):
:type s3_key_prefix: Optional[str]
:param process_func: How we transforms a dynamodb item to bytes. By default we dump the json
:type process_func: Callable[[Dict[str, Any]], bytes]
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is None or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:type aws_conn_id: str
"""

def __init__(
Expand All @@ -105,6 +113,7 @@ def __init__(
dynamodb_scan_kwargs: Optional[Dict[str, Any]] = None,
s3_key_prefix: str = '',
process_func: Callable[[Dict[str, Any]], bytes] = _convert_item_to_json_bytes,
aws_conn_id: str = 'aws_default',
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -114,9 +123,12 @@ def __init__(
self.dynamodb_scan_kwargs = dynamodb_scan_kwargs
self.s3_bucket_name = s3_bucket_name
self.s3_key_prefix = s3_key_prefix
self.aws_conn_id = aws_conn_id

def execute(self, context) -> None:
table = AwsDynamoDBHook().get_conn().Table(self.dynamodb_table_name)
hook = AwsDynamoDBHook(aws_conn_id=self.aws_conn_id)
table = hook.get_conn().Table(self.dynamodb_table_name)

scan_kwargs = copy(self.dynamodb_scan_kwargs) if self.dynamodb_scan_kwargs else {}
err = None
with NamedTemporaryFile() as f:
Expand All @@ -127,7 +139,7 @@ def execute(self, context) -> None:
raise e
finally:
if err is None:
_upload_file_to_s3(f, self.s3_bucket_name, self.s3_key_prefix)
_upload_file_to_s3(f, self.s3_bucket_name, self.s3_key_prefix, self.aws_conn_id)

def _scan_dynamodb_and_upload_to_s3(self, temp_file: IO, scan_kwargs: dict, table: Any) -> IO:
while True:
Expand All @@ -145,7 +157,7 @@ def _scan_dynamodb_and_upload_to_s3(self, temp_file: IO, scan_kwargs: dict, tabl

# Upload the file to S3 if reach file size limit
if getsize(temp_file.name) >= self.file_size:
_upload_file_to_s3(temp_file, self.s3_bucket_name, self.s3_key_prefix)
_upload_file_to_s3(temp_file, self.s3_bucket_name, self.s3_key_prefix, self.aws_conn_id)
temp_file.close()

temp_file = NamedTemporaryFile()
Expand Down
36 changes: 36 additions & 0 deletions tests/providers/amazon/aws/transfers/test_dynamodb_to_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,39 @@ def test_dynamodb_to_s3_success(self, mock_aws_dynamodb_hook, mock_s3_hook):
dynamodb_to_s3_operator.execute(context={})

assert [{'a': 1}, {'b': 2}, {'c': 3}] == self.output_queue

@patch('airflow.providers.amazon.aws.transfers.dynamodb_to_s3.S3Hook')
@patch('airflow.providers.amazon.aws.transfers.dynamodb_to_s3.AwsDynamoDBHook')
def test_dynamodb_to_s3_with_different_aws_conn_id(self, mock_aws_dynamodb_hook, mock_s3_hook):
responses = [
{
'Items': [{'a': 1}, {'b': 2}],
'LastEvaluatedKey': '123',
},
{
'Items': [{'c': 3}],
},
]
table = MagicMock()
table.return_value.scan.side_effect = responses
mock_aws_dynamodb_hook.return_value.get_conn.return_value.Table = table

s3_client = MagicMock()
s3_client.return_value.upload_file = self.mock_upload_file
mock_s3_hook.return_value.get_conn = s3_client

aws_conn_id = "test-conn-id"
dynamodb_to_s3_operator = DynamoDBToS3Operator(
task_id='dynamodb_to_s3',
dynamodb_table_name='airflow_rocks',
s3_bucket_name='airflow-bucket',
file_size=4000,
aws_conn_id=aws_conn_id,
)

dynamodb_to_s3_operator.execute(context={})

assert [{'a': 1}, {'b': 2}, {'c': 3}] == self.output_queue

mock_s3_hook.assert_called_with(aws_conn_id=aws_conn_id)
mock_aws_dynamodb_hook.assert_called_with(aws_conn_id=aws_conn_id)

0 comments on commit 5769def

Please sign in to comment.