Skip to content

Commit

Permalink
Use auth manager is_authorized_ APIs to check user permissions in R…
Browse files Browse the repository at this point in the history
…est API (apache#34317)
  • Loading branch information
vincbeck authored Oct 17, 2023
1 parent 85fd0e1 commit d72131f
Show file tree
Hide file tree
Showing 38 changed files with 887 additions and 601 deletions.
7 changes: 3 additions & 4 deletions airflow/api_connexion/endpoints/config_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from airflow.api_connexion.exceptions import NotFound, PermissionDenied
from airflow.api_connexion.schemas.config_schema import Config, ConfigOption, ConfigSection, config_schema
from airflow.configuration import conf
from airflow.security import permissions
from airflow.settings import json

LINE_SEP = "\n" # `\n` cannot appear in f-strings
Expand Down Expand Up @@ -66,7 +65,7 @@ def _config_to_json(config: Config) -> str:
return json.dumps(config_schema.dump(config), indent=4)


@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG)])
@security.requires_access_configuration("GET")
def get_config(*, section: str | None = None) -> Response:
"""Get current configuration."""
serializer = {
Expand Down Expand Up @@ -103,8 +102,8 @@ def get_config(*, section: str | None = None) -> Response:
)


@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG)])
def get_value(section: str, option: str) -> Response:
@security.requires_access_configuration("GET")
def get_value(*, section: str, option: str) -> Response:
serializer = {
"text/plain": _config_to_text,
"application/json": _config_to_json,
Expand Down
12 changes: 6 additions & 6 deletions airflow/api_connexion/endpoints/connection_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
RESOURCE_EVENT_PREFIX = "connection"


@security.requires_access([(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_CONNECTION)])
@security.requires_access_connection("DELETE")
@provide_session
@action_logging(
event=action_event_from_permission(
Expand All @@ -73,7 +73,7 @@ def delete_connection(*, connection_id: str, session: Session = NEW_SESSION) ->
return NoContent, HTTPStatus.NO_CONTENT


@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONNECTION)])
@security.requires_access_connection("GET")
@provide_session
def get_connection(*, connection_id: str, session: Session = NEW_SESSION) -> APIResponse:
"""Get a connection entry."""
Expand All @@ -86,7 +86,7 @@ def get_connection(*, connection_id: str, session: Session = NEW_SESSION) -> API
return connection_schema.dump(connection)


