Skip to content

Commit

Permalink
Refactor string splitting (apache#34185)
Browse files Browse the repository at this point in the history
  • Loading branch information
eumiro authored Oct 20, 2023
1 parent 570d88b commit f816237
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 14 deletions.
8 changes: 4 additions & 4 deletions airflow/providers/amazon/aws/hooks/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,12 +218,12 @@ def parse_s3_url(s3url: str) -> tuple[str, str]:
elif format[0] == "https:":
temp_split = format[1].split(".")
if temp_split[0] == "s3":
split_url = format[1].split("/")
bucket_name = split_url[1]
key = "/".join(split_url[2:])
# "https://s3.region-code.amazonaws.com/bucket-name/key-name"
_, bucket_name, key = format[1].split("/", 2)
elif temp_split[1] == "s3":
# "https://bucket-name.s3.region-code.amazonaws.com/key-name"
bucket_name = temp_split[0]
key = "/".join(format[1].split("/")[1:])
key = format[1].partition("/")[-1]
else:
raise S3HookUriParseFailure(
"Please provide a bucket name using a valid virtually hosted format which should "
Expand Down
2 changes: 1 addition & 1 deletion dev/breeze/src/airflow_breeze/commands/setup_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def get_commands() -> list[str]:
for line in content.splitlines():
strip_line = line.strip()
if strip_line and not strip_line.startswith("#"):
results.append(":".join(strip_line.split(":")[:-1]))
results.append(strip_line.rpartition(":")[0])
return results


Expand Down
10 changes: 9 additions & 1 deletion dev/breeze/src/airflow_breeze/params/doc_build_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,15 @@
from dataclasses import dataclass

from airflow_breeze.branch_defaults import AIRFLOW_BRANCH
from airflow_breeze.utils.general_utils import get_provider_name_from_short_hand

providers_prefix = "apache-airflow-providers-"


def get_provider_name_from_short_hand(short_form_providers: tuple[str]):
return tuple(
providers_prefix + short_form_provider.replace(".", "-")
for short_form_provider in short_form_providers
)


@dataclass
Expand Down
6 changes: 3 additions & 3 deletions scripts/in_container/verify_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def onerror(_):
console.print(f"Skipping module: {modinfo.name}")
continue
if print_imports:
package_to_print = ".".join(modinfo.name.split(".")[:-1])
package_to_print = modinfo.name.rpartition(".")[0]
if package_to_print not in printed_packages:
printed_packages.add(package_to_print)
console.print(f"Importing package: {package_to_print}")
Expand Down Expand Up @@ -247,7 +247,7 @@ def is_imported_from_same_module(the_class: str, imported_name: str) -> bool:
:param imported_name: name of the imported class
:return: true if the class was imported from another module
"""
return ".".join(imported_name.split(".")[:-1]) == the_class.__module__
return imported_name.rpartition(":")[0] == the_class.__module__


def is_example_dag(imported_name: str) -> bool:
Expand Down Expand Up @@ -360,7 +360,7 @@ def convert_class_name_to_url(base_url: str, class_name) -> str:
:param class_name: name of the class
:return: URL to the class
"""
return base_url + os.path.sep.join(class_name.split(".")[:-1]) + ".py"
return base_url + class_name.rpartition(".")[0].replace(".", "/") + ".py"


def get_class_code_link(base_package: str, class_name: str, git_tag: str) -> str:
Expand Down
6 changes: 2 additions & 4 deletions tests/providers/amazon/aws/hooks/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import unittest
from unittest import mock, mock as async_mock
from unittest.mock import MagicMock, Mock, patch
from urllib.parse import parse_qs

import boto3
import pytest
Expand Down Expand Up @@ -1052,10 +1053,7 @@ def test_generate_presigned_url(self, s3_bucket):
presigned_url = hook.generate_presigned_url(
client_method="get_object", params={"Bucket": s3_bucket, "Key": "my_key"}
)

url = presigned_url.split("?")[1]
params = {x[0]: x[1] for x in [x.split("=") for x in url[0:].split("&")]}

params = parse_qs(presigned_url.partition("?")[-1])
assert {"AWSAccessKeyId", "Signature", "Expires"}.issubset(set(params.keys()))

def test_should_throw_error_if_extra_args_is_not_dict(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils/get_all_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def print_all_cases(xunit_test_file_path):
modules = set()

for test_case in test_cases:
the_module = ".".join(test_case.get("classname").split(".")[:-1])
the_module = test_case["classname"].rpartition(".")[0]
the_class = last_replace(test_case.get("classname"), ".", ":", 1)
test_method = test_case.get("name")
modules.add(the_module)
Expand Down

0 comments on commit f816237

Please sign in to comment.