Skip to content

Commit

Permalink
Faster default role syncing during webserver start (apache#15017)
Browse files Browse the repository at this point in the history
This makes a handful of bigger queries instead of many queries when
syncing the default Airflow roles. On my machine with 5k DAGs, this led
to a reduction of 1 second in startup time (bonus, makes tests faster
too).
  • Loading branch information
jedcunningham authored Mar 29, 2021
1 parent a7a558e commit 1627323
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 17 deletions.
67 changes: 56 additions & 11 deletions airflow/www/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
# under the License.
#

from typing import Optional, Sequence, Set, Tuple
import warnings
from typing import Dict, Optional, Sequence, Set, Tuple

from flask import current_app, g
from flask_appbuilder.security.sqla import models as sqla_models
Expand Down Expand Up @@ -174,16 +175,34 @@ def __init__(self, appbuilder):
def init_role(self, role_name, perms):
"""
Initialize the role with the permissions and related view-menus.
:param role_name:
:param perms:
:return:
"""
role = self.find_role(role_name)
if not role:
role = self.add_role(role_name)
warnings.warn(
"`init_role` has been deprecated. Please use `bulk_sync_roles` instead.",
DeprecationWarning,
stacklevel=2,
)
self.bulk_sync_roles([{'role': role_name, 'perms': perms}])

def bulk_sync_roles(self, roles):
"""Sync the provided roles and permissions."""
existing_roles = self._get_all_roles_with_permissions()
pvs = self._get_all_non_dag_permissionviews()

for config in roles:
role_name = config['role']
perms = config['perms']
role = existing_roles.get(role_name) or self.add_role(role_name)

for perm_name, view_name in perms:
perm_view = pvs.get((perm_name, view_name)) or self.add_permission_view_menu(
perm_name, view_name
)

self.add_permissions(role, set(perms))
if perm_view not in role.permissions:
self.add_permission_role(role, perm_view)

def add_permissions(self, role, perms):
"""Adds resource permissions to a given role."""
Expand Down Expand Up @@ -467,6 +486,34 @@ def get_all_permissions(self) -> Set[Tuple[str, str]]:
.all()
)

def _get_all_non_dag_permissionviews(self) -> Dict[Tuple[str, str], PermissionView]:
"""
Returns a dict with a key of (perm name, view menu name) and value of perm view
with all perm views except those that are for specific DAGs.
"""
return {
(perm_name, viewmodel_name): viewmodel
for perm_name, viewmodel_name, viewmodel in (
self.get_session.query(self.permissionview_model)
.join(self.permission_model)
.join(self.viewmenu_model)
.filter(~self.viewmenu_model.name.like(f"{permissions.RESOURCE_DAG_PREFIX}%"))
.with_entities(
self.permission_model.name, self.viewmenu_model.name, self.permissionview_model
)
.all()
)
}

def _get_all_roles_with_permissions(self) -> Dict[str, Role]:
"""Returns a dict with a key of role name and value of role with eagrly loaded permissions"""
return {
r.name: r
for r in (
self.get_session.query(self.role_model).options(joinedload(self.role_model.permissions)).all()
)
}

