From d72131f952836a3134c90805ef7c3bcf82ea93e9 Mon Sep 17 00:00:00 2001 From: Vincent <97131062+vincbeck@users.noreply.github.com> Date: Tue, 17 Oct 2023 12:43:09 -0400 Subject: [PATCH] Use auth manager `is_authorized_` APIs to check user permissions in Rest API (#34317) --- .../endpoints/config_endpoint.py | 7 +- .../endpoints/connection_endpoint.py | 12 +- .../api_connexion/endpoints/dag_endpoint.py | 18 +- .../endpoints/dag_run_endpoint.py | 74 ++---- .../endpoints/dag_source_endpoint.py | 4 +- .../endpoints/dag_warning_endpoint.py | 9 +- .../endpoints/dataset_endpoint.py | 9 +- .../endpoints/event_log_endpoint.py | 6 +- .../endpoints/extra_link_endpoint.py | 10 +- .../endpoints/import_error_endpoint.py | 6 +- .../api_connexion/endpoints/log_endpoint.py | 10 +- .../endpoints/plugin_endpoint.py | 3 +- .../api_connexion/endpoints/pool_endpoint.py | 11 +- .../endpoints/provider_endpoint.py | 3 +- .../api_connexion/endpoints/task_endpoint.py | 16 +- .../endpoints/task_instance_endpoint.py | 84 +------ .../endpoints/variable_endpoint.py | 10 +- .../api_connexion/endpoints/xcom_endpoint.py | 25 +- airflow/api_connexion/security.py | 217 +++++++++++++++++- airflow/auth/managers/base_auth_manager.py | 82 ++++++- airflow/auth/managers/fab/decorators/auth.py | 30 +++ airflow/auth/managers/fab/fab_auth_manager.py | 187 +++++++++++---- .../managers/fab/security_manager/override.py | 122 +++++++++- .../auth/managers/models/resource_details.py | 36 ++- airflow/www/auth.py | 12 +- airflow/www/extensions/init_jinja_globals.py | 4 +- airflow/www/security_manager.py | 181 +-------------- airflow/www/templates/airflow/dag.html | 7 +- airflow/www/views.py | 66 +++--- .../endpoints/test_event_log_endpoint.py | 17 +- .../endpoints/test_log_endpoint.py | 3 +- .../endpoints/test_xcom_endpoint.py | 4 - .../managers/fab/test_fab_auth_manager.py | 32 ++- tests/auth/managers/test_base_auth_manager.py | 37 ++- tests/www/test_security.py | 84 ++++--- tests/www/views/test_views_acl.py | 5 + tests/www/views/test_views_decorators.py | 44 +--- tests/www/views/test_views_tasks.py | 1 + 38 files changed, 887 insertions(+), 601 deletions(-) diff --git a/airflow/api_connexion/endpoints/config_endpoint.py b/airflow/api_connexion/endpoints/config_endpoint.py index a6fc67beb7bf5..cbb8acdfce166 100644 --- a/airflow/api_connexion/endpoints/config_endpoint.py +++ b/airflow/api_connexion/endpoints/config_endpoint.py @@ -24,7 +24,6 @@ from airflow.api_connexion.exceptions import NotFound, PermissionDenied from airflow.api_connexion.schemas.config_schema import Config, ConfigOption, ConfigSection, config_schema from airflow.configuration import conf -from airflow.security import permissions from airflow.settings import json LINE_SEP = "\n" # `\n` cannot appear in f-strings @@ -66,7 +65,7 @@ def _config_to_json(config: Config) -> str: return json.dumps(config_schema.dump(config), indent=4) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG)]) +@security.requires_access_configuration("GET") def get_config(*, section: str | None = None) -> Response: """Get current configuration.""" serializer = { @@ -103,8 +102,8 @@ def get_config(*, section: str | None = None) -> Response: ) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONFIG)]) -def get_value(section: str, option: str) -> Response: +@security.requires_access_configuration("GET") +def get_value(*, section: str, option: str) -> Response: serializer = { "text/plain": _config_to_text, "application/json": _config_to_json, diff --git a/airflow/api_connexion/endpoints/connection_endpoint.py b/airflow/api_connexion/endpoints/connection_endpoint.py index 1444421a8443c..16d9afb5b9e33 100644 --- a/airflow/api_connexion/endpoints/connection_endpoint.py +++ b/airflow/api_connexion/endpoints/connection_endpoint.py @@ -53,7 +53,7 @@ RESOURCE_EVENT_PREFIX = "connection" -@security.requires_access([(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_CONNECTION)]) +@security.requires_access_connection("DELETE") @provide_session @action_logging( event=action_event_from_permission( @@ -73,7 +73,7 @@ def delete_connection(*, connection_id: str, session: Session = NEW_SESSION) -> return NoContent, HTTPStatus.NO_CONTENT -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONNECTION)]) +@security.requires_access_connection("GET") @provide_session def get_connection(*, connection_id: str, session: Session = NEW_SESSION) -> APIResponse: """Get a connection entry.""" @@ -86,7 +86,7 @@ def get_connection(*, connection_id: str, session: Session = NEW_SESSION) -> API return connection_schema.dump(connection) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONNECTION)]) +@security.requires_access_connection("GET") @format_parameters({"limit": check_limit}) @provide_session def get_connections( @@ -109,7 +109,7 @@ def get_connections( ) -@security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_CONNECTION)]) +@security.requires_access_connection("PUT") @provide_session @action_logging( event=action_event_from_permission( @@ -147,7 +147,7 @@ def patch_connection( return connection_schema.dump(connection) -@security.requires_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION)]) +@security.requires_access_connection("POST") @provide_session @action_logging( event=action_event_from_permission( @@ -176,7 +176,7 @@ def post_connection(*, session: Session = NEW_SESSION) -> APIResponse: raise AlreadyExists(detail=f"Connection already exist. ID: {conn_id}") -@security.requires_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION)]) +@security.requires_access_connection("POST") def test_connection() -> APIResponse: """ Test an API connection. diff --git a/airflow/api_connexion/endpoints/dag_endpoint.py b/airflow/api_connexion/endpoints/dag_endpoint.py index 5aac030ecbae0..21a61a0dddac9 100644 --- a/airflow/api_connexion/endpoints/dag_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_endpoint.py @@ -36,10 +36,10 @@ ) from airflow.exceptions import AirflowException, DagNotFound from airflow.models.dag import DagModel, DagTag -from airflow.security import permissions from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.db import get_query_count from airflow.utils.session import NEW_SESSION, provide_session +from airflow.www.extensions.init_auth_manager import get_auth_manager if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -48,7 +48,7 @@ from airflow.api_connexion.types import APIResponse, UpdateMask -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)]) +@security.requires_access_dag("GET") @provide_session def get_dag(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: """Get basic information about a DAG.""" @@ -60,7 +60,7 @@ def get_dag(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: return dag_schema.dump(dag) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)]) +@security.requires_access_dag("GET") def get_dag_details(*, dag_id: str) -> APIResponse: """Get details of DAG.""" dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id) @@ -69,7 +69,7 @@ def get_dag_details(*, dag_id: str) -> APIResponse: return dag_detail_schema.dump(dag) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)]) +@security.requires_access_dag("GET") @format_parameters({"limit": check_limit}) @provide_session def get_dags( @@ -96,7 +96,7 @@ def get_dags( if dag_id_pattern: dags_query = dags_query.where(DagModel.dag_id.ilike(f"%{dag_id_pattern}%")) - readable_dags = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) + readable_dags = get_auth_manager().get_permitted_dag_ids(user=g.user) dags_query = dags_query.where(DagModel.dag_id.in_(readable_dags)) if tags: @@ -110,7 +110,7 @@ def get_dags( return dags_collection_schema.dump(DAGCollection(dags=dags, total_entries=total_entries)) -@security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG)]) +@security.requires_access_dag("PUT") @provide_session def patch_dag(*, dag_id: str, update_mask: UpdateMask = None, session: Session = NEW_SESSION) -> APIResponse: """Update the specific DAG.""" @@ -132,7 +132,7 @@ def patch_dag(*, dag_id: str, update_mask: UpdateMask = None, session: Session = return dag_schema.dump(dag) -@security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG)]) +@security.requires_access_dag("PUT") @format_parameters({"limit": check_limit}) @provide_session def patch_dags(limit, session, offset=0, only_active=True, tags=None, dag_id_pattern=None, update_mask=None): @@ -156,7 +156,7 @@ def patch_dags(limit, session, offset=0, only_active=True, tags=None, dag_id_pat if dag_id_pattern == "~": dag_id_pattern = "%" dags_query = dags_query.where(DagModel.dag_id.ilike(f"%{dag_id_pattern}%")) - editable_dags = get_airflow_app().appbuilder.sm.get_editable_dag_ids(g.user) + editable_dags = get_auth_manager().get_permitted_dag_ids(methods=["PUT"], user=g.user) dags_query = dags_query.where(DagModel.dag_id.in_(editable_dags)) if tags: @@ -180,7 +180,7 @@ def patch_dags(limit, session, offset=0, only_active=True, tags=None, dag_id_pat return dags_collection_schema.dump(DAGCollection(dags=dags, total_entries=total_entries)) -@security.requires_access([(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG)]) +@security.requires_access_dag("DELETE") @provide_session def delete_dag(dag_id: str, session: Session = NEW_SESSION) -> APIResponse: """Delete the specific DAG.""" diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py index 1a9cb03418885..45e064764c62e 100644 --- a/airflow/api_connexion/endpoints/dag_run_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py @@ -56,6 +56,7 @@ TaskInstanceReferenceCollection, task_instance_reference_collection_schema, ) +from airflow.auth.managers.models.resource_details import DagAccessEntity from airflow.models import DagModel, DagRun from airflow.security import permissions from airflow.utils.airflow_flask_app import get_airflow_app @@ -76,12 +77,7 @@ RESOURCE_EVENT_PREFIX = "dag_run" -@security.requires_access( - [ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG_RUN), - ], -) +@security.requires_access_dag("DELETE", DagAccessEntity.RUN) @provide_session def delete_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse: """Delete a DAG Run.""" @@ -93,12 +89,7 @@ def delete_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSI return NoContent, HTTPStatus.NO_CONTENT -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.RUN) @provide_session def get_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse: """Get a DAG Run.""" @@ -111,13 +102,8 @@ def get_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) return dagrun_schema.dump(dag_run) -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.RUN) +@security.requires_access_dataset("GET") @provide_session def get_upstream_dataset_events( *, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION @@ -194,12 +180,7 @@ def _fetch_dag_runs( return session.scalars(query.offset(offset).limit(limit)).all(), total_entries -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.RUN) @format_parameters( { "start_date_gte": format_datetime, @@ -236,8 +217,9 @@ def get_dag_runs( # This endpoint allows specifying ~ as the dag_id to retrieve DAG Runs for all DAGs. if dag_id == "~": - appbuilder = get_airflow_app().appbuilder - query = query.where(DagRun.dag_id.in_(appbuilder.sm.get_readable_dag_ids(g.user))) + query = query.where( + DagRun.dag_id.in_(get_auth_manager().get_permitted_dag_ids(methods=["GET"], user=g.user)) + ) else: query = query.where(DagRun.dag_id == dag_id) @@ -262,12 +244,7 @@ def get_dag_runs( return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_run, total_entries=total_entries)) -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.RUN) @provide_session def get_dag_runs_batch(*, session: Session = NEW_SESSION) -> APIResponse: """Get list of DAG Runs.""" @@ -277,8 +254,7 @@ def get_dag_runs_batch(*, session: Session = NEW_SESSION) -> APIResponse: except ValidationError as err: raise BadRequest(detail=str(err.messages)) - appbuilder = get_airflow_app().appbuilder - readable_dag_ids = appbuilder.sm.get_readable_dag_ids(g.user) + readable_dag_ids = get_auth_manager().get_permitted_dag_ids(methods=["GET"], user=g.user) query = select(DagRun) if data.get("dag_ids"): dag_ids = set(data["dag_ids"]) & set(readable_dag_ids) @@ -307,12 +283,7 @@ def get_dag_runs_batch(*, session: Session = NEW_SESSION) -> APIResponse: return dagrun_collection_schema.dump(DAGRunCollection(dag_runs=dag_runs, total_entries=total_entries)) -@security.requires_access( - [ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_CREATE, permissions.RESOURCE_DAG_RUN), - ], -) +@security.requires_access_dag("POST", DagAccessEntity.RUN) @provide_session @action_logging( event=action_event_from_permission( @@ -378,12 +349,7 @@ def post_dag_run(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: raise AlreadyExists(detail=f"DAGRun with DAG ID: '{dag_id}' and DAGRun ID: '{run_id}' already exists") -@security.requires_access( - [ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), - ], -) +@security.requires_access_dag("PUT", DagAccessEntity.RUN) @provide_session def update_dag_run_state(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse: """Set a state of a dag run.""" @@ -410,12 +376,7 @@ def update_dag_run_state(*, dag_id: str, dag_run_id: str, session: Session = NEW return dagrun_schema.dump(dag_run) -@security.requires_access( - [ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), - ], -) +@security.requires_access_dag("PUT", DagAccessEntity.RUN) @provide_session def clear_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse: """Clear a dag run.""" @@ -461,12 +422,7 @@ def clear_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSIO return dagrun_schema.dump(dag_run) -@security.requires_access( - [ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), - ], -) +@security.requires_access_dag("PUT", DagAccessEntity.RUN) @provide_session def set_dag_run_note(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse: """Set the note for a dag run.""" diff --git a/airflow/api_connexion/endpoints/dag_source_endpoint.py b/airflow/api_connexion/endpoints/dag_source_endpoint.py index b191630815004..3ee80ee857a4c 100644 --- a/airflow/api_connexion/endpoints/dag_source_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_source_endpoint.py @@ -24,11 +24,11 @@ from airflow.api_connexion import security from airflow.api_connexion.exceptions import NotFound from airflow.api_connexion.schemas.dag_source_schema import dag_source_schema +from airflow.auth.managers.models.resource_details import DagAccessEntity from airflow.models.dagcode import DagCode -from airflow.security import permissions -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_CODE)]) +@security.requires_access_dag("GET", DagAccessEntity.CODE) def get_dag_source(*, file_token: str) -> Response: """Get source code using file token.""" secret_key = current_app.config["SECRET_KEY"] diff --git a/airflow/api_connexion/endpoints/dag_warning_endpoint.py b/airflow/api_connexion/endpoints/dag_warning_endpoint.py index c9d8207b0f65f..3e0db58dc9ad5 100644 --- a/airflow/api_connexion/endpoints/dag_warning_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_warning_endpoint.py @@ -27,8 +27,8 @@ DagWarningCollection, dag_warning_collection_schema, ) +from airflow.auth.managers.models.resource_details import DagAccessEntity from airflow.models.dagwarning import DagWarning as DagWarningModel -from airflow.security import permissions from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.db import get_query_count from airflow.utils.session import NEW_SESSION, provide_session @@ -39,12 +39,7 @@ from airflow.api_connexion.types import APIResponse -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_WARNING), - ] -) +@security.requires_access_dag("GET", DagAccessEntity.WARNING) @format_parameters({"limit": check_limit}) @provide_session def get_dag_warnings( diff --git a/airflow/api_connexion/endpoints/dataset_endpoint.py b/airflow/api_connexion/endpoints/dataset_endpoint.py index 81fe872fca72a..152ac6eecb2cf 100644 --- a/airflow/api_connexion/endpoints/dataset_endpoint.py +++ b/airflow/api_connexion/endpoints/dataset_endpoint.py @@ -32,7 +32,6 @@ dataset_schema, ) from airflow.models.dataset import DatasetEvent, DatasetModel -from airflow.security import permissions from airflow.utils.db import get_query_count from airflow.utils.session import NEW_SESSION, provide_session @@ -42,9 +41,9 @@ from airflow.api_connexion.types import APIResponse -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET)]) +@security.requires_access_dataset("GET") @provide_session -def get_dataset(uri: str, session: Session = NEW_SESSION) -> APIResponse: +def get_dataset(*, uri: str, session: Session = NEW_SESSION) -> APIResponse: """Get a Dataset.""" dataset = session.scalar( select(DatasetModel) @@ -59,7 +58,7 @@ def get_dataset(uri: str, session: Session = NEW_SESSION) -> APIResponse: return dataset_schema.dump(dataset) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET)]) +@security.requires_access_dataset("GET") @format_parameters({"limit": check_limit}) @provide_session def get_datasets( @@ -86,7 +85,7 @@ def get_datasets( return dataset_collection_schema.dump(DatasetCollection(datasets=datasets, total_entries=total_entries)) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET)]) +@security.requires_access_dataset("GET") @provide_session @format_parameters({"limit": check_limit}) def get_dataset_events( diff --git a/airflow/api_connexion/endpoints/event_log_endpoint.py b/airflow/api_connexion/endpoints/event_log_endpoint.py index 99ec8eedaed74..b5bca5cc23272 100644 --- a/airflow/api_connexion/endpoints/event_log_endpoint.py +++ b/airflow/api_connexion/endpoints/event_log_endpoint.py @@ -28,8 +28,8 @@ event_log_collection_schema, event_log_schema, ) +from airflow.auth.managers.models.resource_details import DagAccessEntity from airflow.models import Log -from airflow.security import permissions from airflow.utils import timezone from airflow.utils.session import NEW_SESSION, provide_session @@ -40,7 +40,7 @@ from airflow.api_connexion.types import APIResponse -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG)]) +@security.requires_access_dag("GET", DagAccessEntity.AUDIT_LOG) @provide_session def get_event_log(*, event_log_id: int, session: Session = NEW_SESSION) -> APIResponse: """Get a log entry.""" @@ -50,7 +50,7 @@ def get_event_log(*, event_log_id: int, session: Session = NEW_SESSION) -> APIRe return event_log_schema.dump(event_log) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG)]) +@security.requires_access_dag("GET", DagAccessEntity.AUDIT_LOG) @format_parameters({"limit": check_limit}) @provide_session def get_event_logs( diff --git a/airflow/api_connexion/endpoints/extra_link_endpoint.py b/airflow/api_connexion/endpoints/extra_link_endpoint.py index ec92dd51ee7ad..2e9954587c071 100644 --- a/airflow/api_connexion/endpoints/extra_link_endpoint.py +++ b/airflow/api_connexion/endpoints/extra_link_endpoint.py @@ -22,8 +22,8 @@ from airflow.api_connexion import security from airflow.api_connexion.exceptions import NotFound +from airflow.auth.managers.models.resource_details import DagAccessEntity from airflow.exceptions import TaskNotFound -from airflow.security import permissions from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.session import NEW_SESSION, provide_session @@ -35,13 +35,7 @@ from airflow.models.dagbag import DagBag -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.TASK_INSTANCE) @provide_session def get_extra_links( *, diff --git a/airflow/api_connexion/endpoints/import_error_endpoint.py b/airflow/api_connexion/endpoints/import_error_endpoint.py index f2b9a88311f37..81459b604e9ee 100644 --- a/airflow/api_connexion/endpoints/import_error_endpoint.py +++ b/airflow/api_connexion/endpoints/import_error_endpoint.py @@ -28,8 +28,8 @@ import_error_collection_schema, import_error_schema, ) +from airflow.auth.managers.models.resource_details import DagAccessEntity from airflow.models.errors import ImportError as ImportErrorModel -from airflow.security import permissions from airflow.utils.session import NEW_SESSION, provide_session if TYPE_CHECKING: @@ -38,7 +38,7 @@ from airflow.api_connexion.types import APIResponse -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR)]) +@security.requires_access_dag("GET", DagAccessEntity.IMPORT_ERRORS) @provide_session def get_import_error(*, import_error_id: int, session: Session = NEW_SESSION) -> APIResponse: """Get an import error.""" @@ -52,7 +52,7 @@ def get_import_error(*, import_error_id: int, session: Session = NEW_SESSION) -> return import_error_schema.dump(error) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR)]) +@security.requires_access_dag("GET", DagAccessEntity.IMPORT_ERRORS) @format_parameters({"limit": check_limit}) @provide_session def get_import_errors( diff --git a/airflow/api_connexion/endpoints/log_endpoint.py b/airflow/api_connexion/endpoints/log_endpoint.py index 126b8634e3cfd..239f08ecdaf40 100644 --- a/airflow/api_connexion/endpoints/log_endpoint.py +++ b/airflow/api_connexion/endpoints/log_endpoint.py @@ -27,9 +27,9 @@ from airflow.api_connexion import security from airflow.api_connexion.exceptions import BadRequest, NotFound from airflow.api_connexion.schemas.log_schema import LogResponseObject, logs_schema +from airflow.auth.managers.models.resource_details import DagAccessEntity from airflow.exceptions import TaskNotFound from airflow.models import TaskInstance, Trigger -from airflow.security import permissions from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.log.log_reader import TaskLogReader from airflow.utils.session import NEW_SESSION, provide_session @@ -40,13 +40,7 @@ from airflow.api_connexion.types import APIResponse -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.TASK_LOGS) @provide_session def get_log( *, diff --git a/airflow/api_connexion/endpoints/plugin_endpoint.py b/airflow/api_connexion/endpoints/plugin_endpoint.py index 02ba435d52a9b..500bd65749062 100644 --- a/airflow/api_connexion/endpoints/plugin_endpoint.py +++ b/airflow/api_connexion/endpoints/plugin_endpoint.py @@ -22,13 +22,12 @@ from airflow.api_connexion.parameters import check_limit, format_parameters from airflow.api_connexion.schemas.plugin_schema import PluginCollection, plugin_collection_schema from airflow.plugins_manager import get_plugin_info -from airflow.security import permissions if TYPE_CHECKING: from airflow.api_connexion.types import APIResponse -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_PLUGIN)]) +@security.requires_access_website() @format_parameters({"limit": check_limit}) def get_plugins(*, limit: int, offset: int = 0) -> APIResponse: """Get plugins endpoint.""" diff --git a/airflow/api_connexion/endpoints/pool_endpoint.py b/airflow/api_connexion/endpoints/pool_endpoint.py index 735d777e4ca75..0fbb2c8a23f4d 100644 --- a/airflow/api_connexion/endpoints/pool_endpoint.py +++ b/airflow/api_connexion/endpoints/pool_endpoint.py @@ -30,7 +30,6 @@ from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters from airflow.api_connexion.schemas.pool_schema import PoolCollection, pool_collection_schema, pool_schema from airflow.models.pool import Pool -from airflow.security import permissions from airflow.utils.session import NEW_SESSION, provide_session if TYPE_CHECKING: @@ -39,7 +38,7 @@ from airflow.api_connexion.types import APIResponse, UpdateMask -@security.requires_access([(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_POOL)]) +@security.requires_access_pool("DELETE") @provide_session def delete_pool(*, pool_name: str, session: Session = NEW_SESSION) -> APIResponse: """Delete a pool.""" @@ -52,7 +51,7 @@ def delete_pool(*, pool_name: str, session: Session = NEW_SESSION) -> APIRespons return Response(status=HTTPStatus.NO_CONTENT) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_POOL)]) +@security.requires_access_pool("GET") @provide_session def get_pool(*, pool_name: str, session: Session = NEW_SESSION) -> APIResponse: """Get a pool.""" @@ -62,7 +61,7 @@ def get_pool(*, pool_name: str, session: Session = NEW_SESSION) -> APIResponse: return pool_schema.dump(obj) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_POOL)]) +@security.requires_access_pool("GET") @format_parameters({"limit": check_limit}) @provide_session def get_pools( @@ -82,7 +81,7 @@ def get_pools( return pool_collection_schema.dump(PoolCollection(pools=pools, total_entries=total_entries)) -@security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_POOL)]) +@security.requires_access_pool("PUT") @provide_session def patch_pool( *, @@ -138,7 +137,7 @@ def patch_pool( return pool_schema.dump(pool) -@security.requires_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_POOL)]) +@security.requires_access_pool("POST") @provide_session def post_pool(*, session: Session = NEW_SESSION) -> APIResponse: """Create a pool.""" diff --git a/airflow/api_connexion/endpoints/provider_endpoint.py b/airflow/api_connexion/endpoints/provider_endpoint.py index 75bba31218d05..a64368dce3587 100644 --- a/airflow/api_connexion/endpoints/provider_endpoint.py +++ b/airflow/api_connexion/endpoints/provider_endpoint.py @@ -27,7 +27,6 @@ provider_collection_schema, ) from airflow.providers_manager import ProvidersManager -from airflow.security import permissions if TYPE_CHECKING: from airflow.api_connexion.types import APIResponse @@ -46,7 +45,7 @@ def _provider_mapper(provider: ProviderInfo) -> Provider: ) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_PROVIDER)]) +@security.requires_access_website() def get_providers() -> APIResponse: """Get providers.""" providers = [_provider_mapper(d) for d in ProvidersManager().providers.values()] diff --git a/airflow/api_connexion/endpoints/task_endpoint.py b/airflow/api_connexion/endpoints/task_endpoint.py index 70b6e4b8aba41..4c5954d2ac5f0 100644 --- a/airflow/api_connexion/endpoints/task_endpoint.py +++ b/airflow/api_connexion/endpoints/task_endpoint.py @@ -22,8 +22,8 @@ from airflow.api_connexion import security from airflow.api_connexion.exceptions import BadRequest, NotFound from airflow.api_connexion.schemas.task_schema import TaskCollection, task_collection_schema, task_schema +from airflow.auth.managers.models.resource_details import DagAccessEntity from airflow.exceptions import TaskNotFound -from airflow.security import permissions from airflow.utils.airflow_flask_app import get_airflow_app if TYPE_CHECKING: @@ -31,12 +31,7 @@ from airflow.api_connexion.types import APIResponse -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.TASK) def get_task(*, dag_id: str, task_id: str) -> APIResponse: """Get simplified representation of a task.""" dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id) @@ -50,12 +45,7 @@ def get_task(*, dag_id: str, task_id: str) -> APIResponse: return task_schema.dump(task) -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.TASK) def get_tasks(*, dag_id: str, order_by: str = "task_id") -> APIResponse: """Get tasks for DAG.""" dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id) diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py b/airflow/api_connexion/endpoints/task_instance_endpoint.py index 612167d3d7662..0b942134acc80 100644 --- a/airflow/api_connexion/endpoints/task_instance_endpoint.py +++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py @@ -42,11 +42,11 @@ task_instance_schema, ) from airflow.api_connexion.security import get_readable_dags +from airflow.auth.managers.models.resource_details import DagAccessEntity, DagDetails from airflow.models import SlaMiss from airflow.models.dagrun import DagRun as DR from airflow.models.operator import needs_expansion from airflow.models.taskinstance import TaskInstance as TI, clear_task_instances -from airflow.security import permissions from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.db import get_query_count from airflow.utils.session import NEW_SESSION, provide_session @@ -62,13 +62,7 @@ T = TypeVar("T") -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.TASK_INSTANCE) @provide_session def get_task_instance( *, @@ -110,13 +104,7 @@ def get_task_instance( return task_instance_schema.dump(task_instance) -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.TASK_INSTANCE) @provide_session def get_mapped_task_instance( *, @@ -162,13 +150,7 @@ def get_mapped_task_instance( "updated_at_lte": format_datetime, }, ) -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.TASK_INSTANCE) @provide_session def get_mapped_task_instances( *, @@ -306,13 +288,7 @@ def _apply_range_filter(query: Select, key: ClauseElement, value_range: tuple[T, "updated_at_lte": format_datetime, }, ) -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.TASK_INSTANCE) @provide_session def get_task_instances( *, @@ -389,13 +365,7 @@ def get_task_instances( ) -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.TASK_INSTANCE) @provide_session def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse: """Get list of task instances.""" @@ -408,7 +378,7 @@ def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse: if dag_ids: cannot_access_dag_ids = set() for id in dag_ids: - if not get_airflow_app().appbuilder.sm.can_read_dag(id, g.user): + if not get_auth_manager().is_authorized_dag(method="GET", details=DagDetails(id=id), user=g.user): cannot_access_dag_ids.add(id) if cannot_access_dag_ids: raise PermissionDenied( @@ -464,13 +434,7 @@ def get_task_instances_batch(session: Session = NEW_SESSION) -> APIResponse: ) -@security.requires_access( - [ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("PUT", DagAccessEntity.TASK_INSTANCE) @provide_session def post_clear_task_instances(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: """Clear task instances.""" @@ -530,13 +494,7 @@ def post_clear_task_instances(*, dag_id: str, session: Session = NEW_SESSION) -> ) -@security.requires_access( - [ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("PUT", DagAccessEntity.TASK_INSTANCE) @provide_session def post_set_task_instances_state(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse: """Set a state of task instances.""" @@ -603,13 +561,7 @@ def set_mapped_task_instance_note( return set_task_instance_note(dag_id=dag_id, dag_run_id=dag_run_id, task_id=task_id, map_index=map_index) -@security.requires_access( - [ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("PUT", DagAccessEntity.TASK_INSTANCE) @provide_session def patch_task_instance( *, dag_id: str, dag_run_id: str, task_id: str, map_index: int = -1, session: Session = NEW_SESSION @@ -649,13 +601,7 @@ def patch_task_instance( return task_instance_reference_schema.dump(ti) -@security.requires_access( - [ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("PUT", DagAccessEntity.TASK_INSTANCE) @provide_session def patch_mapped_task_instance( *, dag_id: str, dag_run_id: str, task_id: str, map_index: int, session: Session = NEW_SESSION @@ -666,13 +612,7 @@ def patch_mapped_task_instance( ) -@security.requires_access( - [ - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_TASK_INSTANCE), - ], -) +@security.requires_access_dag("PUT", DagAccessEntity.TASK_INSTANCE) @provide_session def set_task_instance_note( *, dag_id: str, dag_run_id: str, task_id: str, map_index: int = -1, session: Session = NEW_SESSION diff --git a/airflow/api_connexion/endpoints/variable_endpoint.py b/airflow/api_connexion/endpoints/variable_endpoint.py index 54d5ac744b6c3..05157298e7181 100644 --- a/airflow/api_connexion/endpoints/variable_endpoint.py +++ b/airflow/api_connexion/endpoints/variable_endpoint.py @@ -43,7 +43,7 @@ RESOURCE_EVENT_PREFIX = "variable" -@security.requires_access([(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_VARIABLE)]) +@security.requires_access_variable("DELETE") @action_logging( event=action_event_from_permission( prefix=RESOURCE_EVENT_PREFIX, @@ -57,7 +57,7 @@ def delete_variable(*, variable_key: str) -> Response: return Response(status=HTTPStatus.NO_CONTENT) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_VARIABLE)]) +@security.requires_access_variable("DELETE") @provide_session def get_variable(*, variable_key: str, session: Session = NEW_SESSION) -> Response: """Get a variable by key.""" @@ -67,7 +67,7 @@ def get_variable(*, variable_key: str, session: Session = NEW_SESSION) -> Respon return variable_schema.dump(var) -@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_VARIABLE)]) +@security.requires_access_variable("GET") @format_parameters({"limit": check_limit}) @provide_session def get_variables( @@ -92,7 +92,7 @@ def get_variables( ) -@security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_VARIABLE)]) +@security.requires_access_variable("PUT") @provide_session @action_logging( event=action_event_from_permission( @@ -126,7 +126,7 @@ def patch_variable( return variable_schema.dump(variable) -@security.requires_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_VARIABLE)]) +@security.requires_access_variable("POST") @action_logging( event=action_event_from_permission( prefix=RESOURCE_EVENT_PREFIX, diff --git a/airflow/api_connexion/endpoints/xcom_endpoint.py b/airflow/api_connexion/endpoints/xcom_endpoint.py index 73bdd8562e9a5..d5eb6ed19bfb4 100644 --- a/airflow/api_connexion/endpoints/xcom_endpoint.py +++ b/airflow/api_connexion/endpoints/xcom_endpoint.py @@ -26,12 +26,12 @@ from airflow.api_connexion.exceptions import BadRequest, NotFound from airflow.api_connexion.parameters import check_limit, format_parameters from airflow.api_connexion.schemas.xcom_schema import XComCollection, xcom_collection_schema, xcom_schema +from airflow.auth.managers.models.resource_details import DagAccessEntity from airflow.models import DagRun as DR, XCom -from airflow.security import permissions from airflow.settings import conf -from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.db import get_query_count from airflow.utils.session import NEW_SESSION, provide_session +from airflow.www.extensions.init_auth_manager import get_auth_manager if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -39,14 +39,7 @@ from airflow.api_connexion.types import APIResponse -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_XCOM), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.XCOM) @format_parameters({"limit": check_limit}) @provide_session def get_xcom_entries( @@ -63,8 +56,7 @@ def get_xcom_entries( """Get all XCom values.""" query = select(XCom) if dag_id == "~": - appbuilder = get_airflow_app().appbuilder - readable_dag_ids = appbuilder.sm.get_readable_dag_ids(g.user) + readable_dag_ids = get_auth_manager().get_permitted_dag_ids(methods=["GET"], user=g.user) query = query.where(XCom.dag_id.in_(readable_dag_ids)) query = query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.run_id == DR.run_id)) else: @@ -85,14 +77,7 @@ def get_xcom_entries( return xcom_collection_schema.dump(XComCollection(xcom_entries=query, total_entries=total_entries)) -@security.requires_access( - [ - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_RUN), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_TASK_INSTANCE), - (permissions.ACTION_CAN_READ, permissions.RESOURCE_XCOM), - ], -) +@security.requires_access_dag("GET", DagAccessEntity.XCOM) @provide_session def get_xcom_entry( *, diff --git a/airflow/api_connexion/security.py b/airflow/api_connexion/security.py index b19f15257c18a..6da171aa629dc 100644 --- a/airflow/api_connexion/security.py +++ b/airflow/api_connexion/security.py @@ -16,13 +16,28 @@ # under the License. from __future__ import annotations +import warnings from functools import wraps -from typing import Callable, Sequence, TypeVar, cast +from typing import TYPE_CHECKING, Callable, Sequence, TypeVar, cast from flask import Response, g from airflow.api_connexion.exceptions import PermissionDenied, Unauthenticated +from airflow.auth.managers.models.resource_details import ( + ConfigurationDetails, + ConnectionDetails, + DagAccessEntity, + DagDetails, + DatasetDetails, + PoolDetails, + VariableDetails, +) +from airflow.exceptions import RemovedInAirflow3Warning from airflow.utils.airflow_flask_app import get_airflow_app +from airflow.www.extensions.init_auth_manager import get_auth_manager + +if TYPE_CHECKING: + from airflow.auth.managers.base_auth_manager import ResourceMethod T = TypeVar("T", bound=Callable) @@ -39,18 +54,202 @@ def check_authentication() -> None: def requires_access(permissions: Sequence[tuple[str, str]] | None = None) -> Callable[[T], T]: - """Check current user's permissions against required permissions.""" - appbuilder = get_airflow_app().appbuilder - if appbuilder.update_perms: - appbuilder.sm.sync_resource_permissions(permissions) + """ + Check current user's permissions against required permissions. + + Deprecated. Do not use this decorator, use one of the decorator `has_access_*` defined in + airflow/api_connexion/security.py instead. + This decorator will only work with FAB authentication and not with other auth providers. + + This decorator might be used in user plugins, do not remove it. + """ + warnings.warn( + "The 'requires_access' decorator is deprecated. Please use one of the decorator `requires_access_*`" + "defined in airflow/api_connexion/security.py instead.", + RemovedInAirflow3Warning, + stacklevel=2, + ) + from airflow.auth.managers.fab.decorators.auth import _requires_access_fab + + return _requires_access_fab(permissions) + + +def _requires_access(*, is_authorized_callback: Callable[[], bool], func: Callable, args, kwargs) -> bool: + """ + Define the behavior whether the user is authorized to access the resource. + + :param is_authorized_callback: callback to execute to figure whether the user is authorized to access + the resource + :param func: the function to call if the user is authorized + :param args: the arguments of ``func`` + :param kwargs: the keyword arguments ``func`` + + :meta private: + """ + check_authentication() + if is_authorized_callback(): + return func(*args, **kwargs) + raise PermissionDenied() + + +def requires_authentication(func: T): + """Decorator for functions that require authentication.""" + + @wraps(func) + def decorated(*args, **kwargs): + check_authentication() + return func(*args, **kwargs) + + return cast(T, decorated) + + +def requires_access_configuration(method: ResourceMethod) -> Callable[[T], T]: + def requires_access_decorator(func: T): + @wraps(func) + def decorated(*args, **kwargs): + section: str | None = kwargs.get("section") + return _requires_access( + is_authorized_callback=lambda: get_auth_manager().is_authorized_configuration( + method=method, details=ConfigurationDetails(section=section) + ), + func=func, + args=args, + kwargs=kwargs, + ) + + return cast(T, decorated) + + return requires_access_decorator + + +def requires_access_connection(method: ResourceMethod) -> Callable[[T], T]: + def requires_access_decorator(func: T): + @wraps(func) + def decorated(*args, **kwargs): + connection_id: str | None = kwargs.get("connection_id") + return _requires_access( + is_authorized_callback=lambda: get_auth_manager().is_authorized_connection( + method=method, details=ConnectionDetails(conn_id=connection_id) + ), + func=func, + args=args, + kwargs=kwargs, + ) + + return cast(T, decorated) + + return requires_access_decorator + + +def requires_access_dag( + method: ResourceMethod, access_entity: DagAccessEntity | None = None +) -> Callable[[T], T]: + def _is_authorized_callback(dag_id: str): + def callback(): + access = get_auth_manager().is_authorized_dag( + method=method, + access_entity=access_entity, + details=DagDetails(id=dag_id), + ) + + # ``access`` means here: + # - if a DAG id is provided (``dag_id`` not None): is the user authorized to access this DAG + # - if no DAG id is provided: is the user authorized to access all DAGs + if dag_id or access: + return access + + # No DAG id is provided and the user is not authorized to access all DAGs + # If method is "GET", return whether the user has read access to any DAGs + # If method is "PUT", return whether the user has edit access to any DAGs + return (method == "GET" and any(get_auth_manager().get_permitted_dag_ids(methods=["GET"]))) or ( + method == "PUT" and any(get_auth_manager().get_permitted_dag_ids(methods=["PUT"])) + ) + + return callback + + def requires_access_decorator(func: T): + @wraps(func) + def decorated(*args, **kwargs): + dag_id: str | None = kwargs.get("dag_id") if kwargs.get("dag_id") != "~" else None + return _requires_access( + is_authorized_callback=_is_authorized_callback(dag_id), + func=func, + args=args, + kwargs=kwargs, + ) + + return cast(T, decorated) + + return requires_access_decorator + + +def requires_access_dataset(method: ResourceMethod) -> Callable[[T], T]: + def requires_access_decorator(func: T): + @wraps(func) + def decorated(*args, **kwargs): + uri: str | None = kwargs.get("uri") + return _requires_access( + is_authorized_callback=lambda: get_auth_manager().is_authorized_dataset( + method=method, details=DatasetDetails(uri=uri) + ), + func=func, + args=args, + kwargs=kwargs, + ) + + return cast(T, decorated) + + return requires_access_decorator + + +def requires_access_pool(method: ResourceMethod) -> Callable[[T], T]: + def requires_access_decorator(func: T): + @wraps(func) + def decorated(*args, **kwargs): + pool_name: str | None = kwargs.get("pool_name") + return _requires_access( + is_authorized_callback=lambda: get_auth_manager().is_authorized_pool( + method=method, details=PoolDetails(name=pool_name) + ), + func=func, + args=args, + kwargs=kwargs, + ) + + return cast(T, decorated) + + return requires_access_decorator + + +def requires_access_variable(method: ResourceMethod) -> Callable[[T], T]: + def requires_access_decorator(func: T): + @wraps(func) + def decorated(*args, **kwargs): + variable_key: str | None = kwargs.get("variable_key") + return _requires_access( + is_authorized_callback=lambda: get_auth_manager().is_authorized_variable( + method=method, details=VariableDetails(key=variable_key) + ), + func=func, + args=args, + kwargs=kwargs, + ) + + return cast(T, decorated) + + return requires_access_decorator + +def requires_access_website() -> Callable[[T], T]: def requires_access_decorator(func: T): @wraps(func) def decorated(*args, **kwargs): - check_authentication() - if appbuilder.sm.check_authorization(permissions, kwargs.get("dag_id")): - return func(*args, **kwargs) - raise PermissionDenied() + return _requires_access( + is_authorized_callback=lambda: get_auth_manager().is_authorized_website(), + func=func, + args=args, + kwargs=kwargs, + ) return cast(T, decorated) diff --git a/airflow/auth/managers/base_auth_manager.py b/airflow/auth/managers/base_auth_manager.py index 29695dae12aef..07003380693c0 100644 --- a/airflow/auth/managers/base_auth_manager.py +++ b/airflow/auth/managers/base_auth_manager.py @@ -18,19 +18,28 @@ from __future__ import annotations from abc import abstractmethod -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Container, Literal + +from sqlalchemy import select from airflow.exceptions import AirflowException +from airflow.models import DagModel from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.session import NEW_SESSION, provide_session if TYPE_CHECKING: from flask import Flask + from sqlalchemy.orm import Session from airflow.auth.managers.models.base_user import BaseUser from airflow.auth.managers.models.resource_details import ( + ConfigurationDetails, ConnectionDetails, DagAccessEntity, DagDetails, + DatasetDetails, + PoolDetails, + VariableDetails, ) from airflow.cli.cli_config import CLICommand from airflow.www.security_manager import AirflowSecurityManagerV2 @@ -82,12 +91,14 @@ def is_authorized_configuration( self, *, method: ResourceMethod, + details: ConfigurationDetails | None = None, user: BaseUser | None = None, ) -> bool: """ Return whether the user is authorized to perform a given action on configuration. :param method: the method to perform + :param details: optional details about the configuration :param user: the user to perform the action on. If not provided (or None), it uses the current user """ @@ -110,14 +121,14 @@ def is_authorized_connection( self, *, method: ResourceMethod, - connection_details: ConnectionDetails | None = None, + details: ConnectionDetails | None = None, user: BaseUser | None = None, ) -> bool: """ Return whether the user is authorized to perform a given action on a connection. :param method: the method to perform - :param connection_details: optional details about the connection + :param details: optional details about the connection :param user: the user to perform the action on. If not provided (or None), it uses the current user """ @@ -126,17 +137,17 @@ def is_authorized_dag( self, *, method: ResourceMethod, - dag_access_entity: DagAccessEntity | None = None, - dag_details: DagDetails | None = None, + access_entity: DagAccessEntity | None = None, + details: DagDetails | None = None, user: BaseUser | None = None, ) -> bool: """ Return whether the user is authorized to perform a given action on a DAG. :param method: the method to perform - :param dag_access_entity: the kind of DAG information the authorization request is about. + :param access_entity: the kind of DAG information the authorization request is about. If not provided, the authorization request is about the DAG itself - :param dag_details: optional details about the DAG + :param details: optional details about the DAG :param user: the user to perform the action on. If not provided (or None), it uses the current user """ @@ -145,12 +156,30 @@ def is_authorized_dataset( self, *, method: ResourceMethod, + details: DatasetDetails | None = None, user: BaseUser | None = None, ) -> bool: """ Return whether the user is authorized to perform a given action on a dataset. :param method: the method to perform + :param details: optional details about the dataset + :param user: the user to perform the action on. If not provided (or None), it uses the current user + """ + + @abstractmethod + def is_authorized_pool( + self, + *, + method: ResourceMethod, + details: PoolDetails | None = None, + user: BaseUser | None = None, + ) -> bool: + """ + Return whether the user is authorized to perform a given action on a pool. + + :param method: the method to perform + :param details: optional details about the pool :param user: the user to perform the action on. If not provided (or None), it uses the current user """ @@ -159,12 +188,14 @@ def is_authorized_variable( self, *, method: ResourceMethod, + details: VariableDetails | None = None, user: BaseUser | None = None, ) -> bool: """ Return whether the user is authorized to perform a given action on a variable. :param method: the method to perform + :param details: optional details about the variable :param user: the user to perform the action on. If not provided (or None), it uses the current user """ @@ -182,6 +213,43 @@ def is_authorized_website( :param user: the user to perform the action on. If not provided (or None), it uses the current user """ + @provide_session + def get_permitted_dag_ids( + self, + *, + methods: Container[ResourceMethod] | None = None, + user=None, + session: Session = NEW_SESSION, + ) -> set[str]: + """ + Get readable or writable DAGs for user. + + By default, reads all the DAGs and check individually if the user has permissions to access the DAG. + Can lead to some poor performance. It is recommended to override this method in the auth manager + implementation to provide a more efficient implementation. + """ + if not methods: + methods = ["PUT", "GET"] + + dag_ids = {dag.dag_id for dag in session.execute(select(DagModel.dag_id))} + + if ("GET" in methods and self.is_authorized_dag(method="GET", user=user)) or ( + "PUT" in methods and self.is_authorized_dag(method="PUT", user=user) + ): + # If user is authorized to read/edit all DAGs, return all DAGs + return dag_ids + + def _is_permitted_dag_id(method: ResourceMethod, methods: Container[ResourceMethod], dag_id: str): + return method in methods and self.is_authorized_dag( + method=method, details=DagDetails(id=dag_id), user=user + ) + + return { + dag_id + for dag_id in dag_ids + if _is_permitted_dag_id("GET", methods, dag_id) or _is_permitted_dag_id("PUT", methods, dag_id) + } + @abstractmethod def get_url_login(self, **kwargs) -> str: """Return the login page url.""" diff --git a/airflow/auth/managers/fab/decorators/auth.py b/airflow/auth/managers/fab/decorators/auth.py index 5f0f16147075b..583e18e2a7343 100644 --- a/airflow/auth/managers/fab/decorators/auth.py +++ b/airflow/auth/managers/fab/decorators/auth.py @@ -23,7 +23,10 @@ from flask import current_app, render_template, request +from airflow.api_connexion.exceptions import PermissionDenied +from airflow.api_connexion.security import check_authentication from airflow.configuration import conf +from airflow.utils.airflow_flask_app import get_airflow_app from airflow.utils.net import get_hostname from airflow.www.auth import _has_access from airflow.www.extensions.init_auth_manager import get_auth_manager @@ -33,6 +36,33 @@ log = logging.getLogger(__name__) +def _requires_access_fab(permissions: Sequence[tuple[str, str]] | None = None) -> Callable[[T], T]: + """ + Check current user's permissions against required permissions. + + This decorator is only kept for backward compatible reasons. The decorator + ``airflow.api_connexion.security.requires_access``, which redirects to this decorator, might be used in + user plugins. Thus, we need to keep it. + + :meta private: + """ + appbuilder = get_airflow_app().appbuilder + if appbuilder.update_perms: + appbuilder.sm.sync_resource_permissions(permissions) + + def requires_access_decorator(func: T): + @wraps(func) + def decorated(*args, **kwargs): + check_authentication() + if appbuilder.sm.check_authorization(permissions, kwargs.get("dag_id")): + return func(*args, **kwargs) + raise PermissionDenied() + + return cast(T, decorated) + + return requires_access_decorator + + def _has_access_fab(permissions: Sequence[tuple[str, str]] | None = None) -> Callable[[T], T]: """ Factory for decorator that checks current user's permissions against required permissions. diff --git a/airflow/auth/managers/fab/fab_auth_manager.py b/airflow/auth/managers/fab/fab_auth_manager.py index 9c2c5643b60f5..6c942babbfc61 100644 --- a/airflow/auth/managers/fab/fab_auth_manager.py +++ b/airflow/auth/managers/fab/fab_auth_manager.py @@ -18,10 +18,11 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Container from flask import url_for from sqlalchemy import select +from sqlalchemy.orm import Session, joinedload from airflow.auth.managers.base_auth_manager import BaseAuthManager, ResourceMethod from airflow.auth.managers.fab.cli_commands.definition import ( @@ -29,12 +30,22 @@ SYNC_PERM_COMMAND, USERS_COMMANDS, ) -from airflow.auth.managers.models.resource_details import ConnectionDetails, DagAccessEntity, DagDetails +from airflow.auth.managers.fab.models import Permission, Role, User +from airflow.auth.managers.models.resource_details import ( + ConfigurationDetails, + ConnectionDetails, + DagAccessEntity, + DagDetails, + DatasetDetails, + PoolDetails, + VariableDetails, +) from airflow.cli.cli_config import ( GroupCommand, ) from airflow.exceptions import AirflowException from airflow.models import DagModel +from airflow.security import permissions from airflow.security.permissions import ( ACTION_CAN_ACCESS_MENU, ACTION_CAN_CREATE, @@ -50,37 +61,51 @@ RESOURCE_DAG_DEPENDENCIES, RESOURCE_DAG_PREFIX, RESOURCE_DAG_RUN, + RESOURCE_DAG_WARNING, RESOURCE_DATASET, + RESOURCE_IMPORT_ERROR, + RESOURCE_PLUGIN, + RESOURCE_POOL, + RESOURCE_PROVIDER, RESOURCE_TASK_INSTANCE, RESOURCE_TASK_LOG, + RESOURCE_TRIGGER, RESOURCE_VARIABLE, RESOURCE_WEBSITE, RESOURCE_XCOM, ) +from airflow.utils.session import NEW_SESSION, provide_session if TYPE_CHECKING: - from airflow.auth.managers.fab.models import User + from airflow.auth.managers.models.base_user import BaseUser from airflow.cli.cli_config import ( CLICommand, ) -_MAP_METHOD_NAME_TO_FAB_ACTION_NAME: dict[ResourceMethod, str] = { +MAP_METHOD_NAME_TO_FAB_ACTION_NAME: dict[ResourceMethod, str] = { "POST": ACTION_CAN_CREATE, "GET": ACTION_CAN_READ, "PUT": ACTION_CAN_EDIT, "DELETE": ACTION_CAN_DELETE, } -_MAP_DAG_ACCESS_ENTITY_TO_FAB_RESOURCE_TYPE = { - DagAccessEntity.AUDIT_LOG: RESOURCE_AUDIT_LOG, - DagAccessEntity.CODE: RESOURCE_DAG_CODE, - DagAccessEntity.DATASET: RESOURCE_DATASET, - DagAccessEntity.DEPENDENCIES: RESOURCE_DAG_DEPENDENCIES, - DagAccessEntity.RUN: RESOURCE_DAG_RUN, - DagAccessEntity.TASK_INSTANCE: RESOURCE_TASK_INSTANCE, - DagAccessEntity.TASK_LOGS: RESOURCE_TASK_LOG, - DagAccessEntity.XCOM: RESOURCE_XCOM, +_MAP_DAG_ACCESS_ENTITY_TO_FAB_RESOURCE_TYPE: dict[DagAccessEntity, tuple[str, ...]] = { + DagAccessEntity.AUDIT_LOG: (RESOURCE_AUDIT_LOG,), + DagAccessEntity.CODE: (RESOURCE_DAG_CODE,), + DagAccessEntity.DEPENDENCIES: (RESOURCE_DAG_DEPENDENCIES,), + DagAccessEntity.IMPORT_ERRORS: (RESOURCE_IMPORT_ERROR,), + DagAccessEntity.RUN: (RESOURCE_DAG_RUN,), + # RESOURCE_TASK_INSTANCE has been originally misused. RESOURCE_TASK_INSTANCE referred to task definition + # AND task instances without making the difference + # To be backward compatible, we translate DagAccessEntity.TASK_INSTANCE to RESOURCE_TASK_INSTANCE AND + # RESOURCE_DAG_RUN + # See https://github.com/apache/airflow/pull/34317#discussion_r1355917769 + DagAccessEntity.TASK: (RESOURCE_TASK_INSTANCE,), + DagAccessEntity.TASK_INSTANCE: (RESOURCE_DAG_RUN, RESOURCE_TASK_INSTANCE), + DagAccessEntity.TASK_LOGS: (RESOURCE_TASK_LOG,), + DagAccessEntity.WARNING: (RESOURCE_DAG_WARNING,), + DagAccessEntity.XCOM: (RESOURCE_XCOM,), } @@ -139,7 +164,13 @@ def is_logged_in(self) -> bool: """Return whether the user is logged in.""" return not self.get_user().is_anonymous - def is_authorized_configuration(self, *, method: ResourceMethod, user: BaseUser | None = None) -> bool: + def is_authorized_configuration( + self, + *, + method: ResourceMethod, + details: ConfigurationDetails | None = None, + user: BaseUser | None = None, + ) -> bool: return self._is_authorized(method=method, resource_type=RESOURCE_CONFIG, user=user) def is_authorized_cluster_activity(self, *, method: ResourceMethod, user: BaseUser | None = None) -> bool: @@ -149,7 +180,7 @@ def is_authorized_connection( self, *, method: ResourceMethod, - connection_details: ConnectionDetails | None = None, + details: ConnectionDetails | None = None, user: BaseUser | None = None, ) -> bool: return self._is_authorized(method=method, resource_type=RESOURCE_CONNECTION, user=user) @@ -158,8 +189,8 @@ def is_authorized_dag( self, *, method: ResourceMethod, - dag_access_entity: DagAccessEntity | None = None, - dag_details: DagDetails | None = None, + access_entity: DagAccessEntity | None = None, + details: DagDetails | None = None, user: BaseUser | None = None, ) -> bool: """ @@ -171,34 +202,111 @@ def is_authorized_dag( entity (e.g. DAG runs). 2. ``dag_access`` is provided which means the user wants to access a sub entity of the DAG (e.g. DAG runs). - a. If ``method`` is GET, then check the user has READ permissions on the DAG and the sub entity - b. Else, check the user has EDIT permissions on the DAG and ``method`` on the sub entity + a. If ``method`` is GET, then check the user has READ permissions on the DAG and the sub entity. + b. Else, check the user has EDIT permissions on the DAG and ``method`` on the sub entity. + + However, if no specific DAG is targeted, just check the sub entity. :param method: The method to authorize. - :param dag_access_entity: The dag access entity. - :param dag_details: The dag details. + :param access_entity: The dag access entity. + :param details: The dag details. :param user: The user. """ - if not dag_access_entity: + if not access_entity: # Scenario 1 - return self._is_authorized_dag(method=method, dag_details=dag_details, user=user) + return self._is_authorized_dag(method=method, details=details, user=user) else: # Scenario 2 - resource_type = self._get_fab_resource_type(dag_access_entity) + resource_types = self._get_fab_resource_types(access_entity) dag_method: ResourceMethod = "GET" if method == "GET" else "PUT" - return self._is_authorized_dag( - method=dag_method, dag_details=dag_details, user=user - ) and self._is_authorized(method=method, resource_type=resource_type, user=user) + if (details and details.id) and not self._is_authorized_dag( + method=dag_method, details=details, user=user + ): + return False + + return all( + self._is_authorized(method=method, resource_type=resource_type, user=user) + for resource_type in resource_types + ) - def is_authorized_dataset(self, *, method: ResourceMethod, user: BaseUser | None = None) -> bool: + def is_authorized_dataset( + self, *, method: ResourceMethod, details: DatasetDetails | None = None, user: BaseUser | None = None + ) -> bool: return self._is_authorized(method=method, resource_type=RESOURCE_DATASET, user=user) - def is_authorized_variable(self, *, method: ResourceMethod, user: BaseUser | None = None) -> bool: + def is_authorized_pool( + self, *, method: ResourceMethod, details: PoolDetails | None = None, user: BaseUser | None = None + ) -> bool: + return self._is_authorized(method=method, resource_type=RESOURCE_POOL, user=user) + + def is_authorized_variable( + self, *, method: ResourceMethod, details: VariableDetails | None = None, user: BaseUser | None = None + ) -> bool: return self._is_authorized(method=method, resource_type=RESOURCE_VARIABLE, user=user) def is_authorized_website(self, *, user: BaseUser | None = None) -> bool: - return self._is_authorized(method="GET", resource_type=RESOURCE_WEBSITE, user=user) + return ( + self._is_authorized(method="GET", resource_type=RESOURCE_PLUGIN, user=user) + or self._is_authorized(method="GET", resource_type=RESOURCE_PROVIDER, user=user) + or self._is_authorized(method="GET", resource_type=RESOURCE_TRIGGER, user=user) + or self._is_authorized(method="GET", resource_type=RESOURCE_WEBSITE, user=user) + ) + + @provide_session + def get_permitted_dag_ids( + self, + *, + methods: Container[ResourceMethod] | None = None, + user=None, + session: Session = NEW_SESSION, + ) -> set[str]: + if not methods: + methods = ["PUT", "GET"] + + if not user: + user = self.get_user() + + if not self.is_logged_in(): + roles = user.roles + else: + if ("GET" in methods and self.is_authorized_dag(method="GET", user=user)) or ( + "PUT" in methods and self.is_authorized_dag(method="PUT", user=user) + ): + # If user is authorized to read/edit all DAGs, return all DAGs + return {dag.dag_id for dag in session.execute(select(DagModel.dag_id))} + user_query = session.scalar( + select(User) + .options( + joinedload(User.roles) + .subqueryload(Role.permissions) + .options(joinedload(Permission.action), joinedload(Permission.resource)) + ) + .where(User.id == user.id) + ) + roles = user_query.roles + + map_fab_action_name_to_method_name = {v: k for k, v in MAP_METHOD_NAME_TO_FAB_ACTION_NAME.items()} + map_fab_action_name_to_method_name[ACTION_CAN_ACCESS_MENU] = "GET" + resources = set() + for role in roles: + for permission in role.permissions: + action = permission.action.name + if ( + action in map_fab_action_name_to_method_name + and map_fab_action_name_to_method_name[action] in methods + ): + resource = permission.resource.name + if resource == permissions.RESOURCE_DAG: + return {dag.dag_id for dag in session.execute(select(DagModel.dag_id))} + if resource.startswith(permissions.RESOURCE_DAG_PREFIX): + resources.add(resource[len(permissions.RESOURCE_DAG_PREFIX) :]) + else: + resources.add(resource) + return { + dag.dag_id + for dag in session.execute(select(DagModel.dag_id).where(DagModel.dag_id.in_(resources))) + } def get_security_manager_override_class(self) -> type: """Return the security manager override.""" @@ -270,14 +378,14 @@ def _is_authorized( def _is_authorized_dag( self, method: ResourceMethod, - dag_details: DagDetails | None = None, + details: DagDetails | None = None, user: BaseUser | None = None, ) -> bool: """ Return whether the user is authorized to perform a given action on a DAG. :param method: the method to perform - :param dag_details: optional details about the DAG + :param details: optional details about the DAG :param user: the user to perform the action on. If not provided (or None), it uses the current user :meta private: @@ -286,9 +394,9 @@ def _is_authorized_dag( if is_global_authorized: return True - if dag_details and dag_details.id: + if details and details.id: # Check whether the user has permissions to access a specific DAG - resource_dag_name = self._resource_name_for_dag(dag_details.id) + resource_dag_name = self._resource_name_for_dag(details.id) return self._is_authorized(method=method, resource_type=resource_dag_name, user=user) return False @@ -302,14 +410,14 @@ def _get_fab_action(method: ResourceMethod) -> str: :meta private: """ - if method not in _MAP_METHOD_NAME_TO_FAB_ACTION_NAME: + if method not in MAP_METHOD_NAME_TO_FAB_ACTION_NAME: raise AirflowException(f"Unknown method: {method}") - return _MAP_METHOD_NAME_TO_FAB_ACTION_NAME[method] + return MAP_METHOD_NAME_TO_FAB_ACTION_NAME[method] @staticmethod - def _get_fab_resource_type(dag_access_entity: DagAccessEntity): + def _get_fab_resource_types(dag_access_entity: DagAccessEntity) -> tuple[str, ...]: """ - Convert a DAG access entity to a FAB resource type. + Convert a DAG access entity to a tuple of FAB resource type. :param dag_access_entity: the DAG access entity @@ -361,8 +469,7 @@ def _get_root_dag_id(self, dag_id: str) -> str: :meta private: """ if "." in dag_id: - dm = self.security_manager.appbuilder.get_session.scalar( + return self.security_manager.appbuilder.get_session.scalar( select(DagModel.dag_id, DagModel.root_dag_id).where(DagModel.dag_id == dag_id).limit(1) ) - return dm.root_dag_id or dm.dag_id return dag_id diff --git a/airflow/auth/managers/fab/security_manager/override.py b/airflow/auth/managers/fab/security_manager/override.py index 2e5bf313d9ae0..cd5cb868048ec 100644 --- a/airflow/auth/managers/fab/security_manager/override.py +++ b/airflow/auth/managers/fab/security_manager/override.py @@ -25,7 +25,7 @@ import uuid import warnings from functools import cached_property -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Container, Iterable, Sequence import re2 from flask import flash, g, session @@ -42,13 +42,21 @@ from sqlalchemy.exc import MultipleResultsFound from werkzeug.security import generate_password_hash +from airflow.auth.managers.fab.fab_auth_manager import MAP_METHOD_NAME_TO_FAB_ACTION_NAME from airflow.auth.managers.fab.models import Action, Permission, RegisterUser, Resource, Role from airflow.auth.managers.fab.models.anonymous_user import AnonymousUser -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, RemovedInAirflow3Warning +from airflow.models import DagModel +from airflow.security import permissions +from airflow.utils.session import NEW_SESSION, provide_session +from airflow.www.extensions.init_auth_manager import get_auth_manager from airflow.www.security_manager import AirflowSecurityManagerV2 from airflow.www.session import AirflowDatabaseSessionInterface if TYPE_CHECKING: + from sqlalchemy.orm import Session + + from airflow.auth.managers.base_auth_manager import ResourceMethod from airflow.auth.managers.fab.models import User log = logging.getLogger(__name__) @@ -502,6 +510,91 @@ def create_db(self): log.error(const.LOGMSG_ERR_SEC_CREATE_DB, e) exit(1) + def get_readable_dags(self, user) -> Iterable[DagModel]: + """Get the DAGs readable by authenticated user.""" + warnings.warn( + "`get_readable_dags` has been deprecated. Please use `get_auth_manager().get_permitted_dag_ids` " + "instead.", + RemovedInAirflow3Warning, + stacklevel=2, + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RemovedInAirflow3Warning) + return self.get_accessible_dags([permissions.ACTION_CAN_READ], user) + + def get_editable_dags(self, user) -> Iterable[DagModel]: + """Get the DAGs editable by authenticated user.""" + warnings.warn( + "`get_editable_dags` has been deprecated. Please use `get_auth_manager().get_permitted_dag_ids` " + "instead.", + RemovedInAirflow3Warning, + stacklevel=2, + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RemovedInAirflow3Warning) + return self.get_accessible_dags([permissions.ACTION_CAN_EDIT], user) + + @provide_session + def get_accessible_dags( + self, + user_actions: Container[str] | None, + user, + session: Session = NEW_SESSION, + ) -> Iterable[DagModel]: + warnings.warn( + "`get_accessible_dags` has been deprecated. Please use " + "`get_auth_manager().get_permitted_dag_ids` instead.", + RemovedInAirflow3Warning, + stacklevel=3, + ) + + dag_ids = self.get_accessible_dag_ids(user, user_actions, session) + return session.scalars(select(DagModel).where(DagModel.dag_id.in_(dag_ids))) + + @provide_session + def get_accessible_dag_ids( + self, + user, + user_actions: Container[str] | None = None, + session: Session = NEW_SESSION, + ) -> set[str]: + warnings.warn( + "`get_accessible_dag_ids` has been deprecated. Please use " + "`get_auth_manager().get_permitted_dag_ids` instead.", + RemovedInAirflow3Warning, + stacklevel=3, + ) + if not user_actions: + user_actions = [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ] + fab_action_name_to_method_name = {v: k for k, v in MAP_METHOD_NAME_TO_FAB_ACTION_NAME.items()} + user_methods: Container[ResourceMethod] = [ + fab_action_name_to_method_name[action] + for action in fab_action_name_to_method_name + if action in user_actions + ] + return get_auth_manager().get_permitted_dag_ids(user=user, methods=user_methods, session=session) + + @staticmethod + def get_readable_dag_ids(user=None) -> set[str]: + """Get the DAG IDs readable by authenticated user.""" + return get_auth_manager().get_permitted_dag_ids(methods=["GET"], user=user) + + @staticmethod + def get_editable_dag_ids(user=None) -> set[str]: + """Get the DAG IDs editable by authenticated user.""" + return get_auth_manager().get_permitted_dag_ids(methods=["PUT"], user=user) + + def can_access_some_dags(self, action: str, dag_id: str | None = None) -> bool: + """Check if user has read or write access to some dags.""" + if dag_id and dag_id != "~": + root_dag_id = self._get_root_dag_id(dag_id) + return self.has_access(action, permissions.resource_name_for_dag(root_dag_id)) + + user = g.user + if action == permissions.ACTION_CAN_READ: + return any(self.get_readable_dag_ids(user)) + return any(self.get_editable_dag_ids(user)) + """ ----------- Role entity @@ -1071,6 +1164,31 @@ def oauth_token_getter(): log.debug("Token Get: %s", token) return token + def check_authorization( + self, + perms: Sequence[tuple[str, str]] | None = None, + dag_id: str | None = None, + ) -> bool: + """Checks that the logged in user has the specified permissions.""" + if not perms: + return True + + for perm in perms: + if perm in ( + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG), + ): + can_access_all_dags = self.has_access(*perm) + if not can_access_all_dags: + action = perm[0] + if not self.can_access_some_dags(action, dag_id): + return False + elif not self.has_access(*perm): + return False + + return True + @staticmethod def _azure_parse_jwt(token): """ diff --git a/airflow/auth/managers/models/resource_details.py b/airflow/auth/managers/models/resource_details.py index 51cdc5979392d..1f98ba72ce019 100644 --- a/airflow/auth/managers/models/resource_details.py +++ b/airflow/auth/managers/models/resource_details.py @@ -21,18 +21,46 @@ from enum import Enum +@dataclass +class ConfigurationDetails: + """Represents the details of a configuration.""" + + section: str | None = None + + @dataclass class ConnectionDetails: """Represents the details of a connection.""" - conn_id: str + conn_id: str | None = None @dataclass class DagDetails: """Represents the details of a DAG.""" - id: str + id: str | None = None + + +@dataclass +class DatasetDetails: + """Represents the details of a dataset.""" + + uri: str | None = None + + +@dataclass +class PoolDetails: + """Represents the details of a pool.""" + + name: str | None = None + + +@dataclass +class VariableDetails: + """Represents the details of a variable.""" + + key: str | None = None class DagAccessEntity(Enum): @@ -40,9 +68,11 @@ class DagAccessEntity(Enum): AUDIT_LOG = "AUDIT_LOG" CODE = "CODE" - DATASET = "DATASET" DEPENDENCIES = "DEPENDENCIES" + IMPORT_ERRORS = "IMPORT_ERRORS" RUN = "RUN" + TASK = "TASK" TASK_INSTANCE = "TASK_INSTANCE" TASK_LOGS = "TASK_LOGS" + WARNING = "WARNING" XCOM = "XCOM" diff --git a/airflow/www/auth.py b/airflow/www/auth.py index ffd80a117c763..8fb6ffb435da4 100644 --- a/airflow/www/auth.py +++ b/airflow/www/auth.py @@ -35,7 +35,7 @@ if TYPE_CHECKING: from airflow.auth.managers.base_auth_manager import ResourceMethod - from airflow.models import Connection + from airflow.models.connection import Connection T = TypeVar("T", bound=Callable) @@ -75,7 +75,7 @@ def _has_access_no_details(is_authorized_callback: Callable[[], bool]) -> Callab This works only for resources with no details. This function is used in some ``has_access_`` functions below. - :param is_authorized_callback: callback to execute to figure whether the user authorized to access + :param is_authorized_callback: callback to execute to figure whether the user is authorized to access the resource? """ @@ -140,9 +140,7 @@ def decorated(*args, **kwargs): ] is_authorized = all( [ - get_auth_manager().is_authorized_connection( - method=method, connection_details=connection_details - ) + get_auth_manager().is_authorized_connection(method=method, details=connection_details) for connection_details in connections_details ] ) @@ -191,8 +189,8 @@ def decorated(*args, **kwargs): is_authorized = get_auth_manager().is_authorized_dag( method=method, - dag_access_entity=access_entity, - dag_details=None if not dag_id else DagDetails(id=dag_id), + access_entity=access_entity, + details=None if not dag_id else DagDetails(id=dag_id), ) return _has_access( diff --git a/airflow/www/extensions/init_jinja_globals.py b/airflow/www/extensions/init_jinja_globals.py index ff5481dd468f6..95cd9b8c26785 100644 --- a/airflow/www/extensions/init_jinja_globals.py +++ b/airflow/www/extensions/init_jinja_globals.py @@ -69,10 +69,12 @@ def prepare_jinja_globals(): "git_version": git_version, "k8s_or_k8scelery_executor": IS_K8S_OR_K8SCELERY_EXECUTOR, "rest_api_enabled": False, - "auth_manager": get_auth_manager(), "config_test_connection": conf.get("core", "test_connection", fallback="Disabled"), } + # Extra global specific to auth manager + extra_globals["auth_manager"] = get_auth_manager() + backends = conf.get("api", "auth_backends") if backends and backends[0] != "airflow.api.auth.backend.deny_all": extra_globals["rest_api_enabled"] = True diff --git a/airflow/www/security_manager.py b/airflow/www/security_manager.py index 580191a9cb5bf..065490f6687b6 100644 --- a/airflow/www/security_manager.py +++ b/airflow/www/security_manager.py @@ -18,13 +18,13 @@ import itertools import warnings -from typing import TYPE_CHECKING, Any, Collection, Container, Iterable, Sequence +from typing import TYPE_CHECKING, Any, Collection, Iterable, Sequence from flask import g from sqlalchemy import or_, select from sqlalchemy.orm import joinedload -from airflow.auth.managers.fab.models import Permission, Resource, Role, User +from airflow.auth.managers.fab.models import Permission, Resource, Role from airflow.auth.managers.fab.views.permissions import ( ActionModelView, PermissionPairModelView, @@ -48,8 +48,6 @@ from airflow.models import DagBag, DagModel from airflow.security import permissions from airflow.utils.log.logging_mixin import LoggingMixin -from airflow.utils.session import NEW_SESSION, provide_session -from airflow.www.extensions.init_auth_manager import get_auth_manager from airflow.www.fab_security.sqla.manager import SecurityManager from airflow.www.utils import CustomSQLAInterface @@ -62,7 +60,8 @@ } if TYPE_CHECKING: - from sqlalchemy.orm import Session + + pass class AirflowSecurityManagerV2(SecurityManager, LoggingMixin): @@ -269,126 +268,6 @@ def get_user_roles(user=None): user = g.user return user.roles - def get_readable_dags(self, user) -> Iterable[DagModel]: - """Get the DAGs readable by authenticated user.""" - warnings.warn( - "`get_readable_dags` has been deprecated. Please use `get_readable_dag_ids` instead.", - RemovedInAirflow3Warning, - stacklevel=2, - ) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", RemovedInAirflow3Warning) - return self.get_accessible_dags([permissions.ACTION_CAN_READ], user) - - def get_editable_dags(self, user) -> Iterable[DagModel]: - """Get the DAGs editable by authenticated user.""" - warnings.warn( - "`get_editable_dags` has been deprecated. Please use `get_editable_dag_ids` instead.", - RemovedInAirflow3Warning, - stacklevel=2, - ) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", RemovedInAirflow3Warning) - return self.get_accessible_dags([permissions.ACTION_CAN_EDIT], user) - - @provide_session - def get_accessible_dags( - self, - user_actions: Container[str] | None, - user, - session: Session = NEW_SESSION, - ) -> Iterable[DagModel]: - warnings.warn( - "`get_accessible_dags` has been deprecated. Please use `get_accessible_dag_ids` instead.", - RemovedInAirflow3Warning, - stacklevel=3, - ) - dag_ids = self.get_accessible_dag_ids(user, user_actions, session) - return session.scalars(select(DagModel).where(DagModel.dag_id.in_(dag_ids))) - - def get_readable_dag_ids(self, user) -> set[str]: - """Get the DAG IDs readable by authenticated user.""" - return self.get_accessible_dag_ids(user, [permissions.ACTION_CAN_READ]) - - def get_editable_dag_ids(self, user) -> set[str]: - """Get the DAG IDs editable by authenticated user.""" - return self.get_accessible_dag_ids(user, [permissions.ACTION_CAN_EDIT]) - - @provide_session - def get_accessible_dag_ids( - self, - user, - user_actions: Container[str] | None = None, - session: Session = NEW_SESSION, - ) -> set[str]: - """Get readable or writable DAGs for user.""" - if not user_actions: - user_actions = [permissions.ACTION_CAN_EDIT, permissions.ACTION_CAN_READ] - - if not get_auth_manager().is_logged_in(): - roles = user.roles - else: - if (permissions.ACTION_CAN_EDIT in user_actions and self.can_edit_all_dags(user)) or ( - permissions.ACTION_CAN_READ in user_actions and self.can_read_all_dags(user) - ): - return {dag.dag_id for dag in session.execute(select(DagModel.dag_id))} - user_query = session.scalar( - select(User) - .options( - joinedload(User.roles) - .subqueryload(Role.permissions) - .options(joinedload(Permission.action), joinedload(Permission.resource)) - ) - .where(User.id == user.id) - ) - roles = user_query.roles - - resources = set() - for role in roles: - for permission in role.permissions: - action = permission.action.name - if action in user_actions: - resource = permission.resource.name - if resource == permissions.RESOURCE_DAG: - return {dag.dag_id for dag in session.execute(select(DagModel.dag_id))} - if resource.startswith(permissions.RESOURCE_DAG_PREFIX): - resources.add(resource[len(permissions.RESOURCE_DAG_PREFIX) :]) - else: - resources.add(resource) - return { - dag.dag_id - for dag in session.execute(select(DagModel.dag_id).where(DagModel.dag_id.in_(resources))) - } - - def can_access_some_dags(self, action: str, dag_id: str | None = None) -> bool: - """Check if user has read or write access to some dags.""" - if dag_id and dag_id != "~": - root_dag_id = self._get_root_dag_id(dag_id) - return self.has_access(action, permissions.resource_name_for_dag(root_dag_id)) - - user = g.user - if action == permissions.ACTION_CAN_READ: - return any(self.get_readable_dag_ids(user)) - return any(self.get_editable_dag_ids(user)) - - def can_read_dag(self, dag_id: str, user=None) -> bool: - """Determine whether a user has DAG read access.""" - root_dag_id = self._get_root_dag_id(dag_id) - dag_resource_name = permissions.resource_name_for_dag(root_dag_id) - return self.has_access(permissions.ACTION_CAN_READ, dag_resource_name, user=user) - - def can_edit_dag(self, dag_id: str, user=None) -> bool: - """Determine whether a user has DAG edit access.""" - root_dag_id = self._get_root_dag_id(dag_id) - dag_resource_name = permissions.resource_name_for_dag(root_dag_id) - return self.has_access(permissions.ACTION_CAN_EDIT, dag_resource_name, user=user) - - def can_delete_dag(self, dag_id: str, user=None) -> bool: - """Determine whether a user has DAG delete access.""" - root_dag_id = self._get_root_dag_id(dag_id) - dag_resource_name = permissions.resource_name_for_dag(root_dag_id) - return self.has_access(permissions.ACTION_CAN_DELETE, dag_resource_name, user=user) - def prefixed_dag_id(self, dag_id: str) -> str: """Return the permission name for a DAG id.""" warnings.warn( @@ -430,36 +309,6 @@ def has_access(self, action_name: str, resource_name: str, user=None) -> bool: return False - def _has_role(self, role_name_or_list: Container, user) -> bool: - """Whether the user has this role name.""" - if not isinstance(role_name_or_list, list): - role_name_or_list = [role_name_or_list] - return any(r.name in role_name_or_list for r in user.roles) - - def has_all_dags_access(self, user) -> bool: - """ - Has all the dag access in any of the 3 cases. - - 1. Role needs to be in (Admin, Viewer, User, Op). - 2. Has can_read action on dags resource. - 3. Has can_edit action on dags resource. - """ - if not user: - user = g.user - return ( - self._has_role(["Admin", "Viewer", "Op", "User"], user) - or self.can_read_all_dags(user) - or self.can_edit_all_dags(user) - ) - - def can_edit_all_dags(self, user=None) -> bool: - """Has can_edit action on DAG resource.""" - return self.has_access(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG, user) - - def can_read_all_dags(self, user=None) -> bool: - """Has can_read action on DAG resource.""" - return self.has_access(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG, user) - def clean_perms(self) -> None: """FAB leaves faulty permissions that need to be cleaned up.""" self.log.debug("Cleaning faulty perms") @@ -740,22 +589,6 @@ def check_authorization( perms: Sequence[tuple[str, str]] | None = None, dag_id: str | None = None, ) -> bool: - """Check that the logged in user has the specified permissions.""" - if not perms: - return True - - for perm in perms: - if perm in ( - (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG), - (permissions.ACTION_CAN_DELETE, permissions.RESOURCE_DAG), - ): - can_access_all_dags = self.has_access(*perm) - if not can_access_all_dags: - action = perm[0] - if not self.can_access_some_dags(action, dag_id): - return False - elif not self.has_access(*perm): - return False - - return True + raise NotImplementedError( + "The method 'check_authorization' is only available with the auth manager FabAuthManager" + ) diff --git a/airflow/www/templates/airflow/dag.html b/airflow/www/templates/airflow/dag.html index d324199d1fdef..40440d3fd6672 100644 --- a/airflow/www/templates/airflow/dag.html +++ b/airflow/www/templates/airflow/dag.html @@ -110,16 +110,15 @@