Skip to content

Commit

Permalink
adds new common functions for declarative intent modules (ansible#25210)
Browse files Browse the repository at this point in the history
* adds new common functions for declarative intent modules

* adds Entity and EntityCollection
* adds dict_diff and dict_combine

* update for CI  PEP8 compliance

* more CI PEP8 fixes

* more PEP8 CI clean up

* refactors the lambda assignments into top level classes

this is to be in compliant the PEP8 CI sanity checks

* one last pep8 ci fix
  • Loading branch information
privateip authored Jun 16, 2017
1 parent 43468b8 commit 3aa41ed
Show file tree
Hide file tree
Showing 2 changed files with 287 additions and 26 deletions.
184 changes: 158 additions & 26 deletions lib/ansible/module_utils/network_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
from itertools import chain

from ansible.module_utils.six import iteritems
from ansible.module_utils.basic import AnsibleFallbackNotFound
from ansible.module_utils.six import iteritems

Expand All @@ -38,7 +41,13 @@ def to_list(val):
return list()


class ComplexDict(object):
def sort_list(val):
if isinstance(val, list):
return sorted(val)
return val


class Entity(object):
"""Transforms a dict to with an argument spec
This class will take a dict and apply an Ansible argument spec to the
Expand All @@ -52,7 +61,7 @@ class ComplexDict(object):
display=dict(default='text', choices=['text', 'json']),
validate=dict(type='bool')
)
transform = ComplexDict(argument_spec, module)
transform = Entity(module, argument_spec)
value = dict(command='foo')
result = transform(value)
print result
Expand All @@ -66,31 +75,42 @@ class ComplexDict(object):
* fallback - implements fallback function
* choices - set of valid options
* default - default value
"""

def __init__(self, attrs, module):
self._attributes = attrs
def __init__(self, module, attrs=None, args=[], keys=None, from_argspec=False):
self._attributes = attrs or {}
self._module = module

for arg in args:
self._attributes[arg] = dict()
if from_argspec:
self._attributes[arg]['read_from'] = arg
if keys and arg in keys:
self._attributes[arg]['key'] = True

self.attr_names = frozenset(self._attributes.keys())

self._has_key = False
_has_key = False

for name, attr in iteritems(self._attributes):
if attr.get('read_from'):
if attr['read_from'] not in self._module.argument_spec:
module.fail_json(msg='argument %s does not exist' % attr['read_from'])
spec = self._module.argument_spec.get(attr['read_from'])
if not spec:
raise ValueError('argument_spec %s does not exist' % attr['read_from'])
for key, value in iteritems(spec):
if key not in attr:
attr[key] = value

if attr.get('key'):
if self._has_key:
raise ValueError('only one key value can be specified')
self._has_key = True
if _has_key:
module.fail_json(msg='only one key value can be specified')
_has_key = True
attr['required'] = True

def _dict(self, value):
def serialize(self):
return self._attributes

def to_dict(self, value):
obj = {}
for name, attr in iteritems(self._attributes):
if attr.get('key'):
Expand All @@ -99,16 +119,17 @@ def _dict(self, value):
obj[name] = attr.get('default')
return obj

def __call__(self, value):
def __call__(self, value, strict=True):
if not isinstance(value, dict):
value = self._dict(value)
value = self.to_dict(value)

unknown = set(value).difference(self.attr_names)
if unknown:
raise ValueError('invalid keys: %s' % ','.join(unknown))
if strict:
unknown = set(value).difference(self.attr_names)
if unknown:
self._module.fail_json(msg='invalid keys: %s' % ','.join(unknown))

for name, attr in iteritems(self._attributes):
if not value.get(name):
if value.get(name) is None:
value[name] = attr.get('default')

if attr.get('fallback') and not value.get(name):
Expand All @@ -128,24 +149,135 @@ def __call__(self, value):
continue

if attr.get('required') and value.get(name) is None:
raise ValueError('missing required attribute %s' % name)
self._module.fail_json(msg='missing required attribute %s' % name)

if 'choices' in attr:
if value[name] not in attr['choices']:
raise ValueError('%s must be one of %s, got %s' % (name, ', '.join(attr['choices']), value[name]))
self._module.fail_json(msg='%s must be one of %s, got %s' % (name, ', '.join(attr['choices']), value[name]))

if value[name] is not None:
value_type = attr.get('type', 'str')
type_checker = self._module._CHECK_ARGUMENT_TYPES_DISPATCHER[value_type]
type_checker(value[name])
elif value.get(name):
value[name] = self._module.params[name]

return value


class ComplexList(ComplexDict):
"""Extends ```ComplexDict``` to handle a list of dicts """
class EntityCollection(Entity):
"""Extends ```Entity``` to handle a list of dicts """

def __call__(self, iterable, strict=True):
if iterable is None:
iterable = [super(EntityCollection, self).__call__(self._module.params, strict)]

if not isinstance(iterable, (list, tuple)):
module.fail_json(msg='value must be an iterable')

return [(super(EntityCollection, self).__call__(i, strict)) for i in iterable]


# these two are for backwards compatibility and can be removed once all of the
# modules that use them are updated
class ComplexDict(Entity):
def __init__(self, attrs, module, *args, **kwargs):
super(ComplexDict, self).__init__(module, attrs, *args, **kwargs)


class ComplexList(EntityCollection):
def __init__(self, attrs, module, *args, **kwargs):
super(ComplexList, self).__init__(module, attrs, *args, **kwargs)


def dict_diff(base, comparable):
""" Generate a dict object of differences
This function will compare two dict objects and return the difference
between them as a dict object. For scalar values, the key will reflect
the updated value. If the key does not exist in `comparable`, then then no
key will be returned. For lists, the value in comparable will wholly replace
the value in base for the key. For dicts, the returned value will only
return keys that are different.
:param base: dict object to base the diff on
:param comparable: dict object to compare against base
:returns: new dict object with differences
"""
assert isinstance(base, dict), "`base` must be of type <dict>"
assert isinstance(comparable, dict), "`comparable` must be of type <dict>"

updates = dict()

for key, value in iteritems(base):
if isinstance(value, dict):
item = comparable.get(key)
if item is not None:
updates[key] = dict_diff(value, comparable[key])
else:
comparable_value = comparable.get(key)
if comparable_value is not None:
if sort_list(base[key]) != sort_list(comparable_value):
updates[key] = comparable_value

for key in set(comparable.keys()).difference(base.keys()):
updates[key] = comparable.get(key)

return updates


def dict_combine(base, other):
""" Return a new dict object that combines base and other
This will create a new dict object that is a combination of the key/value
pairs from base and other. When both keys exist, the value will be
selected from other. If the value is a list object, the two lists will
be combined and duplicate entries removed.
:param base: dict object to serve as base
:param other: dict object to combine with base
:returns: new combined dict object
"""
assert isinstance(base, dict), "`base` must be of type <dict>"
assert isinstance(other, dict), "`other` must be of type <dict>"

combined = dict()

for key, value in iteritems(base):
if isinstance(value, dict):
if key in other:
item = other.get(key)
if item is not None:
combined[key] = dict_combine(value, other[key])
else:
combined[key] = item
else:
combined[key] = value
elif isinstance(value, list):
if key in other:
item = other.get(key)
if item is not None:
combined[key] = list(set(chain(value, item)))
else:
combined[key] = item
else:
combined[key] = value
else:
if key in other:
other_value = other.get(key)
if other_value is not None:
if sort_list(base[key]) != sort_list(other_value):
combined[key] = other_value
else:
combined[key] = value
else:
combined[key] = other_value
else:
combined[key] = value

for key in set(other.keys()).difference(base.keys()):
combined[key] = other.get(key)

def __call__(self, values):
if not isinstance(values, (list, tuple)):
raise TypeError('value must be an ordered iterable')
return [(super(ComplexList, self).__call__(v)) for v in values]
return combined
129 changes: 129 additions & 0 deletions test/units/module_utils/test_network_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# -*- coding: utf-8 -*-
#
# (c) 2017 Red Hat, Inc.
#
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible. If not, see <http://www.gnu.org/licenses/>.

# Make coding more python3-ish
from __future__ import (absolute_import, division)
__metaclass__ = type

from ansible.compat.tests import unittest

from ansible.module_utils.network_common import to_list, sort_list
from ansible.module_utils.network_common import dict_diff, dict_combine


class TestModuleUtilsNetworkCommon(unittest.TestCase):

def test_to_list(self):
for scalar in ('string', 1, True, False, None):
self.assertTrue(isinstance(to_list(scalar), list))

for container in ([1, 2, 3], {'one': 1}):
self.assertTrue(isinstance(to_list(container), list))

test_list = [1, 2, 3]
self.assertNotEqual(id(test_list), id(to_list(test_list)))

def test_sort(self):
data = [3, 1, 2]
self.assertEqual([1, 2, 3], sort_list(data))

string_data = '123'
self.assertEqual(string_data, sort_list(string_data))

def test_dict_diff(self):
base = dict(obj2=dict(), b1=True, b2=False, b3=False,
one=1, two=2, three=3, obj1=dict(key1=1, key2=2),
l1=[1, 3], l2=[1, 2, 3], l4=[4],
nested=dict(n1=dict(n2=2)))

other = dict(b1=True, b2=False, b3=True, b4=True,
one=1, three=4, four=4, obj1=dict(key1=2),
l1=[2, 1], l2=[3, 2, 1], l3=[1],
nested=dict(n1=dict(n2=2, n3=3)))

result = dict_diff(base, other)

# string assertions
self.assertNotIn('one', result)
self.assertNotIn('two', result)
self.assertEqual(result['three'], 4)
self.assertEqual(result['four'], 4)

# dict assertions
self.assertIn('obj1', result)
self.assertIn('key1', result['obj1'])
self.assertNotIn('key2', result['obj1'])

# list assertions
self.assertEqual(result['l1'], [2, 1])
self.assertNotIn('l2', result)
self.assertEqual(result['l3'], [1])
self.assertNotIn('l4', result)

# nested assertions
self.assertIn('obj1', result)
self.assertEqual(result['obj1']['key1'], 2)
self.assertNotIn('key2', result['obj1'])

# bool assertions
self.assertNotIn('b1', result)
self.assertNotIn('b2', result)
self.assertTrue(result['b3'])
self.assertTrue(result['b4'])

def test_dict_combine(self):
base = dict(obj2=dict(), b1=True, b2=False, b3=False,
one=1, two=2, three=3, obj1=dict(key1=1, key2=2),
l1=[1, 3], l2=[1, 2, 3], l4=[4],
nested=dict(n1=dict(n2=2)))

other = dict(b1=True, b2=False, b3=True, b4=True,
one=1, three=4, four=4, obj1=dict(key1=2),
l1=[2, 1], l2=[3, 2, 1], l3=[1],
nested=dict(n1=dict(n2=2, n3=3)))

result = dict_combine(base, other)

# string assertions
self.assertIn('one', result)
self.assertIn('two', result)
self.assertEqual(result['three'], 4)
self.assertEqual(result['four'], 4)

# dict assertions
self.assertIn('obj1', result)
self.assertIn('key1', result['obj1'])
self.assertIn('key2', result['obj1'])

# list assertions
self.assertEqual(result['l1'], [1, 2, 3])
self.assertIn('l2', result)
self.assertEqual(result['l3'], [1])
self.assertIn('l4', result)

# nested assertions
self.assertIn('obj1', result)
self.assertEqual(result['obj1']['key1'], 2)
self.assertIn('key2', result['obj1'])

# bool assertions
self.assertIn('b1', result)
self.assertIn('b2', result)
self.assertTrue(result['b3'])
self.assertTrue(result['b4'])

0 comments on commit 3aa41ed

Please sign in to comment.