Skip to content

Commit

Permalink
Activate RUF019 that checks for unnecessary key check (apache#38950)
Browse files Browse the repository at this point in the history
  • Loading branch information
hussein-awala authored Apr 15, 2024
1 parent 6520653 commit f810432
Show file tree
Hide file tree
Showing 8 changed files with 26 additions and 24 deletions.
2 changes: 1 addition & 1 deletion airflow/example_dags/example_params_trigger_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def select_languages(**kwargs) -> list[str]:
dag_run: DagRun = ti.dag_run
selected_languages = []
for lang in ["english", "german", "french"]:
if lang in dag_run.conf and dag_run.conf[lang]:
if dag_run.conf.get(lang):
selected_languages.append(f"generate_{lang}_greeting")
return selected_languages

Expand Down
10 changes: 3 additions & 7 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,8 +582,7 @@ def __init__(
if start_date and start_date.tzinfo:
tzinfo = None if start_date.tzinfo else settings.TIMEZONE
tz = pendulum.instance(start_date, tz=tzinfo).timezone
elif "start_date" in self.default_args and self.default_args["start_date"]:
date = self.default_args["start_date"]
elif date := self.default_args.get("start_date"):
if not isinstance(date, datetime):
date = timezone.parse(date)
self.default_args["start_date"] = date
Expand All @@ -594,11 +593,8 @@ def __init__(
self.timezone: Timezone | FixedTimezone = tz or settings.TIMEZONE

# Apply the timezone we settled on to end_date if it wasn't supplied
if "end_date" in self.default_args and self.default_args["end_date"]:
if isinstance(self.default_args["end_date"], str):
self.default_args["end_date"] = timezone.parse(
self.default_args["end_date"], timezone=self.timezone
)
if isinstance(_end_date := self.default_args.get("end_date"), str):
self.default_args["end_date"] = timezone.parse(_end_date, timezone=self.timezone)

self.start_date = timezone.convert_to_utc(start_date)
self.end_date = timezone.convert_to_utc(end_date)
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/fab/auth_manager/fab_auth_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,8 +350,8 @@ def get_url_login(self, **kwargs) -> str:
"""Return the login page url."""
if not self.security_manager.auth_view:
raise AirflowException("`auth_view` not defined in the security manager.")
if "next_url" in kwargs and kwargs["next_url"]:
return url_for(f"{self.security_manager.auth_view.endpoint}.login", next=kwargs["next_url"])
if next_url := kwargs.get("next_url"):
return url_for(f"{self.security_manager.auth_view.endpoint}.login", next=next_url)
else:
return url_for(f"{self.security_manager.auth_view.endpoint}.login")

Expand Down
6 changes: 2 additions & 4 deletions airflow/providers/google/cloud/hooks/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2849,11 +2849,10 @@ def next(self) -> list | None:
return None

query_results = self._get_query_result()
if "rows" in query_results and query_results["rows"]:
if rows := query_results.get("rows"):
self.page_token = query_results.get("pageToken")
fields = query_results["schema"]["fields"]
col_types = [field["type"] for field in fields]
rows = query_results["rows"]

for dict_row in rows:
typed_row = [bq_cast(vs["v"], col_types[idx]) for idx, vs in enumerate(dict_row["f"])]
Expand Down Expand Up @@ -3396,8 +3395,7 @@ def get_records(self, query_results: dict[str, Any], as_dict: bool = False) -> l
:param as_dict: if True returns the result as a list of dictionaries, otherwise as list of lists.
"""
buffer: list[Any] = []
if "rows" in query_results and query_results["rows"]:
rows = query_results["rows"]
if rows := query_results.get("rows"):
fields = query_results["schema"]["fields"]
fields_names = [field["name"] for field in fields]
col_types = [field["type"] for field in fields]
Expand Down
11 changes: 6 additions & 5 deletions airflow/providers/snowflake/hooks/snowflake_sql_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,11 +283,12 @@ def _process_response(self, status_code, resp):
elif status_code == 422:
return {"status": "error", "message": resp["message"]}
elif status_code == 200:
statement_handles = []
if "statementHandles" in resp and resp["statementHandles"]:
statement_handles = resp["statementHandles"]
elif "statementHandle" in resp and resp["statementHandle"]:
statement_handles.append(resp["statementHandle"])
if resp_statement_handles := resp.get("statementHandles"):
statement_handles = resp_statement_handles
elif resp_statement_handle := resp.get("statementHandle"):
statement_handles = [resp_statement_handle]
else:
statement_handles = []
return {
"status": "success",
"message": resp["message"],
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,9 @@ extend-select = [
"PT", # flake8-pytest-style rules
"TID25", # flake8-tidy-imports rules
# Per rule enables
"RUF006", # Checks for asyncio dangling task
"RUF015", # Checks for unnecessary iterable allocation for first element
"RUF019", # Checks for unnecessary key check
"RUF100", # Unused noqa (auto-fixable)
# We ignore more pydocstyle than we enable, so be more selective at what we enable
"D101",
Expand All @@ -292,8 +295,6 @@ extend-select = [
"B019", # Use of functools.lru_cache or functools.cache on methods can lead to memory leaks
"B028", # No explicit stacklevel keyword argument found
"TRY002", # Prohibit use of `raise Exception`, use specific exceptions instead.
"RUF006", # Checks for asyncio dangling task
"RUF015", # Checks for unnecessary iterable allocation for first element
]
ignore = [
"D203",
Expand Down
8 changes: 6 additions & 2 deletions tests/providers/google/cloud/hooks/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -1794,7 +1794,9 @@ class TestBigQueryHookLegacySql(_BigQueryBaseTestClass):

@mock.patch("airflow.providers.google.cloud.hooks.bigquery.build")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
def test_hook_uses_legacy_sql_by_default(self, mock_insert, _):
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryCursor._get_query_result")
def test_hook_uses_legacy_sql_by_default(self, mock_get_query_result, mock_insert, _):
mock_get_query_result.return_value = {}
self.hook.get_first("query")
_, kwargs = mock_insert.call_args
assert kwargs["configuration"]["query"]["useLegacySql"] is True
Expand All @@ -1805,9 +1807,11 @@ def test_hook_uses_legacy_sql_by_default(self, mock_insert, _):
)
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.build")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.insert_job")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryCursor._get_query_result")
def test_legacy_sql_override_propagates_properly(
self, mock_insert, mock_build, mock_get_creds_and_proj_id
self, mock_get_query_result, mock_insert, mock_build, mock_get_creds_and_proj_id
):
mock_get_query_result.return_value = {}
bq_hook = BigQueryHook(use_legacy_sql=False)
bq_hook.get_first("query")
_, kwargs = mock_insert.call_args
Expand Down
4 changes: 3 additions & 1 deletion tests/providers/google/cloud/triggers/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,12 +499,14 @@ def test_interval_check_trigger_serialization(self, interval_check_trigger):
@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_status")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_job_output")
@mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryAsyncHook.get_records")
async def test_interval_check_trigger_success(
self, mock_get_job_output, mock_job_status, interval_check_trigger
self, mock_get_records, mock_get_job_output, mock_job_status, interval_check_trigger
):
"""
Tests the BigQueryIntervalCheckTrigger only fires once the query execution reaches a successful state.
"""
mock_get_records.return_value = {}
mock_job_status.return_value = {"status": "success", "message": "Job completed"}
mock_get_job_output.return_value = ["0"]

Expand Down

0 comments on commit f810432

Please sign in to comment.