Skip to content

Commit

Permalink
Improve test coverage for the rest api modules (apache#35219)
Browse files Browse the repository at this point in the history
  • Loading branch information
ephraimbuddy authored Oct 31, 2023
1 parent 76847f8 commit 119cd67
Show file tree
Hide file tree
Showing 10 changed files with 592 additions and 20 deletions.
3 changes: 1 addition & 2 deletions airflow/api_connexion/endpoints/connection_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,5 +205,4 @@ def test_connection() -> APIResponse:
except ValidationError as err:
raise BadRequest(detail=str(err.messages))
finally:
if conn_env_var in os.environ:
del os.environ[conn_env_var]
os.environ.pop(conn_env_var, None)
13 changes: 5 additions & 8 deletions airflow/api_connexion/endpoints/pool_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,11 @@ def patch_pool(
"""Update a pool."""
request_dict = get_json_request_dict()
# Only slots and include_deferred can be modified in 'default_pool'
try:
if pool_name == Pool.DEFAULT_POOL_NAME and request_dict["name"] != Pool.DEFAULT_POOL_NAME:
if update_mask and all(mask.strip() in {"slots", "include_deferred"} for mask in update_mask):
pass
else:
raise BadRequest(detail="Default Pool's name can't be modified")
except KeyError:
pass
if pool_name == Pool.DEFAULT_POOL_NAME and request_dict.get("name", None) != Pool.DEFAULT_POOL_NAME:
if update_mask and all(mask.strip() in {"slots", "include_deferred"} for mask in update_mask):
pass
else:
raise BadRequest(detail="Default Pool's name can't be modified")

pool = session.scalar(select(Pool).where(Pool.pool == pool_name).limit(1))
if not pool:
Expand Down
6 changes: 4 additions & 2 deletions airflow/api_connexion/endpoints/task_instance_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
)
from airflow.api_connexion.security import get_readable_dags
from airflow.auth.managers.models.resource_details import DagAccessEntity, DagDetails
from airflow.exceptions import TaskNotFound
from airflow.models import SlaMiss
from airflow.models.dagrun import DagRun as DR
from airflow.models.operator import needs_expansion
Expand Down Expand Up @@ -192,8 +193,9 @@ def get_mapped_task_instances(
if not dag:
error_message = f"DAG {dag_id} not found"
raise NotFound(error_message)
task = dag.get_task(task_id)
if not task:
try:
task = dag.get_task(task_id)
except TaskNotFound:
error_message = f"Task id {task_id} not found"
raise NotFound(error_message)
if not needs_expansion(task):
Expand Down
3 changes: 0 additions & 3 deletions scripts/cov/restapi_coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@
restapi_files = ["tests/api_experimental", "tests/api_connexion", "tests/api_internal"]

files_not_fully_covered = [
"airflow/api_connexion/endpoints/dag_run_endpoint.py",
"airflow/api_connexion/endpoints/forward_to_fab_endpoint.py",
"airflow/api_connexion/endpoints/pool_endpoint.py",
"airflow/api_connexion/endpoints/task_endpoint.py",
"airflow/api_connexion/endpoints/task_instance_endpoint.py",
"airflow/api_connexion/endpoints/xcom_endpoint.py",
"airflow/api_connexion/exceptions.py",
Expand Down
7 changes: 7 additions & 0 deletions tests/api_connexion/endpoints/test_connection_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from airflow.api_connexion.exceptions import EXCEPTIONS_LINK_MAP
from airflow.models import Connection
from airflow.secrets.environment_variables import CONN_ENV_PREFIX
from airflow.security import permissions
from airflow.utils.session import provide_session
from tests.test_utils.api_connexion_utils import assert_401, create_user, delete_user
Expand Down Expand Up @@ -620,6 +621,12 @@ def test_should_respond_200(self):
"message": "Connection successfully tested",
}

@mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
def test_connection_env_is_cleaned_after_run(self):
payload = {"connection_id": "test-connection-id", "conn_type": "sqlite"}
self.client.post("/api/v1/connections/test", json=payload, environ_overrides={"REMOTE_USER": "test"})
assert not any([key.startswith(CONN_ENV_PREFIX) for key in os.environ.keys()])

@mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
def test_post_should_respond_400_for_invalid_payload(self):
payload = {
Expand Down
164 changes: 160 additions & 4 deletions tests/api_connexion/endpoints/test_dag_run_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def _create_dag(self, dag_id):
self.app.dag_bag.bag_dag(dag, root_dag=dag)
return dag_instance

def _create_test_dag_run(self, state="running", extra_dag=False, commit=True, idx_start=1):
def _create_test_dag_run(self, state=DagRunState.RUNNING, extra_dag=False, commit=True, idx_start=1):
dag_runs = []
dags = []

Expand All @@ -147,6 +147,7 @@ def _create_test_dag_run(self, state="running", extra_dag=False, commit=True, id
external_trigger=True,
state=state,
)
dagrun_model.updated_at = timezone.parse(self.default_time)
dag_runs.append(dagrun_model)

if extra_dag:
Expand Down Expand Up @@ -553,12 +554,32 @@ class TestGetDagRunsPaginationFilters(TestDagRunEndpoint):
"TEST_START_EXEC_DAY_19",
],
),
(
"api/v1/dags/TEST_DAG_ID/dagRuns?updated_at_lte=2020-06-13T18%3A00%3A00%2B00%3A00",
[
"TEST_START_EXEC_DAY_10",
"TEST_START_EXEC_DAY_11",
"TEST_START_EXEC_DAY_12",
"TEST_START_EXEC_DAY_13",
],
),
(
"api/v1/dags/TEST_DAG_ID/dagRuns?updated_at_gte=2020-06-16T18%3A00%3A00%2B00%3A00",
[
"TEST_START_EXEC_DAY_16",
"TEST_START_EXEC_DAY_17",
"TEST_START_EXEC_DAY_18",
"TEST_START_EXEC_DAY_19",
],
),
],
)
@provide_session
def test_date_filters_gte_and_lte(self, url, expected_dag_run_ids, session):
dagrun_models = self._create_dag_runs()
session.add_all(dagrun_models)
for d in dagrun_models:
d.updated_at = d.execution_date
session.commit()

