Skip to content

Commit

Permalink
Migrate remaining API tests to pytest (apache#28311)
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis authored Dec 13, 2022
1 parent 4dd6b2c commit 02c2b2c
Show file tree
Hide file tree
Showing 14 changed files with 239 additions and 248 deletions.
48 changes: 6 additions & 42 deletions tests/api/common/test_mark_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,9 +616,10 @@ def test_set_running_dag_run_to_activate_state(self, dag_run_alter_function: Cal
self._verify_task_instance_states_remain_default(dr)
self._verify_dag_run_dates(self.dag1, date, new_state, middle_time) # type: ignore

def test_set_success_dag_run_to_success(self):
@pytest.mark.parametrize("completed_state", [State.SUCCESS, State.FAILED])
def test_set_success_dag_run_to_success(self, completed_state):
date = self.execution_dates[0]
dr = self._create_test_dag_run(State.SUCCESS, date)
dr = self._create_test_dag_run(completed_state, date)
middle_time = timezone.utcnow()
self._set_default_task_instance_states(dr)

Expand All @@ -631,9 +632,10 @@ def test_set_success_dag_run_to_success(self):
self._verify_task_instance_states(self.dag1, date, State.SUCCESS)
self._verify_dag_run_dates(self.dag1, date, State.SUCCESS, middle_time)

def test_set_success_dag_run_to_failed(self):
@pytest.mark.parametrize("completed_state", [State.SUCCESS, State.FAILED])
def test_set_completed_dag_run_to_failed(self, completed_state):
date = self.execution_dates[0]
dr = self._create_test_dag_run(State.SUCCESS, date)
dr = self._create_test_dag_run(completed_state, date)
middle_time = timezone.utcnow()
self._set_default_task_instance_states(dr)

Expand Down Expand Up @@ -663,36 +665,6 @@ def test_set_success_dag_run_to_activate_state(self, dag_run_alter_function: Cal
self._verify_task_instance_states_remain_default(dr)
self._verify_dag_run_dates(self.dag1, date, new_state, middle_time) # type: ignore

def test_set_failed_dag_run_to_success(self):
date = self.execution_dates[0]
dr = self._create_test_dag_run(State.SUCCESS, date)
middle_time = timezone.utcnow()
self._set_default_task_instance_states(dr)

altered = set_dag_run_state_to_success(dag=self.dag1, run_id=dr.run_id, commit=True)

# All except the SUCCESS task should be altered.
expected = self._get_num_tasks_with_starting_state(State.SUCCESS, inclusion=False)
assert len(altered) == expected
self._verify_dag_run_state(self.dag1, date, State.SUCCESS)
self._verify_task_instance_states(self.dag1, date, State.SUCCESS)
self._verify_dag_run_dates(self.dag1, date, State.SUCCESS, middle_time)

def test_set_failed_dag_run_to_failed(self):
date = self.execution_dates[0]
dr = self._create_test_dag_run(State.SUCCESS, date)
middle_time = timezone.utcnow()
self._set_default_task_instance_states(dr)

altered = set_dag_run_state_to_failed(dag=self.dag1, run_id=dr.run_id, commit=True)

# Only non-completed tasks should be altered.
expected = self._get_num_tasks_with_non_completed_state()
assert len(altered) == expected
self._verify_dag_run_state(self.dag1, date, State.FAILED)
assert dr.get_task_instance("run_after_loop").state == State.FAILED
self._verify_dag_run_dates(self.dag1, date, State.FAILED, middle_time)

@pytest.mark.parametrize(
"dag_run_alter_function,state",
[(set_dag_run_state_to_running, State.RUNNING), (set_dag_run_state_to_queued, State.QUEUED)],
Expand Down Expand Up @@ -828,11 +800,3 @@ def test_set_dag_run_state_to_failed_no_running_tasks(self):
dr.get_task_instance(task.task_id).set_state(State.SUCCESS)

set_dag_run_state_to_failed(dag=self.dag1, run_id=dr.run_id)

def tearDown(self):
self.dag1.clear()
self.dag2.clear()

with create_session() as session:
session.query(models.DagRun).delete()
session.query(models.TaskInstance).delete()
25 changes: 14 additions & 11 deletions tests/api_connexion/endpoints/test_connection_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from __future__ import annotations

import pytest
from parameterized import parameterized

from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP
from airflow.models import Connection
Expand Down Expand Up @@ -252,7 +251,8 @@ def test_should_raises_401_unauthenticated(self):


class TestGetConnectionsPagination(TestConnectionEndpoint):
@parameterized.expand(
@pytest.mark.parametrize(
"url, expected_conn_ids",
[
("/api/v1/connections?limit=1", ["TEST_CONN_ID1"]),
("/api/v1/connections?limit=2", ["TEST_CONN_ID1", "TEST_CONN_ID2"]),
Expand Down Expand Up @@ -287,7 +287,7 @@ class TestGetConnectionsPagination(TestConnectionEndpoint):
"/api/v1/connections?limit=2&offset=2",
["TEST_CONN_ID3", "TEST_CONN_ID4"],
),
]
],
)
@provide_session
def test_handle_limit_offset(self, url, expected_conn_ids, session):
Expand Down Expand Up @@ -354,11 +354,12 @@ def _create_connections(self, count):


class TestPatchConnection(TestConnectionEndpoint):
@parameterized.expand(
@pytest.mark.parametrize(
"payload",
[
({"connection_id": "test-connection-id", "conn_type": "test_type", "extra": "{'key': 'var'}"},),
({"extra": "{'key': 'var'}"},),
]
{"connection_id": "test-connection-id", "conn_type": "test_type", "extra": "{'key': 'var'}"},
{"extra": "{'key': 'var'}"},
],
)
@provide_session
def test_patch_should_respond_200(self, payload, session):
Expand Down Expand Up @@ -399,7 +400,8 @@ def test_patch_should_respond_200_with_update_mask(self, session):
"host": None,
}

@parameterized.expand(
@pytest.mark.parametrize(
"payload, update_mask, error_message",
[
(
{
Expand Down Expand Up @@ -443,7 +445,7 @@ def test_patch_should_respond_200_with_update_mask(self, session):
"", # not necessary
"The connection_id cannot be updated.",
),
]
],
)
@provide_session
def test_patch_should_respond_400_for_invalid_fields_in_update_mask(
Expand All @@ -458,7 +460,8 @@ def test_patch_should_respond_400_for_invalid_fields_in_update_mask(
assert response.status_code == 400
assert response.json["detail"] == error_message

@parameterized.expand(
@pytest.mark.parametrize(
"payload, error_message",
[
(
{
Expand All @@ -485,7 +488,7 @@ def test_patch_should_respond_400_for_invalid_fields_in_update_mask(
},
"_password",
),
]
],
)
@provide_session
def test_patch_should_respond_400_for_invalid_update(self, payload, error_message, session):
Expand Down
36 changes: 21 additions & 15 deletions tests/api_connexion/endpoints/test_dag_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from datetime import datetime

import pytest
from parameterized import parameterized

from airflow import DAG
from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP
Expand Down Expand Up @@ -698,13 +697,14 @@ def test_only_active_false_returns_all_dags(self, url_safe_serializer):
"total_entries": 2,
} == response.json

@parameterized.expand(
@pytest.mark.parametrize(
"url, expected_dag_ids",
[
("api/v1/dags?tags=t1", ["TEST_DAG_1", "TEST_DAG_3"]),
("api/v1/dags?tags=t2", ["TEST_DAG_2", "TEST_DAG_3"]),
("api/v1/dags?tags=t1,t2", ["TEST_DAG_1", "TEST_DAG_2", "TEST_DAG_3"]),
("api/v1/dags", ["TEST_DAG_1", "TEST_DAG_2", "TEST_DAG_3", "TEST_DAG_4"]),
]
],
)
def test_filter_dags_by_tags_works(self, url, expected_dag_ids):
# test filter by tags
Expand All @@ -723,15 +723,16 @@ def test_filter_dags_by_tags_works(self, url, expected_dag_ids):

assert expected_dag_ids == dag_ids

@parameterized.expand(
@pytest.mark.parametrize(
"url, expected_dag_ids",
[
("api/v1/dags?dag_id_pattern=DAG_1", {"TEST_DAG_1", "SAMPLE_DAG_1"}),
("api/v1/dags?dag_id_pattern=SAMPLE_DAG", {"SAMPLE_DAG_1", "SAMPLE_DAG_2"}),
(
"api/v1/dags?dag_id_pattern=_DAG_",
{"TEST_DAG_1", "TEST_DAG_2", "SAMPLE_DAG_1", "SAMPLE_DAG_2"},
),
]
],
)
def test_filter_dags_by_dag_id_works(self, url, expected_dag_ids):
# test filter by tags
Expand Down Expand Up @@ -759,7 +760,8 @@ def test_should_respond_200_with_granular_dag_access(self):
assert len(response.json["dags"]) == 1
assert response.json["dags"][0]["dag_id"] == "TEST_DAG_1"

@parameterized.expand(
@pytest.mark.parametrize(
"url, expected_dag_ids",
[
("api/v1/dags?limit=1", ["TEST_DAG_1"]),
("api/v1/dags?limit=2", ["TEST_DAG_1", "TEST_DAG_10"]),
Expand All @@ -785,7 +787,7 @@ def test_should_respond_200_with_granular_dag_access(self):
("api/v1/dags?limit=1&offset=5", ["TEST_DAG_5"]),
("api/v1/dags?limit=1&offset=1", ["TEST_DAG_10"]),
("api/v1/dags?limit=2&offset=2", ["TEST_DAG_2", "TEST_DAG_3"]),
]
],
)
def test_should_respond_200_and_handle_pagination(self, url, expected_dag_ids):
self._create_dag_models(10)
Expand Down Expand Up @@ -965,7 +967,8 @@ def test_should_respond_200_with_update_mask(self, url_safe_serializer):
}
assert response.json == expected_response

@parameterized.expand(
@pytest.mark.parametrize(
"payload, update_mask, error_message",
[
(
{
Expand All @@ -981,7 +984,7 @@ def test_should_respond_200_with_update_mask(self, url_safe_serializer):
"update_mask=schedule_interval, description",
"Only `is_paused` field can be updated through the REST API",
),
]
],
)
def test_should_respond_400_for_invalid_fields_in_update_mask(self, payload, update_mask, error_message):
dag_model = self._create_dag_model()
Expand Down Expand Up @@ -1241,13 +1244,14 @@ def test_only_active_false_returns_all_dags(self, url_safe_serializer):
"total_entries": 2,
} == response.json

@parameterized.expand(
@pytest.mark.parametrize(
"url, expected_dag_ids",
[
("api/v1/dags?tags=t1&dag_id_pattern=~", ["TEST_DAG_1", "TEST_DAG_3"]),
("api/v1/dags?tags=t2&dag_id_pattern=~", ["TEST_DAG_2", "TEST_DAG_3"]),
("api/v1/dags?tags=t1,t2&dag_id_pattern=~", ["TEST_DAG_1", "TEST_DAG_2", "TEST_DAG_3"]),
("api/v1/dags?dag_id_pattern=~", ["TEST_DAG_1", "TEST_DAG_2", "TEST_DAG_3", "TEST_DAG_4"]),
]
],
)
def test_filter_dags_by_tags_works(self, url, expected_dag_ids):
# test filter by tags
Expand All @@ -1271,15 +1275,16 @@ def test_filter_dags_by_tags_works(self, url, expected_dag_ids):

assert expected_dag_ids == dag_ids

@parameterized.expand(
@pytest.mark.parametrize(
"url, expected_dag_ids",
[
("api/v1/dags?dag_id_pattern=DAG_1", {"TEST_DAG_1", "SAMPLE_DAG_1"}),
("api/v1/dags?dag_id_pattern=SAMPLE_DAG", {"SAMPLE_DAG_1", "SAMPLE_DAG_2"}),
(
"api/v1/dags?dag_id_pattern=_DAG_",
{"TEST_DAG_1", "TEST_DAG_2", "SAMPLE_DAG_1", "SAMPLE_DAG_2"},
),
]
],
)
def test_filter_dags_by_dag_id_works(self, url, expected_dag_ids):
# test filter by tags
Expand Down Expand Up @@ -1317,7 +1322,8 @@ def test_should_respond_200_with_granular_dag_access(self):
assert len(response.json["dags"]) == 1
assert response.json["dags"][0]["dag_id"] == "TEST_DAG_1"

@parameterized.expand(
@pytest.mark.parametrize(
"url, expected_dag_ids",
[
("api/v1/dags?limit=1&dag_id_pattern=~", ["TEST_DAG_1"]),
("api/v1/dags?limit=2&dag_id_pattern=~", ["TEST_DAG_1", "TEST_DAG_10"]),
Expand All @@ -1343,7 +1349,7 @@ def test_should_respond_200_with_granular_dag_access(self):
("api/v1/dags?limit=1&offset=5&dag_id_pattern=~", ["TEST_DAG_5"]),
("api/v1/dags?limit=1&offset=1&dag_id_pattern=~", ["TEST_DAG_10"]),
("api/v1/dags?limit=2&offset=2&dag_id_pattern=~", ["TEST_DAG_2", "TEST_DAG_3"]),
]
],
)
def test_should_respond_200_and_handle_pagination(self, url, expected_dag_ids):
self._create_dag_models(10)
Expand Down
Loading

0 comments on commit 02c2b2c

Please sign in to comment.