@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONNECTION)])
@security.requires_access_connection("GET")
@format_parameters({"limit": check_limit})
@provide_session
def get_connections(
Expand All @@ -109,7 +109,7 @@ def get_connections(
)


@security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_CONNECTION)])
@security.requires_access_connection("PUT")
@provide_session
@action_logging(
event=action_event_from_permission(
Expand Down Expand Up @@ -147,7 +147,7 @@ def patch_connection(
return connection_schema.dump(connection)


@security.requires_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION)])
@security.requires_access_connection("POST")
@provide_session
@action_logging(
event=action_event_from_permission(
Expand Down Expand Up @@ -176,7 +176,7 @@ def post_connection(*, session: Session = NEW_SESSION) -> APIResponse:
raise AlreadyExists(detail=f"Connection already exist. ID: {conn_id}")


@security.requires_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION)])
@security.requires_access_connection("POST")
def test_connection() -> APIResponse:
"""
Test an API connection.
Expand Down
18 changes: 9 additions & 9 deletions airflow/api_connexion/endpoints/dag_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@
)
from airflow.exceptions import AirflowException, DagNotFound
from airflow.models.dag import DagModel, DagTag
from airflow.security import permissions
from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.utils.db import get_query_count
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.www.extensions.init_auth_manager import get_auth_manager

if TYPE_CHECKING:
from sqlalchemy.orm import Session
Expand All @@ -48,7 +48,7 @@
from airflow.api_connexion.types import APIResponse, UpdateMask


@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)])
@security.requires_access_dag("GET")
@provide_session
def get_dag(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse:
"""Get basic information about a DAG."""
Expand All @@ -60,7 +60,7 @@ def get_dag(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse:
return dag_schema.dump(dag)


@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)])
@security.requires_access_dag("GET")
def get_dag_details(*, dag_id: str) -> APIResponse:
"""Get details of DAG."""
dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id)
Expand All @@ -69,7 +69,7 @@ def get_dag_details(*, dag_id: str) -> APIResponse:
return dag_detail_schema.dump(dag)


@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)])
@security.requires_access_dag("GET")
@format_parameters({"limit": check_limit})
@provide_session
def get_dags(
Expand All @@ -96,7 +96,7 @@ def get_dags(
if dag_id_pattern:
dags_query = dags_query.where(DagModel.dag_id.ilike(f"%{dag_id_pattern}%"))

readable_dags = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user)
readable_dags = get_auth_manager().get_permitted_dag_ids(user=g.user)

dags_query = dags_query.where(DagModel.dag_id.in_(readable_dags))
if tags:
Expand All @@ -110,7 +110,7 @@ def get_dags(
return dags_collection_schema.dump(DAGCollection(dags=dags, total_entries=total_entries))


@security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG)])
@security.requires_access_dag("PUT")
@provide_session
def patch_dag(*, dag_id: str, update_mask: UpdateMask = None, session: Session = NEW_SESSION) -> APIResponse:
"""Update the specific DAG."""
Expand All @@ -132,7 +132,7 @@ def patch_dag(*, dag_id: str, update_mask: UpdateMask = None, session: Session =
return dag_schema.dump(dag)


@security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG)])
@security.requires_access_dag("PUT")
@format_parameters({"limit": check_limit})
@provide_session
def patch_dags(limit, session, offset=0, only_active=True, tags=None, dag_id_pattern=None, update_mask=None):
Expand All @@ -156,7 +156,7 @@ def patch_dags(limit, session, offset=0, only_active=True, tags=None, dag_id_pat
if dag_id_pattern == "~":
dag_id_pattern = "%"
dags_query = dags_query.where(DagModel.dag_id.ilike(f"%{dag_id_pattern}%"))
editable_dags = get_airflow_app().appbuilder.sm.get_editable_dag_ids(g.user)
editable_dags = get_auth_manager().get_permitted_dag_ids(methods=["PUT"], user=g.user)

dags_query = dags_query.where(DagModel.dag_id.in_(editable_dags))
if tags:
Expand All @@ -180,7 +180,7 @@ def patch_dags(limit, session, offset=0, only_active=True, tags=None, dag_id_pat
return dags_collection_schema.dump(DAGCollection(dags=dags, total_entries=total_entries))


@security.requires_access([(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG)])
@security.requires_access_dag("DELETE")
@provide_session
def delete_dag(dag_id: str, session: Session = NEW_SESSION) -> APIResponse:
"""Delete the specific DAG."""
Expand Down
74 changes: 15 additions & 59 deletions airflow/api_connexion/endpoints/dag_run_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
TaskInstanceReferenceCollection,
task_instance_reference_collection_schema,
)
from airflow.auth.managers.models.resource_details import DagAccessEntity
from airflow.models import DagModel, DagRun
from airflow.security import permissions
from airflow.utils.airflow_flask_app import get_airflow_app
Expand All @@ -76,12 +77,7 @@
RESOURCE_EVENT_PREFIX = "dag_run"


@security.requires_access(
[
(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN),
],
)
@security.requires_access_dag("DELETE", DagAccessEntity.RUN)
@provide_session
def delete_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse:
"""Delete a DAG Run."""
Expand All @@ -93,12 +89,7 @@ def delete_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSI
return NoContent, HTTPStatus.NO_CONTENT


@security.requires_access(
[
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
],
)
@security.requires_access_dag("GET", DagAccessEntity.RUN)
@provide_session
def get_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse:
"""Get a DAG Run."""
Expand All @@ -111,13 +102,8 @@ def get_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION)
return dagrun_schema.dump(dag_run)


@security.requires_access(
[
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET),
],
)
@security.requires_access_dag("GET", DagAccessEntity.RUN)
@security.requires_access_dataset("GET")
@provide_session
def get_upstream_dataset_events(
*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION
Expand Down Expand Up @@ -194,12 +180,7 @@ def _fetch_dag_runs(
return session.scalars(query.offset(offset).limit(limit)).all(), total_entries


@security.requires_access(
[
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
],
)
@security.requires_access_dag("GET", DagAccessEntity.RUN)
@format_parameters(
{
"start_date_gte": format_datetime,
Expand Down Expand Up @@ -236,8 +217,9 @@ def get_dag_runs(

# This endpoint allows specifying ~ as the dag_id to retrieve DAG Runs for all DAGs.
if dag_id == "~":
appbuilder = get_airflow_app().appbuilder
query = query.where(DagRun.dag_id.in_(appbuilder.sm.get_readable_dag_ids(g.user)))
query = query.where(
DagRun.dag_id.in_(get_auth_manager().get_permitted_dag_ids(methods=["GET"], user=g.user))
)
else:
query = query.where(DagRun.dag_id == dag_id)

Expand All @@ -262,12 +244,7 @@ def get_dag_runs(
return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_run, total_entries=total_entries))


@security.requires_access(
[
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN),
],
)
@security.requires_access_dag("GET", DagAccessEntity.RUN)
@provide_session
def get_dag_runs_batch(*, session: Session = NEW_SESSION) -> APIResponse:
"""Get list of DAG Runs."""
Expand All @@ -277,8 +254,7 @@ def get_dag_runs_batch(*, session: Session = NEW_SESSION) -> APIResponse:
except ValidationError as err:
raise BadRequest(detail=str(err.messages))

appbuilder = get_airflow_app().appbuilder
readable_dag_ids = appbuilder.sm.get_readable_dag_ids(g.user)
readable_dag_ids = get_auth_manager().get_permitted_dag_ids(methods=["GET"], user=g.user)
query = select(DagRun)
if data.get("dag_ids"):
dag_ids = set(data["dag_ids"]) & set(readable_dag_ids)
Expand Down Expand Up @@ -307,12 +283,7 @@ def get_dag_runs_batch(*, session: Session = NEW_SESSION) -> APIResponse:
return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_runs, total_entries=total_entries))


@security.requires_access(
[
(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN),
],
)
@security.requires_access_dag("POST", DagAccessEntity.RUN)
@provide_session
@action_logging(
event=action_event_from_permission(
Expand Down Expand Up @@ -378,12 +349,7 @@ def post_dag_run(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse:
raise AlreadyExists(detail=f"DAGRun with DAG ID: '{dag_id}' and DAGRun ID: '{run_id}' already exists")


@security.requires_access(
[
(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN),
],
)
@security.requires_access_dag("PUT", DagAccessEntity.RUN)
@provide_session
def update_dag_run_state(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse:
"""Set a state of a dag run."""
Expand All @@ -410,12 +376,7 @@ def update_dag_run_state(*, dag_id: str, dag_run_id: str, session: Session = NEW
return dagrun_schema.dump(dag_run)


@security.requires_access(
[
(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN),
],
)
@security.requires_access_dag("PUT", DagAccessEntity.RUN)
@provide_session
def clear_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse:
"""Clear a dag run."""
Expand Down Expand Up @@ -461,12 +422,7 @@ def clear_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSIO
return dagrun_schema.dump(dag_run)


@security.requires_access(
[
(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG),
(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN),
],
)
@security.requires_access_dag("PUT", DagAccessEntity.RUN)
@provide_session
def set_dag_run_note(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse:
"""Set the note for a dag run."""
Expand Down
4 changes: 2 additions & 2 deletions airflow/api_connexion/endpoints/dag_source_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
from airflow.api_connexion import security
from airflow.api_connexion.exceptions import NotFound
from airflow.api_connexion.schemas.dag_source_schema import dag_source_schema
from airflow.auth.managers.models.resource_details import DagAccessEntity
from airflow.models.dagcode import DagCode
from airflow.security import permissions


@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_CODE)])
@security.requires_access_dag("GET", DagAccessEntity.CODE)
def get_dag_source(*, file_token: str) -> Response:
"""Get source code using file token."""
secret_key = current_app.config["SECRET_KEY"]
Expand Down
9 changes: 2 additions & 7 deletions airflow/api_connexion/endpoints/dag_warning_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
DagWarningCollection,
dag_warning_collection_schema,
)
from airflow.auth.managers.models.resource_details import DagAccessEntity
from airflow.models.dagwarning import DagWarning as DagWarningModel
from airflow.security import permissions
from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.utils.db import get_query_count
from airflow.utils.session import NEW_SESSION, provide_session
Expand All @@ -39,12 +39,7 @@
from airflow.api_connexion.types import APIResponse


@security.requires_access(
[
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG),
(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING),
]
)
@security.requires_access_dag("GET", DagAccessEntity.WARNING)
@format_parameters({"limit": check_limit})
@provide_session
def get_dag_warnings(
Expand Down
9 changes: 4 additions & 5 deletions airflow/api_connexion/endpoints/dataset_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
dataset_schema,
)
from airflow.models.dataset import DatasetEvent, DatasetModel
from airflow.security import permissions
from airflow.utils.db import get_query_count
from airflow.utils.session import NEW_SESSION, provide_session

Expand All @@ -42,9 +41,9 @@
from airflow.api_connexion.types import APIResponse


@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET)])
@security.requires_access_dataset("GET")
@provide_session
def get_dataset(uri: str, session: Session = NEW_SESSION) -> APIResponse:
def get_dataset(*, uri: str, session: Session = NEW_SESSION) -> APIResponse:
"""Get a Dataset."""
dataset = session.scalar(
select(DatasetModel)
Expand All @@ -59,7 +58,7 @@ def get_dataset(uri: str, session: Session = NEW_SESSION) -> APIResponse:
return dataset_schema.dump(dataset)


@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET)])
@security.requires_access_dataset("GET")
@format_parameters({"limit": check_limit})
@provide_session
def get_datasets(
Expand All @@ -86,7 +85,7 @@ def get_datasets(
return dataset_collection_schema.dump(DatasetCollection(datasets=datasets, total_entries=total_entries))


@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET)])
@security.requires_access_dataset("GET")
@provide_session
@format_parameters({"limit": check_limit})
def get_dataset_events(
Expand Down
Loading

0 comments on commit d72131f

Please sign in to comment.