Skip to content

Commit

Permalink
FieldAttributes as descriptors (ansible#73908)
Browse files Browse the repository at this point in the history
  • Loading branch information
mkrizek authored Jun 29, 2022
1 parent 4c9385d commit 43153c5
Show file tree
Hide file tree
Showing 25 changed files with 333 additions and 409 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def extract_keywords(keyword_definitions):

# Maintain order of the actual class names for our output
# Build up a mapping of playbook classes to the attributes that they hold
pb_keywords[pb_class_name] = {k: v for (k, v) in playbook_class._valid_attrs.items()
pb_keywords[pb_class_name] = {k: v for (k, v) in playbook_class.fattributes.items()
# Filter private attributes as they're not usable in playbooks
if not v.private}

Expand All @@ -60,7 +60,7 @@ def extract_keywords(keyword_definitions):
pb_keywords[pb_class_name][keyword] = keyword_definitions[keyword]
else:
# check if there is an alias, otherwise undocumented
alias = getattr(getattr(playbook_class, '_%s' % keyword), 'alias', None)
alias = getattr(playbook_class.fattributes.get(keyword), 'alias', None)
if alias and alias in keyword_definitions:
pb_keywords[pb_class_name][alias] = keyword_definitions[alias]
del pb_keywords[pb_class_name][keyword]
Expand Down
4 changes: 2 additions & 2 deletions lib/ansible/cli/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,13 +590,13 @@ def _get_keywords_docs(keys):
loaded_class = importlib.import_module(obj_class)
PB_LOADED[pobj] = getattr(loaded_class, pobj, None)

if keyword in PB_LOADED[pobj]._valid_attrs:
if keyword in PB_LOADED[pobj].fattributes:
kdata['applies_to'].append(pobj)

# we should only need these once
if 'type' not in kdata:

fa = getattr(PB_LOADED[pobj], '_%s' % keyword)
fa = PB_LOADED[pobj].fattributes.get(keyword)
if getattr(fa, 'private'):
kdata = {}
raise KeyError
Expand Down
4 changes: 2 additions & 2 deletions lib/ansible/parsing/mod_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def __init__(self, task_ds=None, collection_list=None):
from ansible.playbook.task import Task
from ansible.playbook.handler import Handler
# store the valid Task/Handler attrs for quick access
self._task_attrs = set(Task._valid_attrs.keys())
self._task_attrs.update(set(Handler._valid_attrs.keys()))
self._task_attrs = set(Task.fattributes)
self._task_attrs.update(set(Handler.fattributes))
# HACK: why are these not FieldAttributes on task with a post-validate to check usage?
self._task_attrs.update(['local_action', 'static'])
self._task_attrs = frozenset(self._task_attrs)
Expand Down
103 changes: 93 additions & 10 deletions lib/ansible/playbook/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from copy import copy, deepcopy

from ansible.utils.sentinel import Sentinel

_CONTAINERS = frozenset(('list', 'dict', 'set'))

Expand All @@ -37,10 +38,7 @@ def __init__(
priority=0,
class_type=None,
always_post_validate=False,
inherit=True,
alias=None,
extend=False,
prepend=False,
static=False,
):

Expand Down Expand Up @@ -70,9 +68,6 @@ def __init__(
the field will be an instance of that class.
:kwarg always_post_validate: Controls whether a field should be post
validated or not (default: False).
:kwarg inherit: A boolean value, which controls whether the object
containing this field should attempt to inherit the value from its
parent object if the local value is None.
:kwarg alias: An alias to use for the attribute name, for situations where
the attribute name may conflict with a Python reserved word.
"""
Expand All @@ -85,15 +80,15 @@ def __init__(
self.priority = priority
self.class_type = class_type
self.always_post_validate = always_post_validate
self.inherit = inherit
self.alias = alias
self.extend = extend
self.prepend = prepend
self.static = static

if default is not None and self.isa in _CONTAINERS and not callable(default):
raise TypeError('defaults for FieldAttribute may not be mutable, please provide a callable instead')

def __set_name__(self, owner, name):
self.name = name

def __eq__(self, other):
return other.priority == self.priority

Expand All @@ -114,6 +109,94 @@ def __le__(self, other):
def __ge__(self, other):
return other.priority >= self.priority

def __get__(self, obj, obj_type=None):
method = f'_get_attr_{self.name}'
if hasattr(obj, method):
# NOTE this appears to be not used in the codebase,
# _get_attr_connection has been replaced by ConnectionFieldAttribute.
# Leaving it here for test_attr_method from
# test/units/playbook/test_base.py to pass and for backwards compat.
if getattr(obj, '_squashed', False):
value = getattr(obj, f'_{self.name}', Sentinel)
else:
value = getattr(obj, method)()
else:
value = getattr(obj, f'_{self.name}', Sentinel)

if value is Sentinel:
value = self.default
if callable(value):
value = value()
setattr(obj, f'_{self.name}', value)

return value

def __set__(self, obj, value):
setattr(obj, f'_{self.name}', value)
if self.alias is not None:
setattr(obj, f'_{self.alias}', value)

# NOTE this appears to be not needed in the codebase,
# leaving it here for test_attr_int_del from
# test/units/playbook/test_base.py to pass.
def __delete__(self, obj):
delattr(obj, f'_{self.name}')


class NonInheritableFieldAttribute(Attribute):
...


class FieldAttribute(Attribute):
pass
def __init__(self, extend=False, prepend=False, **kwargs):
super().__init__(**kwargs)

self.extend = extend
self.prepend = prepend

def __get__(self, obj, obj_type=None):
if getattr(obj, '_squashed', False) or getattr(obj, '_finalized', False):
value = getattr(obj, f'_{self.name}', Sentinel)
else:
try:
value = obj._get_parent_attribute(self.name)
except AttributeError:
method = f'_get_attr_{self.name}'
if hasattr(obj, method):
# NOTE this appears to be not needed in the codebase,
# _get_attr_connection has been replaced by ConnectionFieldAttribute.
# Leaving it here for test_attr_method from
# test/units/playbook/test_base.py to pass and for backwards compat.
if getattr(obj, '_squashed', False):
value = getattr(obj, f'_{self.name}', Sentinel)
else:
value = getattr(obj, method)()
else:
value = getattr(obj, f'_{self.name}', Sentinel)

if value is Sentinel:
value = self.default
if callable(value):
value = value()
setattr(obj, f'_{self.name}', value)

return value


class ConnectionFieldAttribute(FieldAttribute):
def __get__(self, obj, obj_type=None):
from ansible.module_utils.compat.paramiko import paramiko
from ansible.utils.ssh_functions import check_for_controlpersist
value = super().__get__(obj, obj_type)

if value == 'smart':
value = 'ssh'
# see if SSH can support ControlPersist if not use paramiko
if not check_for_controlpersist('ssh') and paramiko is not None:
value = "paramiko"

# if someone did `connection: persistent`, default it to using a persistent paramiko connection to avoid problems
elif value == 'persistent' and paramiko is not None:
value = 'paramiko'

return value
Loading

0 comments on commit 43153c5

Please sign in to comment.