diff --git a/airflow/api_connexion/endpoints/connection_endpoint.py b/airflow/api_connexion/endpoints/connection_endpoint.py index 16d9afb5b9e33..9d3224dd8f57c 100644 --- a/airflow/api_connexion/endpoints/connection_endpoint.py +++ b/airflow/api_connexion/endpoints/connection_endpoint.py @@ -205,5 +205,4 @@ def test_connection() -> APIResponse: except ValidationError as err: raise BadRequest(detail=str(err.messages)) finally: - if conn_env_var in os.environ: - del os.environ[conn_env_var] + os.environ.pop(conn_env_var, None) diff --git a/airflow/api_connexion/endpoints/pool_endpoint.py b/airflow/api_connexion/endpoints/pool_endpoint.py index 0fbb2c8a23f4d..23fa23922f135 100644 --- a/airflow/api_connexion/endpoints/pool_endpoint.py +++ b/airflow/api_connexion/endpoints/pool_endpoint.py @@ -92,14 +92,11 @@ def patch_pool( """Update a pool.""" request_dict = get_json_request_dict() # Only slots and include_deferred can be modified in 'default_pool' - try: - if pool_name == Pool.DEFAULT_POOL_NAME and request_dict["name"] != Pool.DEFAULT_POOL_NAME: - if update_mask and all(mask.strip() in {"slots", "include_deferred"} for mask in update_mask): - pass - else: - raise BadRequest(detail="Default Pool's name can't be modified") - except KeyError: - pass + if pool_name == Pool.DEFAULT_POOL_NAME and request_dict.get("name", None) != Pool.DEFAULT_POOL_NAME: + if update_mask and all(mask.strip() in {"slots", "include_deferred"} for mask in update_mask): + pass + else: + raise BadRequest(detail="Default Pool's name can't be modified") pool = session.scalar(select(Pool).where(Pool.pool == pool_name).limit(1)) if not pool: diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py b/airflow/api_connexion/endpoints/task_instance_endpoint.py index 0b942134acc80..946f8fc0dcb49 100644 --- a/airflow/api_connexion/endpoints/task_instance_endpoint.py +++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py @@ -43,6 +43,7 @@ ) from airflow.api_connexion.security import get_readable_dags from airflow.auth.managers.models.resource_details import DagAccessEntity, DagDetails +from airflow.exceptions import TaskNotFound from airflow.models import SlaMiss from airflow.models.dagrun import DagRun as DR from airflow.models.operator import needs_expansion @@ -192,8 +193,9 @@ def get_mapped_task_instances( if not dag: error_message = f"DAG {dag_id} not found" raise NotFound(error_message) - task = dag.get_task(task_id) - if not task: + try: + task = dag.get_task(task_id) + except TaskNotFound: error_message = f"Task id {task_id} not found" raise NotFound(error_message) if not needs_expansion(task): diff --git a/scripts/cov/restapi_coverage.py b/scripts/cov/restapi_coverage.py index 8d94894da9aec..9a3dc2a143c0d 100644 --- a/scripts/cov/restapi_coverage.py +++ b/scripts/cov/restapi_coverage.py @@ -28,10 +28,7 @@ restapi_files = ["tests/api_experimental", "tests/api_connexion", "tests/api_internal"] files_not_fully_covered = [ - "airflow/api_connexion/endpoints/dag_run_endpoint.py", "airflow/api_connexion/endpoints/forward_to_fab_endpoint.py", - "airflow/api_connexion/endpoints/pool_endpoint.py", - "airflow/api_connexion/endpoints/task_endpoint.py", "airflow/api_connexion/endpoints/task_instance_endpoint.py", "airflow/api_connexion/endpoints/xcom_endpoint.py", "airflow/api_connexion/exceptions.py", diff --git a/tests/api_connexion/endpoints/test_connection_endpoint.py b/tests/api_connexion/endpoints/test_connection_endpoint.py index 1dd3752873582..043dc0b5a17c2 100644 --- a/tests/api_connexion/endpoints/test_connection_endpoint.py +++ b/tests/api_connexion/endpoints/test_connection_endpoint.py @@ -23,6 +23,7 @@ from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP from airflow.models import Connection +from airflow.secrets.environment_variables import CONN_ENV_PREFIX from airflow.security import permissions from airflow.utils.session import provide_session from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user @@ -620,6 +621,12 @@ def test_should_respond_200(self): "message": "Connection successfully tested", } + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) + def test_connection_env_is_cleaned_after_run(self): + payload = {"connection_id": "test-connection-id", "conn_type": "sqlite"} + self.client.post("/api/v1/connections/test", json=payload, environ_overrides={"REMOTE_USER": "test"}) + assert not any([key.startswith(CONN_ENV_PREFIX) for key in os.environ.keys()]) + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) def test_post_should_respond_400_for_invalid_payload(self): payload = { diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py index 04367402066e8..768c235b92373 100644 --- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py @@ -131,7 +131,7 @@ def _create_dag(self, dag_id): self.app.dag_bag.bag_dag(dag, root_dag=dag) return dag_instance - def _create_test_dag_run(self, state="running", extra_dag=False, commit=True, idx_start=1): + def _create_test_dag_run(self, state=DagRunState.RUNNING, extra_dag=False, commit=True, idx_start=1): dag_runs = [] dags = [] @@ -147,6 +147,7 @@ def _create_test_dag_run(self, state="running", extra_dag=False, commit=True, id external_trigger=True, state=state, ) + dagrun_model.updated_at = timezone.parse(self.default_time) dag_runs.append(dagrun_model) if extra_dag: @@ -553,12 +554,32 @@ class TestGetDagRunsPaginationFilters(TestDagRunEndpoint): "TEST_START_EXEC_DAY_19", ], ), + ( + "api/v1/dags/TEST_DAG_ID/dagRuns?updated_at_lte=2020-06-13T18%3A00%3A00%2B00%3A00", + [ + "TEST_START_EXEC_DAY_10", + "TEST_START_EXEC_DAY_11", + "TEST_START_EXEC_DAY_12", + "TEST_START_EXEC_DAY_13", + ], + ), + ( + "api/v1/dags/TEST_DAG_ID/dagRuns?updated_at_gte=2020-06-16T18%3A00%3A00%2B00%3A00", + [ + "TEST_START_EXEC_DAY_16", + "TEST_START_EXEC_DAY_17", + "TEST_START_EXEC_DAY_18", + "TEST_START_EXEC_DAY_19", + ], + ), ], ) @provide_session def test_date_filters_gte_and_lte(self, url, expected_dag_run_ids, session): dagrun_models = self._create_dag_runs() session.add_all(dagrun_models) + for d in dagrun_models: + d.updated_at = d.execution_date session.commit() response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"}) @@ -589,7 +610,7 @@ def _create_dag_runs(self): execution_date=timezone.parse(dates[i]), start_date=timezone.parse(dates[i]), external_trigger=True, - state="success", + state=DagRunState.SUCCESS, ) for i in range(len(dates)) ] @@ -667,6 +688,21 @@ def test_should_respond_200(self): "total_entries": 2, } + def test_raises_validation_error_for_invalid_request(self): + self._create_test_dag_run() + response = self.client.post( + "api/v1/dags/~/dagRuns/list", + json={"dagids": ["TEST_DAG_ID"]}, + environ_overrides={"REMOTE_USER": "test"}, + ) + assert response.status_code == 400 + assert response.json == { + "detail": "{'dagids': ['Unknown field.']}", + "status": 400, + "title": "Bad Request", + "type": EXCEPTIONS_LINK_MAP[400], + } + def test_filter_by_state(self): self._create_test_dag_run() self._create_test_dag_run(state="queued", idx_start=3) @@ -1092,6 +1128,41 @@ def test_should_respond_200(self, session, logical_date_field_name, dag_run_id, } _check_last_log(session, dag_id="TEST_DAG_ID", event="dag_run.create", execution_date=None) + def test_raises_validation_error_for_invalid_request(self): + self._create_dag("TEST_DAG_ID") + response = self.client.post( + "api/v1/dags/TEST_DAG_ID/dagRuns", + json={"executiondate": "2020-11-10T08:25:56Z"}, + environ_overrides={"REMOTE_USER": "test"}, + ) + assert response.status_code == 400 + assert response.json == { + "detail": "{'executiondate': ['Unknown field.']}", + "status": 400, + "title": "Bad Request", + "type": EXCEPTIONS_LINK_MAP[400], + } + + @mock.patch("airflow.api_connexion.endpoints.dag_run_endpoint.get_airflow_app") + def test_dagrun_creation_exception_is_handled(self, mock_get_app, session): + self._create_dag("TEST_DAG_ID") + error_message = "Encountered Error" + mock_get_app.return_value.dag_bag.get_dag.return_value.create_dagrun.side_effect = ValueError( + error_message + ) + response = self.client.post( + "api/v1/dags/TEST_DAG_ID/dagRuns", + json={"execution_date": "2020-11-10T08:25:56Z"}, + environ_overrides={"REMOTE_USER": "test"}, + ) + assert response.status_code == 400 + assert response.json == { + "detail": error_message, + "status": 400, + "title": "Bad Request", + "type": EXCEPTIONS_LINK_MAP[400], + } + def test_should_respond_404_if_a_dag_is_inactive(self, session): dm = self._create_dag("TEST_INACTIVE_DAG_ID") dm.is_active = False @@ -1382,6 +1453,27 @@ def test_should_respond_200(self, state, run_type, dag_maker, session): "note": None, } + def test_schema_validation_error_raises(self, dag_maker, session): + dag_id = "TEST_DAG_ID" + dag_run_id = "TEST_DAG_RUN_ID" + with dag_maker(dag_id) as dag: + EmptyOperator(task_id="task_id", dag=dag) + self.app.dag_bag.bag_dag(dag, root_dag=dag) + dag_maker.create_dagrun(run_id=dag_run_id) + + response = self.client.patch( + f"api/v1/dags/{dag_id}/dagRuns/{dag_run_id}", + json={"states": "success"}, + environ_overrides={"REMOTE_USER": "test"}, + ) + assert response.status_code == 400 + assert response.json == { + "detail": "{'states': ['Unknown field.']}", + "status": 400, + "title": "Bad Request", + "type": EXCEPTIONS_LINK_MAP[400], + } + @pytest.mark.parametrize("invalid_state", ["running"]) @time_machine.travel(TestDagRunEndpoint.default_time) def test_should_response_400_for_non_existing_dag_run_state(self, invalid_state, dag_maker): @@ -1481,6 +1573,26 @@ def test_should_respond_200(self, dag_maker, session): ti.refresh_from_db() assert ti.state is None + def test_schema_validation_error_raises_for_invalid_fields(self, dag_maker, session): + dag_id = "TEST_DAG_ID" + dag_run_id = "TEST_DAG_RUN_ID" + with dag_maker(dag_id) as dag: + EmptyOperator(task_id="task_id", dag=dag) + self.app.dag_bag.bag_dag(dag, root_dag=dag) + dag_maker.create_dagrun(run_id=dag_run_id, state=DagRunState.FAILED) + response = self.client.post( + f"api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/clear", + json={"dryrun": False}, + environ_overrides={"REMOTE_USER": "test"}, + ) + assert response.status_code == 400 + assert response.json == { + "detail": "{'dryrun': ['Unknown field.']}", + "status": 400, + "title": "Bad Request", + "type": EXCEPTIONS_LINK_MAP[400], + } + def test_dry_run(self, dag_maker, session): """Test that dry_run being True returns TaskInstances without clearing DagRun""" dag_id = "TEST_DAG_ID" @@ -1648,11 +1760,10 @@ def test_should_raises_401_unauthenticated(self, session): class TestSetDagRunNote(TestDagRunEndpoint): def test_should_respond_200(self, dag_maker, session): - dag_runs: list[DagRun] = self._create_test_dag_run("success") + dag_runs: list[DagRun] = self._create_test_dag_run(DagRunState.SUCCESS) session.add_all(dag_runs) session.commit() created_dr: DagRun = dag_runs[0] - new_note_value = "My super cool DagRun notes" response = self.client.patch( f"api/v1/dags/{created_dr.dag_id}/dagRuns/{created_dr.run_id}/setNote", @@ -1680,6 +1791,51 @@ def test_should_respond_200(self, dag_maker, session): "note": new_note_value, } assert dr.dag_run_note.user_id is not None + # Update the note again + new_note_value = "My super cool DagRun notes 2" + response = self.client.patch( + f"api/v1/dags/{created_dr.dag_id}/dagRuns/{created_dr.run_id}/setNote", + json={"note": new_note_value}, + environ_overrides={"REMOTE_USER": "test"}, + ) + assert response.status_code == 200 + assert response.json == { + "conf": {}, + "dag_id": dr.dag_id, + "dag_run_id": dr.run_id, + "end_date": dr.end_date.isoformat(), + "execution_date": self.default_time, + "external_trigger": True, + "logical_date": self.default_time, + "start_date": self.default_time, + "state": "success", + "data_interval_start": None, + "data_interval_end": None, + "last_scheduling_decision": None, + "run_type": dr.run_type, + "note": new_note_value, + } + assert dr.dag_run_note.user_id is not None + + def test_schema_validation_error_raises(self, dag_maker, session): + dag_runs: list[DagRun] = self._create_test_dag_run(DagRunState.SUCCESS) + session.add_all(dag_runs) + session.commit() + created_dr: DagRun = dag_runs[0] + + new_note_value = "My super cool DagRun notes" + response = self.client.patch( + f"api/v1/dags/{created_dr.dag_id}/dagRuns/{created_dr.run_id}/setNote", + json={"notes": new_note_value}, + environ_overrides={"REMOTE_USER": "test"}, + ) + assert response.status_code == 400 + assert response.json == { + "detail": "{'notes': ['Unknown field.']}", + "status": 400, + "title": "Bad Request", + "type": EXCEPTIONS_LINK_MAP[400], + } def test_should_raises_401_unauthenticated(self, session): response = self.client.patch( diff --git a/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py index 49138cef2d4c1..9f6da68f0ef8b 100644 --- a/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py +++ b/tests/api_connexion/endpoints/test_mapped_task_instance_endpoint.py @@ -353,6 +353,18 @@ def test_mapped_task_instances_state_order(self, one_task_with_many_mapped_tis, assert list(range(5)) + list(range(25, 110)) + list(range(5, 15)) == [ ti["map_index"] for ti in response.json["task_instances"] ] + # State ascending + response = self.client.get( + "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped" + "?order_by=state", + environ_overrides={"REMOTE_USER": "test"}, + ) + assert response.status_code == 200 + assert response.json["total_entries"] == 110 + assert len(response.json["task_instances"]) == 100 + assert list(range(5, 25)) + list(range(90, 110)) + list(range(25, 85)) == [ + ti["map_index"] for ti in response.json["task_instances"] + ] @provide_session def test_mapped_task_instances_invalid_order(self, one_task_with_many_mapped_tis, session): @@ -448,3 +460,11 @@ def test_mapped_task_instances_with_zero_mapped(self, one_task_with_zero_mapped_ assert response.status_code == 200 assert response.json["total_entries"] == 0 assert response.json["task_instances"] == [] + + def test_should_raise_404_not_found_for_nonexistent_task(self): + response = self.client.get( + "/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/nonexistent_task/listMapped", + environ_overrides={"REMOTE_USER": "test"}, + ) + assert response.status_code == 404 + assert response.json["title"] == "Task id nonexistent_task not found" diff --git a/tests/api_connexion/endpoints/test_pool_endpoint.py b/tests/api_connexion/endpoints/test_pool_endpoint.py index 023dac388144b..23f487931b56f 100644 --- a/tests/api_connexion/endpoints/test_pool_endpoint.py +++ b/tests/api_connexion/endpoints/test_pool_endpoint.py @@ -429,6 +429,20 @@ def test_response_400(self, error_detail, request_json, session): "type": EXCEPTIONS_LINK_MAP[400], } == response.json + def test_not_found_when_no_pool_available(self): + response = self.client.patch( + "api/v1/pools/test_pool", + json={"name": "test_pool_a", "slots": 3}, + environ_overrides={"REMOTE_USER": "test"}, + ) + assert response.status_code == 404 + assert { + "detail": "Pool with name:'test_pool' not found", + "status": 404, + "title": "Not Found", + "type": EXCEPTIONS_LINK_MAP[404], + } == response.json + def test_should_raises_401_unauthenticated(self, session): pool = Pool(pool="test_pool", slots=2, include_deferred=False) session.add(pool) diff --git a/tests/api_connexion/endpoints/test_task_endpoint.py b/tests/api_connexion/endpoints/test_task_endpoint.py index b2ebcdb7f9907..b8ef8dc0cf650 100644 --- a/tests/api_connexion/endpoints/test_task_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_endpoint.py @@ -237,6 +237,14 @@ def test_should_respond_404(self): ) assert response.status_code == 404 + def test_should_respond_404_when_dag_not_found(self): + dag_id = "xxxx_not_existing" + response = self.client.get( + f"/api/v1/dags/{dag_id}/tasks/{self.task_id}", environ_overrides={"REMOTE_USER": "test"} + ) + assert response.status_code == 404 + assert response.json["title"] == "DAG not found" + def test_should_raises_401_unauthenticated(self): response = self.client.get(f"/api/v1/dags/{self.dag_id}/tasks/{self.task_id}") diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py b/tests/api_connexion/endpoints/test_task_instance_endpoint.py index 0c445c155dc23..9f913ac9fecf3 100644 --- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py +++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py @@ -469,6 +469,14 @@ def test_should_raise_403_forbidden(self): ) assert response.status_code == 403 + def test_raises_404_for_nonexistent_task_instance(self): + response = self.client.get( + "api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/nonexistent_task", + environ_overrides={"REMOTE_USER": "test"}, + ) + assert response.status_code == 404 + assert response.json["title"] == "Task instance not found" + def test_unmapped_map_index_should_return_404(self, session): self.create_task_instances(session) response = self.client.get( @@ -1000,6 +1008,15 @@ def test_should_raise_400_for_no_json(self): assert response.status_code == 400 assert response.json["detail"] == "Request body must not be empty" + def test_should_raise_400_for_unknown_fields(self): + response = self.client.post( + "/api/v1/dags/~/dagRuns/~/taskInstances/list", + environ_overrides={"REMOTE_USER": "test"}, + json={"unknown_field": "unknown_value"}, + ) + assert response.status_code == 400 + assert response.json["detail"] == "{'unknown_field': ['Unknown field.']}" + @pytest.mark.parametrize( "payload, expected", [ @@ -1353,6 +1370,279 @@ def test_should_respond_200_with_reset_dag_run(self, session): assert 6 == len(response.json["task_instances"]) assert 0 == failed_dag_runs, 0 + def test_should_respond_200_with_dag_run_id(self, session): + dag_id = "example_python_operator" + payload = { + "dry_run": False, + "reset_dag_runs": False, + "only_failed": False, + "only_running": True, + "include_subdags": True, + "dag_run_id": "TEST_DAG_RUN_ID_0", + } + task_instances = [ + {"execution_date": DEFAULT_DATETIME_1, "state": State.RUNNING}, + { + "execution_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1), + "state": State.RUNNING, + }, + { + "execution_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2), + "state": State.RUNNING, + }, + { + "execution_date": DEFAULT_DATETIME_1 + dt.timedelta(days=3), + "state": State.RUNNING, + }, + { + "execution_date": DEFAULT_DATETIME_1 + dt.timedelta(days=4), + "state": State.RUNNING, + }, + { + "execution_date": DEFAULT_DATETIME_1 + dt.timedelta(days=5), + "state": State.RUNNING, + }, + ] + + self.create_task_instances( + session, + dag_id=dag_id, + task_instances=task_instances, + update_extras=False, + dag_run_state=State.FAILED, + ) + response = self.client.post( + f"/api/v1/dags/{dag_id}/clearTaskInstances", + environ_overrides={"REMOTE_USER": "test"}, + json=payload, + ) + assert 200 == response.status_code + expected_response = [ + { + "dag_id": "example_python_operator", + "dag_run_id": "TEST_DAG_RUN_ID_0", + "execution_date": "2020-01-01T00:00:00+00:00", + "task_id": "print_the_context", + }, + ] + assert response.json["task_instances"] == expected_response + assert 1 == len(response.json["task_instances"]) + + def test_should_respond_200_with_include_past(self, session): + dag_id = "example_python_operator" + payload = { + "dry_run": False, + "reset_dag_runs": False, + "only_failed": False, + "include_past": True, + "only_running": True, + "include_subdags": True, + } + task_instances = [ + {"execution_date": DEFAULT_DATETIME_1, "state": State.RUNNING}, + { + "execution_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1), + "state": State.RUNNING, + }, + { + "execution_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2), + "state": State.RUNNING, + }, + { + "execution_date": DEFAULT_DATETIME_1 + dt.timedelta(days=3), + "state": State.RUNNING, + }, + { + "execution_date": DEFAULT_DATETIME_1 + dt.timedelta(days=4), + "state": State.RUNNING, + }, + { + "execution_date": DEFAULT_DATETIME_1 + dt.timedelta(days=5), + "state": State.RUNNING, + }, + ] + + self.create_task_instances( + session, + dag_id=dag_id, + task_instances=task_instances, + update_extras=False, + dag_run_state=State.FAILED, + ) + response = self.client.post( + f"/api/v1/dags/{dag_id}/clearTaskInstances", + environ_overrides={"REMOTE_USER": "test"}, + json=payload, + ) + assert 200 == response.status_code + expected_response = [ + { + "dag_id": "example_python_operator", + "dag_run_id": "TEST_DAG_RUN_ID_0", + "execution_date": "2020-01-01T00:00:00+00:00", + "task_id": "print_the_context", + }, + { + "dag_id": "example_python_operator", + "dag_run_id": "TEST_DAG_RUN_ID_1", + "execution_date": "2020-01-02T00:00:00+00:00", + "task_id": "log_sql_query", + }, + { + "dag_id": "example_python_operator", + "dag_run_id": "TEST_DAG_RUN_ID_2", + "execution_date": "2020-01-03T00:00:00+00:00", + "task_id": "sleep_for_0", + }, + { + "dag_id": "example_python_operator", + "dag_run_id": "TEST_DAG_RUN_ID_3", + "execution_date": "2020-01-04T00:00:00+00:00", + "task_id": "sleep_for_1", + }, + { + "dag_id": "example_python_operator", + "dag_run_id": "TEST_DAG_RUN_ID_4", + "execution_date": "2020-01-05T00:00:00+00:00", + "task_id": "sleep_for_2", + }, + { + "dag_id": "example_python_operator", + "dag_run_id": "TEST_DAG_RUN_ID_5", + "execution_date": "2020-01-06T00:00:00+00:00", + "task_id": "sleep_for_3", + }, + ] + for task_instance in expected_response: + assert task_instance in response.json["task_instances"] + assert 6 == len(response.json["task_instances"]) + + def test_should_respond_200_with_include_future(self, session): + dag_id = "example_python_operator" + payload = { + "dry_run": False, + "reset_dag_runs": False, + "only_failed": False, + "include_future": True, + "only_running": False, + } + task_instances = [ + {"execution_date": DEFAULT_DATETIME_1, "state": State.SUCCESS}, + { + "execution_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1), + "state": State.SUCCESS, + }, + { + "execution_date": DEFAULT_DATETIME_1 + dt.timedelta(days=2), + "state": State.SUCCESS, + }, + { + "execution_date": DEFAULT_DATETIME_1 + dt.timedelta(days=3), + "state": State.SUCCESS, + }, + { + "execution_date": DEFAULT_DATETIME_1 + dt.timedelta(days=4), + "state": State.SUCCESS, + }, + { + "execution_date": DEFAULT_DATETIME_1 + dt.timedelta(days=5), + "state": State.SUCCESS, + }, + ] + + self.create_task_instances( + session, + dag_id=dag_id, + task_instances=task_instances, + update_extras=False, + dag_run_state=State.FAILED, + ) + response = self.client.post( + f"/api/v1/dags/{dag_id}/clearTaskInstances", + environ_overrides={"REMOTE_USER": "test"}, + json=payload, + ) + + assert 200 == response.status_code + expected_response = [ + { + "dag_id": "example_python_operator", + "dag_run_id": "TEST_DAG_RUN_ID_0", + "execution_date": "2020-01-01T00:00:00+00:00", + "task_id": "print_the_context", + }, + { + "dag_id": "example_python_operator", + "dag_run_id": "TEST_DAG_RUN_ID_1", + "execution_date": "2020-01-02T00:00:00+00:00", + "task_id": "log_sql_query", + }, + { + "dag_id": "example_python_operator", + "dag_run_id": "TEST_DAG_RUN_ID_2", + "execution_date": "2020-01-03T00:00:00+00:00", + "task_id": "sleep_for_0", + }, + { + "dag_id": "example_python_operator", + "dag_run_id": "TEST_DAG_RUN_ID_3", + "execution_date": "2020-01-04T00:00:00+00:00", + "task_id": "sleep_for_1", + }, + { + "dag_id": "example_python_operator", + "dag_run_id": "TEST_DAG_RUN_ID_4", + "execution_date": "2020-01-05T00:00:00+00:00", + "task_id": "sleep_for_2", + }, + { + "dag_id": "example_python_operator", + "dag_run_id": "TEST_DAG_RUN_ID_5", + "execution_date": "2020-01-06T00:00:00+00:00", + "task_id": "sleep_for_3", + }, + ] + for task_instance in expected_response: + assert task_instance in response.json["task_instances"] + assert 6 == len(response.json["task_instances"]) + + def test_should_respond_404_for_nonexistent_dagrun_id(self, session): + dag_id = "example_python_operator" + payload = { + "dry_run": False, + "reset_dag_runs": False, + "only_failed": False, + "only_running": True, + "include_subdags": True, + "dag_run_id": "TEST_DAG_RUN_ID_100", + } + task_instances = [ + {"execution_date": DEFAULT_DATETIME_1, "state": State.RUNNING}, + { + "execution_date": DEFAULT_DATETIME_1 + dt.timedelta(days=1), + "state": State.RUNNING, + }, + ] + + self.create_task_instances( + session, + dag_id=dag_id, + task_instances=task_instances, + update_extras=False, + dag_run_state=State.FAILED, + ) + response = self.client.post( + f"/api/v1/dags/{dag_id}/clearTaskInstances", + environ_overrides={"REMOTE_USER": "test"}, + json=payload, + ) + + assert 404 == response.status_code + assert ( + response.json["title"] + == "Dag Run id TEST_DAG_RUN_ID_100 not found in dag example_python_operator" + ) + def test_should_raises_401_unauthenticated(self): response = self.client.post( "/api/v1/dags/example_python_operator/clearTaskInstances", @@ -1414,6 +1704,21 @@ def test_should_raise_400_for_naive_and_bad_datetime(self, payload, expected, se assert response.status_code == 400 assert response.json["detail"] == expected + def test_raises_404_for_non_existent_dag(self): + response = self.client.post( + "/api/v1/dags/non-existent-dag/clearTaskInstances", + environ_overrides={"REMOTE_USER": "test"}, + json={ + "dry_run": False, + "reset_dag_runs": True, + "only_failed": False, + "only_running": True, + "include_subdags": True, + }, + ) + assert response.status_code == 404 + assert response.json["title"] == "Dag id non-existent-dag not found" + class TestPostSetTaskInstanceState(TestTaskInstanceEndpoint): @mock.patch("airflow.models.dag.DAG.set_task_instance_state") @@ -1788,7 +2093,6 @@ def test_should_not_call_mocked_api_for_dry_run(self, mock_set_task_instance_sta }, ) assert response.status_code == 200 - print(response.status_code) assert response.json == { "dag_id": "example_python_operator", "dag_run_id": "TEST_DAG_RUN_ID", @@ -1889,6 +2193,47 @@ def test_should_handle_errors(self, error, code, payload, session): assert response.status_code == code assert response.json["detail"] == error + def test_should_raise_400_for_unknown_fields(self, session): + self.create_task_instances(session) + response = self.client.patch( + self.ENDPOINT_URL, + environ_overrides={"REMOTE_USER": "test"}, + json={ + "dryrun": True, + "new_state": "failed", + }, + ) + assert response.status_code == 400 + assert response.json["detail"] == "{'dryrun': ['Unknown field.']}" + + def test_should_raise_404_for_non_existent_dag(self): + response = self.client.patch( + "/api/v1/dags/non-existent-dag/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context", + environ_overrides={"REMOTE_USER": "test"}, + json={ + "dry_run": False, + "new_state": "failed", + }, + ) + assert response.status_code == 404 + assert response.json["title"] == "DAG not found" + assert response.json["detail"] == "DAG 'non-existent-dag' not found" + + def test_should_raise_404_for_non_existent_task_in_dag(self): + response = self.client.patch( + "/api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/non_existent_task", + environ_overrides={"REMOTE_USER": "test"}, + json={ + "dry_run": False, + "new_state": "failed", + }, + ) + assert response.status_code == 404 + assert response.json["title"] == "Task not found" + assert ( + response.json["detail"] == "Task 'non_existent_task' not found in DAG 'example_python_operator'" + ) + def test_should_raises_401_unauthenticated(self): response = self.client.patch( self.ENDPOINT_URL, @@ -2068,6 +2413,33 @@ def test_should_respond_200_mapped_task_instance_with_rtif(self, session): "triggerer_job": None, } + def test_should_respond_200_when_note_is_empty(self, session): + tis = self.create_task_instances(session) + for ti in tis: + ti.task_instance_note = None + session.add(ti) + session.commit() + new_note_value = "My super cool TaskInstance note." + response = self.client.patch( + "api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/" + "print_the_context/setNote", + json={"note": new_note_value}, + environ_overrides={"REMOTE_USER": "test"}, + ) + assert response.status_code == 200, response.text + assert response.json["note"] == new_note_value + + def test_should_raise_400_for_unknown_fields(self, session): + self.create_task_instances(session) + response = self.client.patch( + "api/v1/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/" + "print_the_context/setNote", + json={"note": "a valid field", "not": "an unknown field"}, + environ_overrides={"REMOTE_USER": "test"}, + ) + assert response.status_code == 400 + assert response.json["detail"] == "{'not': ['Unknown field.']}" + def test_should_raises_401_unauthenticated(self): for map_index in ["", "/0"]: url = (