response = self.client.get(url, environ_overrides={"REMOTE_USER": "test"})
Expand Down Expand Up @@ -589,7 +610,7 @@ def _create_dag_runs(self):
execution_date=timezone.parse(dates[i]),
start_date=timezone.parse(dates[i]),
external_trigger=True,
state="success",
state=DagRunState.SUCCESS,
)
for i in range(len(dates))
]
Expand Down Expand Up @@ -667,6 +688,21 @@ def test_should_respond_200(self):
"total_entries": 2,
}

def test_raises_validation_error_for_invalid_request(self):
self._create_test_dag_run()
response = self.client.post(
"api/v1/dags/~/dagRuns/list",
json={"dagids": ["TEST_DAG_ID"]},
environ_overrides={"REMOTE_USER": "test"},
)
assert response.status_code == 400
assert response.json == {
"detail": "{'dagids': ['Unknown field.']}",
"status": 400,
"title": "Bad Request",
"type": EXCEPTIONS_LINK_MAP[400],
}

def test_filter_by_state(self):
self._create_test_dag_run()
self._create_test_dag_run(state="queued", idx_start=3)
Expand Down Expand Up @@ -1092,6 +1128,41 @@ def test_should_respond_200(self, session, logical_date_field_name, dag_run_id,
}
_check_last_log(session, dag_id="TEST_DAG_ID", event="dag_run.create", execution_date=None)

def test_raises_validation_error_for_invalid_request(self):
self._create_dag("TEST_DAG_ID")
response = self.client.post(
"api/v1/dags/TEST_DAG_ID/dagRuns",
json={"executiondate": "2020-11-10T08:25:56Z"},
environ_overrides={"REMOTE_USER": "test"},
)
assert response.status_code == 400
assert response.json == {
"detail": "{'executiondate': ['Unknown field.']}",
"status": 400,
"title": "Bad Request",
"type": EXCEPTIONS_LINK_MAP[400],
}

