Skip to content

Commit

Permalink
Apply "unify bucket and key" before "provide bucket" (apache#28710)
Browse files Browse the repository at this point in the history
If we unify first, then we (desirably) get the bucket from the key.  If we provide bucket first, something undesirable happens.  The bucket from the connection is used, which means we don't respect the full key provided. Further, the full key is not made relative, which would cause the actual request to fail. For this reason we want to put unify first.  Because in the case where the outcome is different, the previously the request would fail, this should not pose any backcompat concerns.

Co-authored-by: Felix Uellendall <[email protected]>
  • Loading branch information
dstandish and feluelle authored Jan 6, 2023
1 parent bda3918 commit 3eee33a
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 103 deletions.
6 changes: 6 additions & 0 deletions airflow/providers/amazon/aws/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
# under the License.
from __future__ import annotations

from airflow import AirflowException

# Note: Any AirflowException raised is expected to cause the TaskInstance
# to be marked in an ERROR state

Expand All @@ -42,3 +44,7 @@ def __init__(self, failures: list, message: str):

def __reduce__(self):
return EcsOperatorError, (self.failures, self.message)


class S3HookUriParseFailure(AirflowException):
"""When parse_s3_url fails to parse URL, this error is thrown."""
47 changes: 30 additions & 17 deletions airflow/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
import fnmatch
import gzip as gz
import io
import logging
import re
import shutil
from contextlib import suppress
from copy import deepcopy
from datetime import datetime
from functools import wraps
Expand All @@ -38,17 +40,22 @@
from botocore.exceptions import ClientError

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.exceptions import S3HookUriParseFailure
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.utils.helpers import chunks

T = TypeVar("T", bound=Callable)

logger = logging.getLogger(__name__)


def provide_bucket_name(func: T) -> T:
"""
Function decorator that provides a bucket name taken from the connection
in case no bucket name has been passed to the function.
"""
if hasattr(func, "_unify_bucket_name_and_key_wrapped"):
logger.warning("`unify_bucket_name_and_key` should wrap `provide_bucket_name`.")
function_signature = signature(func)

@wraps(func)
Expand Down Expand Up @@ -84,12 +91,18 @@ def wrapper(*args, **kwargs) -> T:
raise ValueError("Missing key parameter!")

if "bucket_name" not in bound_args.arguments:
bound_args.arguments["bucket_name"], bound_args.arguments[key_name] = S3Hook.parse_s3_url(
bound_args.arguments[key_name]
)
with suppress(S3HookUriParseFailure):
bound_args.arguments["bucket_name"], bound_args.arguments[key_name] = S3Hook.parse_s3_url(
bound_args.arguments[key_name]
)

return func(*bound_args.args, **bound_args.kwargs)

# set attr _unify_bucket_name_and_key_wrapped so that we can check at
# class definition that unify is the first decorator applied
# if provide_bucket_name is applied first, and there's a bucket defined in conn
# then if user supplies full key, bucket in key is not respected
wrapper._unify_bucket_name_and_key_wrapped = True # type: ignore[attr-defined]
return cast(T, wrapper)


Expand Down Expand Up @@ -153,7 +166,7 @@ def parse_s3_url(s3url: str) -> tuple[str, str]:
if re.match(r"s3[na]?:", format[0], re.IGNORECASE):
parsed_url = urlsplit(s3url)
if not parsed_url.netloc:
raise AirflowException(f'Please provide a bucket name using a valid format: "{s3url}"')
raise S3HookUriParseFailure(f'Please provide a bucket name using a valid format: "{s3url}"')

bucket_name = parsed_url.netloc
key = parsed_url.path.lstrip("/")
Expand All @@ -167,7 +180,7 @@ def parse_s3_url(s3url: str) -> tuple[str, str]:
bucket_name = temp_split[0]
key = "/".join(format[1].split("/")[1:])
else:
raise AirflowException(f'Please provide a bucket name using a valid format: "{s3url}"')
raise S3HookUriParseFailure(f'Please provide a bucket name using a valid format: "{s3url}"')
return bucket_name, key

@staticmethod
Expand Down Expand Up @@ -437,8 +450,8 @@ def get_file_metadata(
files += page["Contents"]
return files

@provide_bucket_name
@unify_bucket_name_and_key
@provide_bucket_name
def head_object(self, key: str, bucket_name: str | None = None) -> dict | None:
"""
Retrieves metadata of an object
Expand All @@ -455,8 +468,8 @@ def head_object(self, key: str, bucket_name: str | None = None) -> dict | None:
else:
raise e

@provide_bucket_name
@unify_bucket_name_and_key
@provide_bucket_name
def check_for_key(self, key: str, bucket_name: str | None = None) -> bool:
"""
Checks if a key exists in a bucket
Expand All @@ -468,8 +481,8 @@ def check_for_key(self, key: str, bucket_name: str | None = None) -> bool:
obj = self.head_object(key, bucket_name)
return obj is not None

@provide_bucket_name
@unify_bucket_name_and_key
@provide_bucket_name
def get_key(self, key: str, bucket_name: str | None = None) -> S3Transfer:
"""
Returns a boto3.s3.Object
Expand All @@ -488,8 +501,8 @@ def get_key(self, key: str, bucket_name: str | None = None) -> S3Transfer:
obj.load()
return obj

@provide_bucket_name
@unify_bucket_name_and_key
@provide_bucket_name
def read_key(self, key: str, bucket_name: str | None = None) -> str:
"""
Reads a key from S3
Expand All @@ -501,8 +514,8 @@ def read_key(self, key: str, bucket_name: str | None = None) -> str:
obj = self.get_key(key, bucket_name)
return obj.get()["Body"].read().decode("utf-8")

@provide_bucket_name
@unify_bucket_name_and_key
@provide_bucket_name
def select_key(
self,
key: str,
Expand Down Expand Up @@ -548,8 +561,8 @@ def select_key(
event["Records"]["Payload"] for event in response["Payload"] if "Records" in event
).decode("utf-8")

@provide_bucket_name
@unify_bucket_name_and_key
@provide_bucket_name
def check_for_wildcard_key(
self, wildcard_key: str, bucket_name: str | None = None, delimiter: str = ""
) -> bool:
Expand All @@ -566,8 +579,8 @@ def check_for_wildcard_key(
is not None
)

@provide_bucket_name
@unify_bucket_name_and_key
@provide_bucket_name
def get_wildcard_key(
self, wildcard_key: str, bucket_name: str | None = None, delimiter: str = ""
) -> S3Transfer:
Expand All @@ -586,8 +599,8 @@ def get_wildcard_key(
return self.get_key(key_matches[0], bucket_name)
return None

@provide_bucket_name
@unify_bucket_name_and_key
@provide_bucket_name
def load_file(
self,
filename: Path | str,
Expand Down Expand Up @@ -632,8 +645,8 @@ def load_file(
client = self.get_conn()
client.upload_file(filename, bucket_name, key, ExtraArgs=extra_args, Config=self.transfer_config)

@provide_bucket_name
@unify_bucket_name_and_key
@provide_bucket_name
def load_string(
self,
string_data: str,
Expand Down Expand Up @@ -682,8 +695,8 @@ def load_string(
self._upload_file_obj(file_obj, key, bucket_name, replace, encrypt, acl_policy)
file_obj.close()

@provide_bucket_name
@unify_bucket_name_and_key
@provide_bucket_name
def load_bytes(
self,
bytes_data: bytes,
Expand Down Expand Up @@ -713,8 +726,8 @@ def load_bytes(
self._upload_file_obj(file_obj, key, bucket_name, replace, encrypt, acl_policy)
file_obj.close()

@provide_bucket_name
@unify_bucket_name_and_key
@provide_bucket_name
def load_file_obj(
self,
file_obj: BytesIO,
Expand Down Expand Up @@ -860,8 +873,8 @@ def delete_objects(self, bucket: str, keys: str | list) -> None:
errors_keys = [x["Key"] for x in response.get("Errors", [])]
raise AirflowException(f"Errors when deleting: {errors_keys}")

@provide_bucket_name
@unify_bucket_name_and_key
@provide_bucket_name
def download_file(
self,
key: str,
Expand Down
Loading

0 comments on commit 3eee33a

Please sign in to comment.