def create_dag_specific_permissions(self) -> None:
"""
Creates 'can_read' and 'can_edit' permissions for all active and paused DAGs.
Expand Down Expand Up @@ -526,11 +573,9 @@ def sync_roles(self):
self.create_perm_vm_for_all_dag()
self.create_dag_specific_permissions()

# Create default user role.
for config in self.ROLE_CONFIGS:
role = config['role']
perms = config['perms']
self.init_role(role, perms)
# Sync the default roles (Admin, Viewer, User, Op, public) with related permissions
self.bulk_sync_roles(self.ROLE_CONFIGS)

self.add_homepage_access_to_custom_roles()
# init existing roles, the rest role could be created through UI.
self.update_admin_perm_view()
Expand Down
58 changes: 52 additions & 6 deletions tests/www/test_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def delete_roles(cls):
fab_utils.delete_role(cls.app, role_name)

def expect_user_is_in_role(self, user, rolename):
self.security_manager.init_role(rolename, [])
self.security_manager.bulk_sync_roles([{'role': rolename, 'perms': []}])
role = self.security_manager.find_role(rolename)
if not role:
self.security_manager.add_role(rolename)
Expand Down Expand Up @@ -141,14 +141,28 @@ def tearDown(self):
log.debug("Complete teardown!")

def test_init_role_baseview(self):
role_name = 'MyRole7'
role_perms = [('can_some_other_action', 'AnotherBaseView')]
with pytest.warns(
DeprecationWarning,
match="`init_role` has been deprecated\\. Please use `bulk_sync_roles` instead\\.",
):
self.security_manager.init_role(role_name, role_perms)

role = self.appbuilder.sm.find_role(role_name)
assert role is not None
assert len(role_perms) == len(role.permissions)

def test_bulk_sync_roles_baseview(self):
role_name = 'MyRole3'
role_perms = [('can_some_action', 'SomeBaseView')]
self.security_manager.init_role(role_name, perms=role_perms)
self.security_manager.bulk_sync_roles([{'role': role_name, 'perms': role_perms}])

role = self.appbuilder.sm.find_role(role_name)
assert role is not None
assert len(role_perms) == len(role.permissions)

def test_init_role_modelview(self):
def test_bulk_sync_roles_modelview(self):
role_name = 'MyRole2'
role_perms = [
('can_list', 'SomeModelView'),
Expand All @@ -157,24 +171,33 @@ def test_init_role_modelview(self):
(permissions.ACTION_CAN_EDIT, 'SomeModelView'),
(permissions.ACTION_CAN_DELETE, 'SomeModelView'),
]
self.security_manager.init_role(role_name, role_perms)
mock_roles = [{'role': role_name, 'perms': role_perms}]
self.security_manager.bulk_sync_roles(mock_roles)

role = self.appbuilder.sm.find_role(role_name)
assert role is not None
assert len(role_perms) == len(role.permissions)

# Check short circuit works
with assert_queries_count(2): # One for permissionview, one for roles
self.security_manager.bulk_sync_roles(mock_roles)

def test_update_and_verify_permission_role(self):
role_name = 'Test_Role'
self.security_manager.init_role(role_name, [])
role_perms = []
mock_roles = [{'role': role_name, 'perms': role_perms}]
self.security_manager.bulk_sync_roles(mock_roles)
role = self.security_manager.find_role(role_name)

perm = self.security_manager.find_permission_view_menu(permissions.ACTION_CAN_EDIT, 'RoleModelView')
self.security_manager.add_permission_role(role, perm)
role_perms_len = len(role.permissions)

self.security_manager.init_role(role_name, [])
self.security_manager.bulk_sync_roles(mock_roles)
new_role_perms_len = len(role.permissions)

assert role_perms_len == new_role_perms_len
assert new_role_perms_len == 1

def test_verify_public_role_has_no_permissions(self):
public = self.appbuilder.sm.find_role("Public")
Expand Down Expand Up @@ -574,3 +597,26 @@ def test_get_all_permissions(self):
assert len(perm) == 2

assert ('can_read', 'Connections') in perms

def test_get_all_non_dag_permissionviews(self):
with assert_queries_count(1):
pvs = self.security_manager._get_all_non_dag_permissionviews()

assert isinstance(pvs, dict)
for (perm_name, viewmodel_name), perm_view in pvs.items():
assert isinstance(perm_name, str)
assert isinstance(viewmodel_name, str)
assert isinstance(perm_view, self.security_manager.permissionview_model)

assert ('can_read', 'Connections') in pvs

def test_get_all_roles_with_permissions(self):
with assert_queries_count(1):
roles = self.security_manager._get_all_roles_with_permissions()

assert isinstance(roles, dict)
for role_name, role in roles.items():
assert isinstance(role_name, str)
assert isinstance(role, self.security_manager.role_model)

assert 'Admin' in roles

0 comments on commit 1627323

Please sign in to comment.