@mock.patch("airflow.api_connexion.endpoints.dag_run_endpoint.get_airflow_app")
def test_dagrun_creation_exception_is_handled(self, mock_get_app, session):
self._create_dag("TEST_DAG_ID")
error_message = "Encountered Error"
mock_get_app.return_value.dag_bag.get_dag.return_value.create_dagrun.side_effect = ValueError(
error_message
)
response = self.client.post(
"api/v1/dags/TEST_DAG_ID/dagRuns",
json={"execution_date": "2020-11-10T08:25:56Z"},
environ_overrides={"REMOTE_USER": "test"},
)
assert response.status_code == 400
assert response.json == {
"detail": error_message,
"status": 400,
"title": "Bad Request",
"type": EXCEPTIONS_LINK_MAP[400],
}

def test_should_respond_404_if_a_dag_is_inactive(self, session):
dm = self._create_dag("TEST_INACTIVE_DAG_ID")
dm.is_active = False
Expand Down Expand Up @@ -1382,6 +1453,27 @@ def test_should_respond_200(self, state, run_type, dag_maker, session):
"note": None,
}

def test_schema_validation_error_raises(self, dag_maker, session):
dag_id = "TEST_DAG_ID"
dag_run_id = "TEST_DAG_RUN_ID"
with dag_maker(dag_id) as dag:
EmptyOperator(task_id="task_id", dag=dag)
self.app.dag_bag.bag_dag(dag, root_dag=dag)
dag_maker.create_dagrun(run_id=dag_run_id)

response = self.client.patch(
f"api/v1/dags/{dag_id}/dagRuns/{dag_run_id}",
json={"states": "success"},
environ_overrides={"REMOTE_USER": "test"},
)
assert response.status_code == 400
assert response.json == {
"detail": "{'states': ['Unknown field.']}",
"status": 400,
"title": "Bad Request",
"type": EXCEPTIONS_LINK_MAP[400],
}

@pytest.mark.parametrize("invalid_state", ["running"])
@time_machine.travel(TestDagRunEndpoint.default_time)
def test_should_response_400_for_non_existing_dag_run_state(self, invalid_state, dag_maker):
Expand Down Expand Up @@ -1481,6 +1573,26 @@ def test_should_respond_200(self, dag_maker, session):
ti.refresh_from_db()
assert ti.state is None

def test_schema_validation_error_raises_for_invalid_fields(self, dag_maker, session):
dag_id = "TEST_DAG_ID"
dag_run_id = "TEST_DAG_RUN_ID"
with dag_maker(dag_id) as dag:
EmptyOperator(task_id="task_id", dag=dag)
self.app.dag_bag.bag_dag(dag, root_dag=dag)
dag_maker.create_dagrun(run_id=dag_run_id, state=DagRunState.FAILED)
response = self.client.post(
f"api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/clear",
json={"dryrun": False},
environ_overrides={"REMOTE_USER": "test"},
)
assert response.status_code == 400
assert response.json == {
"detail": "{'dryrun': ['Unknown field.']}",
"status": 400,
"title": "Bad Request",
"type": EXCEPTIONS_LINK_MAP[400],
}

def test_dry_run(self, dag_maker, session):
"""Test that dry_run being True returns TaskInstances without clearing DagRun"""
dag_id = "TEST_DAG_ID"
Expand Down Expand Up @@ -1648,11 +1760,10 @@ def test_should_raises_401_unauthenticated(self, session):

class TestSetDagRunNote(TestDagRunEndpoint):
def test_should_respond_200(self, dag_maker, session):
dag_runs: list[DagRun] = self._create_test_dag_run("success")
dag_runs: list[DagRun] = self._create_test_dag_run(DagRunState.SUCCESS)
session.add_all(dag_runs)
session.commit()
created_dr: DagRun = dag_runs[0]

new_note_value = "My super cool DagRun notes"
response = self.client.patch(
f"api/v1/dags/{created_dr.dag_id}/dagRuns/{created_dr.run_id}/setNote",
Expand Down Expand Up @@ -1680,6 +1791,51 @@ def test_should_respond_200(self, dag_maker, session):
"note": new_note_value,
}
assert dr.dag_run_note.user_id is not None
# Update the note again
new_note_value = "My super cool DagRun notes 2"
response = self.client.patch(
f"api/v1/dags/{created_dr.dag_id}/dagRuns/{created_dr.run_id}/setNote",
json={"note": new_note_value},
environ_overrides={"REMOTE_USER": "test"},
)
assert response.status_code == 200
assert response.json == {
"conf": {},
"dag_id": dr.dag_id,
"dag_run_id": dr.run_id,
"end_date": dr.end_date.isoformat(),
"execution_date": self.default_time,
"external_trigger": True,
"logical_date": self.default_time,
"start_date": self.default_time,
"state": "success",
"data_interval_start": None,
"data_interval_end": None,
"last_scheduling_decision": None,
"run_type": dr.run_type,
"note": new_note_value,
}
assert dr.dag_run_note.user_id is not None

def test_schema_validation_error_raises(self, dag_maker, session):
dag_runs: list[DagRun] = self._create_test_dag_run(DagRunState.SUCCESS)
session.add_all(dag_runs)
session.commit()
created_dr: DagRun = dag_runs[0]

new_note_value = "My super cool DagRun notes"
response = self.client.patch(
f"api/v1/dags/{created_dr.dag_id}/dagRuns/{created_dr.run_id}/setNote",
json={"notes": new_note_value},
environ_overrides={"REMOTE_USER": "test"},
)
assert response.status_code == 400
assert response.json == {
"detail": "{'notes': ['Unknown field.']}",
"status": 400,
"title": "Bad Request",
"type": EXCEPTIONS_LINK_MAP[400],
}

def test_should_raises_401_unauthenticated(self, session):
response = self.client.patch(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,18 @@ def test_mapped_task_instances_state_order(self, one_task_with_many_mapped_tis,
assert list(range(5)) + list(range(25, 110)) + list(range(5, 15)) == [
ti["map_index"] for ti in response.json["task_instances"]
]
# State ascending
response = self.client.get(
"/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped"
"?order_by=state",
environ_overrides={"REMOTE_USER": "test"},
)
assert response.status_code == 200
assert response.json["total_entries"] == 110
assert len(response.json["task_instances"]) == 100
assert list(range(5, 25)) + list(range(90, 110)) + list(range(25, 85)) == [
ti["map_index"] for ti in response.json["task_instances"]
]

@provide_session
def test_mapped_task_instances_invalid_order(self, one_task_with_many_mapped_tis, session):
Expand Down Expand Up @@ -448,3 +460,11 @@ def test_mapped_task_instances_with_zero_mapped(self, one_task_with_zero_mapped_
assert response.status_code == 200
assert response.json["total_entries"] == 0
assert response.json["task_instances"] == []

def test_should_raise_404_not_found_for_nonexistent_task(self):
response = self.client.get(
"/api/v1/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/nonexistent_task/listMapped",
environ_overrides={"REMOTE_USER": "test"},
)
assert response.status_code == 404
assert response.json["title"] == "Task id nonexistent_task not found"
14 changes: 14 additions & 0 deletions tests/api_connexion/endpoints/test_pool_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,20 @@ def test_response_400(self, error_detail, request_json, session):
"type": EXCEPTIONS_LINK_MAP[400],
} == response.json

def test_not_found_when_no_pool_available(self):
response = self.client.patch(
"api/v1/pools/test_pool",
json={"name": "test_pool_a", "slots": 3},
environ_overrides={"REMOTE_USER": "test"},
)
assert response.status_code == 404
assert {
"detail": "Pool with name:'test_pool' not found",
"status": 404,
"title": "Not Found",
"type": EXCEPTIONS_LINK_MAP[404],
} == response.json

def test_should_raises_401_unauthenticated(self, session):
pool = Pool(pool="test_pool", slots=2, include_deferred=False)
session.add(pool)
Expand Down
8 changes: 8 additions & 0 deletions tests/api_connexion/endpoints/test_task_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,14 @@ def test_should_respond_404(self):
)
assert response.status_code == 404

def test_should_respond_404_when_dag_not_found(self):
dag_id = "xxxx_not_existing"
response = self.client.get(
f"/api/v1/dags/{dag_id}/tasks/{self.task_id}", environ_overrides={"REMOTE_USER": "test"}
)
assert response.status_code == 404
assert response.json["title"] == "DAG not found"

def test_should_raises_401_unauthenticated(self):
response = self.client.get(f"/api/v1/dags/{self.dag_id}/tasks/{self.task_id}")

Expand Down
Loading

0 comments on commit 119cd67

Please sign in to comment.