diff --git a/Makefile b/Makefile index 7702f5162dab0b..b659e044f81b24 100644 --- a/Makefile +++ b/Makefile @@ -94,7 +94,7 @@ tests: PYTHONPATH=./lib $(NOSETESTS) -d -w test/units -v newtests: - PYTHONPATH=./v2 $(NOSETESTS) -d -w test/v2 -v + PYTHONPATH=./v2:./lib $(NOSETESTS) -d -w test/v2 -v authors: diff --git a/test/v2/playbook/test_task.py b/test/v2/playbook/test_task.py index c6215d56ffc067..a012dff4bf8206 100644 --- a/test/v2/playbook/test_task.py +++ b/test/v2/playbook/test_task.py @@ -16,24 +16,39 @@ def setUp(self): def tearDown(self): pass - def test_can_construct_empty_task(self): + def test_construct_empty_task(self): t = Task() - def test_can_construct_task_with_role(self): + def test_construct_task_with_role(self): pass - def test_can_construct_task_with_block(self): + def test_construct_task_with_block(self): pass - def test_can_construct_task_with_role_and_block(self): + def test_construct_task_with_role_and_block(self): pass - def test_can_load_simple_task(self): - t = Task.load(basic_shell_task) - assert t is not None - print "NAME=%s" % t.name - assert t.name == basic_shell_task['name'] - #assert t.module == 'shell' - #assert t.args == 'echo hi' + def test_load_simple_task(self): + t = Task.load(basic_shell_task) + assert t is not None + assert t.name == basic_shell_task['name'] + assert t.module == 'shell' + assert t.args == 'echo hi' + def test_can_load_action_kv_form(self): + pass + + def test_can_load_action_complex_form(self): + pass + + def test_can_load_module_complex_form(self): + pass + + def test_local_action_implies_delegate(self): + pass + def test_local_action_conflicts_with_delegate(self): + pass + + def test_delegate_to_parses(self): + pass diff --git a/v2/ansible/constants.py b/v2/ansible/constants.py new file mode 100644 index 00000000000000..861dd5325c16cf --- /dev/null +++ b/v2/ansible/constants.py @@ -0,0 +1,190 @@ +# (c) 2012-2014, Michael DeHaan +# +# This file is part of Ansible +# +# Ansible is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# Ansible is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with Ansible. If not, see . + +import os +import pwd +import sys +import ConfigParser +from string import ascii_letters, digits + +# copied from utils, avoid circular reference fun :) +def mk_boolean(value): + if value is None: + return False + val = str(value) + if val.lower() in [ "true", "t", "y", "1", "yes" ]: + return True + else: + return False + +def get_config(p, section, key, env_var, default, boolean=False, integer=False, floating=False, islist=False): + ''' return a configuration variable with casting ''' + value = _get_config(p, section, key, env_var, default) + if boolean: + return mk_boolean(value) + if value and integer: + return int(value) + if value and floating: + return float(value) + if value and islist: + return [x.strip() for x in value.split(',')] + return value + +def _get_config(p, section, key, env_var, default): + ''' helper function for get_config ''' + if env_var is not None: + value = os.environ.get(env_var, None) + if value is not None: + return value + if p is not None: + try: + return p.get(section, key, raw=True) + except: + return default + return default + +def load_config_file(): + ''' Load Config File order(first found is used): ENV, CWD, HOME, /etc/ansible ''' + + p = ConfigParser.ConfigParser() + + path0 = os.getenv("ANSIBLE_CONFIG", None) + if path0 is not None: + path0 = os.path.expanduser(path0) + path1 = os.getcwd() + "/ansible.cfg" + path2 = os.path.expanduser("~/.ansible.cfg") + path3 = "/etc/ansible/ansible.cfg" + + for path in [path0, path1, path2, path3]: + if path is not None and os.path.exists(path): + try: + p.read(path) + except ConfigParser.Error as e: + print "Error reading config file: \n%s" % e + sys.exit(1) + return p + return None + +def shell_expand_path(path): + ''' shell_expand_path is needed as os.path.expanduser does not work + when path is None, which is the default for ANSIBLE_PRIVATE_KEY_FILE ''' + if path: + path = os.path.expanduser(os.path.expandvars(path)) + return path + +p = load_config_file() + +active_user = pwd.getpwuid(os.geteuid())[0] + +# check all of these extensions when looking for yaml files for things like +# group variables -- really anything we can load +YAML_FILENAME_EXTENSIONS = [ "", ".yml", ".yaml", ".json" ] + +# sections in config file +DEFAULTS='defaults' + +# configurable things +DEFAULT_HOST_LIST = shell_expand_path(get_config(p, DEFAULTS, 'hostfile', 'ANSIBLE_HOSTS', '/etc/ansible/hosts')) +DEFAULT_MODULE_PATH = get_config(p, DEFAULTS, 'library', 'ANSIBLE_LIBRARY', None) +DEFAULT_ROLES_PATH = shell_expand_path(get_config(p, DEFAULTS, 'roles_path', 'ANSIBLE_ROLES_PATH', '/etc/ansible/roles')) +DEFAULT_REMOTE_TMP = get_config(p, DEFAULTS, 'remote_tmp', 'ANSIBLE_REMOTE_TEMP', '$HOME/.ansible/tmp') +DEFAULT_MODULE_NAME = get_config(p, DEFAULTS, 'module_name', None, 'command') +DEFAULT_PATTERN = get_config(p, DEFAULTS, 'pattern', None, '*') +DEFAULT_FORKS = get_config(p, DEFAULTS, 'forks', 'ANSIBLE_FORKS', 5, integer=True) +DEFAULT_MODULE_ARGS = get_config(p, DEFAULTS, 'module_args', 'ANSIBLE_MODULE_ARGS', '') +DEFAULT_MODULE_LANG = get_config(p, DEFAULTS, 'module_lang', 'ANSIBLE_MODULE_LANG', 'en_US.UTF-8') +DEFAULT_TIMEOUT = get_config(p, DEFAULTS, 'timeout', 'ANSIBLE_TIMEOUT', 10, integer=True) +DEFAULT_POLL_INTERVAL = get_config(p, DEFAULTS, 'poll_interval', 'ANSIBLE_POLL_INTERVAL', 15, integer=True) +DEFAULT_REMOTE_USER = get_config(p, DEFAULTS, 'remote_user', 'ANSIBLE_REMOTE_USER', active_user) +DEFAULT_ASK_PASS = get_config(p, DEFAULTS, 'ask_pass', 'ANSIBLE_ASK_PASS', False, boolean=True) +DEFAULT_PRIVATE_KEY_FILE = shell_expand_path(get_config(p, DEFAULTS, 'private_key_file', 'ANSIBLE_PRIVATE_KEY_FILE', None)) +DEFAULT_SUDO_USER = get_config(p, DEFAULTS, 'sudo_user', 'ANSIBLE_SUDO_USER', 'root') +DEFAULT_ASK_SUDO_PASS = get_config(p, DEFAULTS, 'ask_sudo_pass', 'ANSIBLE_ASK_SUDO_PASS', False, boolean=True) +DEFAULT_REMOTE_PORT = get_config(p, DEFAULTS, 'remote_port', 'ANSIBLE_REMOTE_PORT', None, integer=True) +DEFAULT_ASK_VAULT_PASS = get_config(p, DEFAULTS, 'ask_vault_pass', 'ANSIBLE_ASK_VAULT_PASS', False, boolean=True) +DEFAULT_VAULT_PASSWORD_FILE = shell_expand_path(get_config(p, DEFAULTS, 'vault_password_file', 'ANSIBLE_VAULT_PASSWORD_FILE', None)) +DEFAULT_TRANSPORT = get_config(p, DEFAULTS, 'transport', 'ANSIBLE_TRANSPORT', 'smart') +DEFAULT_SCP_IF_SSH = get_config(p, 'ssh_connection', 'scp_if_ssh', 'ANSIBLE_SCP_IF_SSH', False, boolean=True) +DEFAULT_MANAGED_STR = get_config(p, DEFAULTS, 'ansible_managed', None, 'Ansible managed: {file} modified on %Y-%m-%d %H:%M:%S by {uid} on {host}') +DEFAULT_SYSLOG_FACILITY = get_config(p, DEFAULTS, 'syslog_facility', 'ANSIBLE_SYSLOG_FACILITY', 'LOG_USER') +DEFAULT_KEEP_REMOTE_FILES = get_config(p, DEFAULTS, 'keep_remote_files', 'ANSIBLE_KEEP_REMOTE_FILES', False, boolean=True) +DEFAULT_SUDO = get_config(p, DEFAULTS, 'sudo', 'ANSIBLE_SUDO', False, boolean=True) +DEFAULT_SUDO_EXE = get_config(p, DEFAULTS, 'sudo_exe', 'ANSIBLE_SUDO_EXE', 'sudo') +DEFAULT_SUDO_FLAGS = get_config(p, DEFAULTS, 'sudo_flags', 'ANSIBLE_SUDO_FLAGS', '-H') +DEFAULT_HASH_BEHAVIOUR = get_config(p, DEFAULTS, 'hash_behaviour', 'ANSIBLE_HASH_BEHAVIOUR', 'replace') +DEFAULT_JINJA2_EXTENSIONS = get_config(p, DEFAULTS, 'jinja2_extensions', 'ANSIBLE_JINJA2_EXTENSIONS', None) +DEFAULT_EXECUTABLE = get_config(p, DEFAULTS, 'executable', 'ANSIBLE_EXECUTABLE', '/bin/sh') +DEFAULT_SU_EXE = get_config(p, DEFAULTS, 'su_exe', 'ANSIBLE_SU_EXE', 'su') +DEFAULT_SU = get_config(p, DEFAULTS, 'su', 'ANSIBLE_SU', False, boolean=True) +DEFAULT_SU_FLAGS = get_config(p, DEFAULTS, 'su_flags', 'ANSIBLE_SU_FLAGS', '') +DEFAULT_SU_USER = get_config(p, DEFAULTS, 'su_user', 'ANSIBLE_SU_USER', 'root') +DEFAULT_ASK_SU_PASS = get_config(p, DEFAULTS, 'ask_su_pass', 'ANSIBLE_ASK_SU_PASS', False, boolean=True) +DEFAULT_GATHERING = get_config(p, DEFAULTS, 'gathering', 'ANSIBLE_GATHERING', 'implicit').lower() + +DEFAULT_ACTION_PLUGIN_PATH = get_config(p, DEFAULTS, 'action_plugins', 'ANSIBLE_ACTION_PLUGINS', '/usr/share/ansible_plugins/action_plugins') +DEFAULT_CACHE_PLUGIN_PATH = get_config(p, DEFAULTS, 'cache_plugins', 'ANSIBLE_CACHE_PLUGINS', '/usr/share/ansible_plugins/cache_plugins') +DEFAULT_CALLBACK_PLUGIN_PATH = get_config(p, DEFAULTS, 'callback_plugins', 'ANSIBLE_CALLBACK_PLUGINS', '/usr/share/ansible_plugins/callback_plugins') +DEFAULT_CONNECTION_PLUGIN_PATH = get_config(p, DEFAULTS, 'connection_plugins', 'ANSIBLE_CONNECTION_PLUGINS', '/usr/share/ansible_plugins/connection_plugins') +DEFAULT_LOOKUP_PLUGIN_PATH = get_config(p, DEFAULTS, 'lookup_plugins', 'ANSIBLE_LOOKUP_PLUGINS', '/usr/share/ansible_plugins/lookup_plugins') +DEFAULT_VARS_PLUGIN_PATH = get_config(p, DEFAULTS, 'vars_plugins', 'ANSIBLE_VARS_PLUGINS', '/usr/share/ansible_plugins/vars_plugins') +DEFAULT_FILTER_PLUGIN_PATH = get_config(p, DEFAULTS, 'filter_plugins', 'ANSIBLE_FILTER_PLUGINS', '/usr/share/ansible_plugins/filter_plugins') +DEFAULT_LOG_PATH = shell_expand_path(get_config(p, DEFAULTS, 'log_path', 'ANSIBLE_LOG_PATH', '')) + +CACHE_PLUGIN = get_config(p, DEFAULTS, 'fact_caching', 'ANSIBLE_CACHE_PLUGIN', 'memory') +CACHE_PLUGIN_CONNECTION = get_config(p, DEFAULTS, 'fact_caching_connection', 'ANSIBLE_CACHE_PLUGIN_CONNECTION', None) +CACHE_PLUGIN_PREFIX = get_config(p, DEFAULTS, 'fact_caching_prefix', 'ANSIBLE_CACHE_PLUGIN_PREFIX', 'ansible_facts') +CACHE_PLUGIN_TIMEOUT = get_config(p, DEFAULTS, 'fact_caching_timeout', 'ANSIBLE_CACHE_PLUGIN_TIMEOUT', 24 * 60 * 60, integer=True) + +ANSIBLE_FORCE_COLOR = get_config(p, DEFAULTS, 'force_color', 'ANSIBLE_FORCE_COLOR', None, boolean=True) +ANSIBLE_NOCOLOR = get_config(p, DEFAULTS, 'nocolor', 'ANSIBLE_NOCOLOR', None, boolean=True) +ANSIBLE_NOCOWS = get_config(p, DEFAULTS, 'nocows', 'ANSIBLE_NOCOWS', None, boolean=True) +DISPLAY_SKIPPED_HOSTS = get_config(p, DEFAULTS, 'display_skipped_hosts', 'DISPLAY_SKIPPED_HOSTS', True, boolean=True) +DEFAULT_UNDEFINED_VAR_BEHAVIOR = get_config(p, DEFAULTS, 'error_on_undefined_vars', 'ANSIBLE_ERROR_ON_UNDEFINED_VARS', True, boolean=True) +HOST_KEY_CHECKING = get_config(p, DEFAULTS, 'host_key_checking', 'ANSIBLE_HOST_KEY_CHECKING', True, boolean=True) +SYSTEM_WARNINGS = get_config(p, DEFAULTS, 'system_warnings', 'ANSIBLE_SYSTEM_WARNINGS', True, boolean=True) +DEPRECATION_WARNINGS = get_config(p, DEFAULTS, 'deprecation_warnings', 'ANSIBLE_DEPRECATION_WARNINGS', True, boolean=True) +DEFAULT_CALLABLE_WHITELIST = get_config(p, DEFAULTS, 'callable_whitelist', 'ANSIBLE_CALLABLE_WHITELIST', [], islist=True) +COMMAND_WARNINGS = get_config(p, DEFAULTS, 'command_warnings', 'ANSIBLE_COMMAND_WARNINGS', False, boolean=True) +DEFAULT_LOAD_CALLBACK_PLUGINS = get_config(p, DEFAULTS, 'bin_ansible_callbacks', 'ANSIBLE_LOAD_CALLBACK_PLUGINS', False, boolean=True) + +# CONNECTION RELATED +ANSIBLE_SSH_ARGS = get_config(p, 'ssh_connection', 'ssh_args', 'ANSIBLE_SSH_ARGS', None) +ANSIBLE_SSH_CONTROL_PATH = get_config(p, 'ssh_connection', 'control_path', 'ANSIBLE_SSH_CONTROL_PATH', "%(directory)s/ansible-ssh-%%h-%%p-%%r") +ANSIBLE_SSH_PIPELINING = get_config(p, 'ssh_connection', 'pipelining', 'ANSIBLE_SSH_PIPELINING', False, boolean=True) +PARAMIKO_RECORD_HOST_KEYS = get_config(p, 'paramiko_connection', 'record_host_keys', 'ANSIBLE_PARAMIKO_RECORD_HOST_KEYS', True, boolean=True) +# obsolete -- will be formally removed in 1.6 +ZEROMQ_PORT = get_config(p, 'fireball_connection', 'zeromq_port', 'ANSIBLE_ZEROMQ_PORT', 5099, integer=True) +ACCELERATE_PORT = get_config(p, 'accelerate', 'accelerate_port', 'ACCELERATE_PORT', 5099, integer=True) +ACCELERATE_TIMEOUT = get_config(p, 'accelerate', 'accelerate_timeout', 'ACCELERATE_TIMEOUT', 30, integer=True) +ACCELERATE_CONNECT_TIMEOUT = get_config(p, 'accelerate', 'accelerate_connect_timeout', 'ACCELERATE_CONNECT_TIMEOUT', 1.0, floating=True) +ACCELERATE_DAEMON_TIMEOUT = get_config(p, 'accelerate', 'accelerate_daemon_timeout', 'ACCELERATE_DAEMON_TIMEOUT', 30, integer=True) +ACCELERATE_KEYS_DIR = get_config(p, 'accelerate', 'accelerate_keys_dir', 'ACCELERATE_KEYS_DIR', '~/.fireball.keys') +ACCELERATE_KEYS_DIR_PERMS = get_config(p, 'accelerate', 'accelerate_keys_dir_perms', 'ACCELERATE_KEYS_DIR_PERMS', '700') +ACCELERATE_KEYS_FILE_PERMS = get_config(p, 'accelerate', 'accelerate_keys_file_perms', 'ACCELERATE_KEYS_FILE_PERMS', '600') +ACCELERATE_MULTI_KEY = get_config(p, 'accelerate', 'accelerate_multi_key', 'ACCELERATE_MULTI_KEY', False, boolean=True) +PARAMIKO_PTY = get_config(p, 'paramiko_connection', 'pty', 'ANSIBLE_PARAMIKO_PTY', True, boolean=True) + +# characters included in auto-generated passwords +DEFAULT_PASSWORD_CHARS = ascii_letters + digits + ".,:-_" + +# non-configurable things +DEFAULT_SUDO_PASS = None +DEFAULT_REMOTE_PASS = None +DEFAULT_SUBSET = None +DEFAULT_SU_PASS = None +VAULT_VERSION_MIN = 1.0 +VAULT_VERSION_MAX = 1.0 diff --git a/v2/ansible/playbook/base.py b/v2/ansible/playbook/base.py index 1223eafefe619f..68dc2d6ffe3bf9 100644 --- a/v2/ansible/playbook/base.py +++ b/v2/ansible/playbook/base.py @@ -15,44 +15,39 @@ # You should have received a copy of the GNU General Public License # along with Ansible. If not, see . -#from ansible.cmmon.errors import AnsibleError -#from playbook.tag import Tag from ansible.playbook.attribute import Attribute, FieldAttribute - -# general concept -# FooObject.load(datastructure) -> Foo -# FooObject._load_field # optional -# FooObject._validate_field # optional -# FooObject._post_validate_field # optional -# FooObject.evaluate(host_context) -> FooObject ? (calls post_validators, templates all members) -# question - are there some things that need to be evaluated *before* host context, i.e. globally? -# most things should be templated but want to provide as much early checking as possible -# TODO: also check for fields in datastructure that are not valid -# TODO: PluginAttribute(type) allows all the valid plugins as valid types of names -# lookupPlugins start with "with_", ModulePluginAttribute allows any key - class Base(object): def __init__(self): - self._data = dict() + + # each class knows attributes set upon it, see Task.py for example self._attributes = dict() - for name in self.__class__.__dict__: aname = name[1:] if isinstance(aname, Attribute) and not isinstance(aname, FieldAttribute): self._attributes[aname] = None + def munge(self, ds): + ''' infrequently used method to do some pre-processing of legacy terms ''' + + return ds + def load_data(self, ds): ''' walk the input datastructure and assign any values ''' assert ds is not None + ds = self.munge(ds) + # walk all attributes in the class for (name, attribute) in self.__class__.__dict__.iteritems(): aname = name[1:] - # process Fields + # process Field attributes which get loaded from the YAML + if isinstance(attribute, FieldAttribute): + + # copy the value over unless a _load_field method is defined method = getattr(self, '_load_%s' % aname, None) if method: self._attributes[aname] = method(self, attribute) @@ -60,38 +55,45 @@ def load_data(self, ds): if aname in ds: self._attributes[aname] = ds[aname] - # TODO: implement PluginAtrribute which allows "with_" and "action" aliases. - + # return the constructed object + self.validate() return self def validate(self): - # TODO: finish - for name in self.__dict__: - aname = name[1:] - attribute = self.__dict__[aname] - if instanceof(attribute, FieldAttribute): + ''' validation that is done at parse time, not load time ''' + + # walk all fields in the object + for (name, attribute) in self.__dict__: + + # find any field attributes + if isinstance(attribute, FieldAttribute): + + if not name.startswith("_"): + raise AnsibleError("FieldAttribute %s must start with _" % name) + + aname = name[1:] + + # run validator only if present method = getattr(self, '_validate_%s' % (prefix, aname), None) if method: method(self, attribute) def post_validate(self, runner_context): - # TODO: finish + ''' + we can't tell that everything is of the right type until we have + all the variables. Run basic types (from isa) as well as + any _post_validate_ functions. + ''' + raise exception.NotImplementedError def __getattr__(self, needle): + + # return any attribute names as if they were real. + # access them like obj.attrname() if needle in self._attributes: return self._attributes[needle] - if needle in self.__dict__: - return self.__dict__[needle] - raise AttributeError - - #def __setattr__(self, needle, value): - # if needle in self._attributes: - # self._attributes[needle] = value - # if needle in self.__dict__: - # super(Base, self).__setattr__(needle, value) - # # self.__dict__[needle] = value - # raise AttributeError + raise AttributeError diff --git a/v2/ansible/playbook/task.py b/v2/ansible/playbook/task.py index 80d04fcb95fa90..ad708f167cbd05 100644 --- a/v2/ansible/playbook/task.py +++ b/v2/ansible/playbook/task.py @@ -17,13 +17,15 @@ from ansible.playbook.base import Base from ansible.playbook.attribute import Attribute, FieldAttribute -from ansible.playbook.conditional import Conditional -#from ansible.common.errors import AnsibleError -#from ansible import utils + +# from ansible.playbook.conditional import Conditional +# from ansible.common.errors import AnsibleError # TODO: it would be fantastic (if possible) if a task new where in the YAML it was defined for describing # it in error conditions +from ansible.plugins import module_finder, lookup_finder + class Task(Base): """ @@ -44,6 +46,7 @@ class Task(Base): # might be possible to define others _action = FieldAttribute(isa='string') + _always_run = FieldAttribute(isa='bool') _any_errors_fatal = FieldAttribute(isa='bool') _async = FieldAttribute(isa='int') @@ -55,12 +58,14 @@ class Task(Base): _ignore_errors = FieldAttribute(isa='bool') # FIXME: this should not be a Task - # include = FieldAttribute(isa='string') + # include = FieldAttribute(isa='string') + _loop = Attribute() _local_action = FieldAttribute(isa='string') # FIXME: this should not be a Task - _meta = FieldAttribute(isa='string') + _module_args = Attribute(isa='dict') + _meta = FieldAttribute(isa='string') _name = FieldAttribute(isa='string') @@ -106,6 +111,44 @@ def __repr__(self): ''' returns a human readable representation of the task ''' return "TASK: %s" % self.get_name() + def munge(self, ds): + ''' + tasks are especially complex arguments so need pre-processing. + keep it short. + ''' + + + assert isinstance(ds, dict) + + new_ds = dict() + for (k,v) in ds.iteritems(): + + # if any attributes of the datastructure match a module name + # convert it to "module + args" + + if k in module_finder: + if _module.value is not None or 'action' in ds or 'local_action' in ds: + raise AnsibleError("duplicate action in task: %s" % k) + _module.value = k + _module_args.value = v + + # handle any loops, there can be only one kind of loop + + elif "with_%s" % k in lookup_finder: + if _loop.value is not None: + raise AnsibleError("duplicate loop in task: %s" % k) + _loop.value = k + _loop_args.value = v + + # otherwise send it through straight + + else: + # nothing we need to filter + new_ds[k] = v + + return new_ds + + # ================================================================================== # BELOW THIS LINE # info below this line is "old" and is before the attempt to build Attributes @@ -119,7 +162,7 @@ def _load_action(self, ds, k, v): results = dict() module_name, params = v.strip().split(' ', 1) - if module_name not in utils.plugins.module_finder: + if module_name not in module_finder: raise AnsibleError("the specified module '%s' could not be found, check your module path" % module_name) results['_module_name'] = module_name results['_parameters'] = utils.parse_kv(params) diff --git a/v2/ansible/plugins/__init__.py b/v2/ansible/plugins/__init__.py index d6c11ffa74293b..4bb6c393120b9d 100644 --- a/v2/ansible/plugins/__init__.py +++ b/v2/ansible/plugins/__init__.py @@ -1,4 +1,5 @@ -# (c) 2012-2014, Michael DeHaan +# (c) 2012, Daniel Hokka Zakrisson +# (c) 2012-2014, Michael DeHaan and others # # This file is part of Ansible # @@ -15,3 +16,271 @@ # You should have received a copy of the GNU General Public License # along with Ansible. If not, see . +import os +import os.path +import sys +import glob +import imp +from ansible import constants as C +from ansible import errors + +MODULE_CACHE = {} +PATH_CACHE = {} +PLUGIN_PATH_CACHE = {} +_basedirs = [] + +def push_basedir(basedir): + # avoid pushing the same absolute dir more than once + basedir = os.path.realpath(basedir) + if basedir not in _basedirs: + _basedirs.insert(0, basedir) + +class PluginLoader(object): + + ''' + PluginLoader loads plugins from the configured plugin directories. + + It searches for plugins by iterating through the combined list of + play basedirs, configured paths, and the python path. + The first match is used. + ''' + + def __init__(self, class_name, package, config, subdir, aliases={}): + + self.class_name = class_name + self.package = package + self.config = config + self.subdir = subdir + self.aliases = aliases + + if not class_name in MODULE_CACHE: + MODULE_CACHE[class_name] = {} + if not class_name in PATH_CACHE: + PATH_CACHE[class_name] = None + if not class_name in PLUGIN_PATH_CACHE: + PLUGIN_PATH_CACHE[class_name] = {} + + self._module_cache = MODULE_CACHE[class_name] + self._paths = PATH_CACHE[class_name] + self._plugin_path_cache = PLUGIN_PATH_CACHE[class_name] + + self._extra_dirs = [] + + def print_paths(self): + ''' Returns a string suitable for printing of the search path ''' + + # Uses a list to get the order right + ret = [] + for i in self._get_paths(): + if i not in ret: + ret.append(i) + return os.pathsep.join(ret) + + def _all_directories(self, dir): + results = [] + results.append(dir) + for root, subdirs, files in os.walk(dir): + if '__init__.py' in files: + for x in subdirs: + results.append(os.path.join(root,x)) + return results + + def _get_package_paths(self): + ''' Gets the path of a Python package ''' + + paths = [] + if not self.package: + return [] + if not hasattr(self, 'package_path'): + m = __import__(self.package) + parts = self.package.split('.')[1:] + self.package_path = os.path.join(os.path.dirname(m.__file__), *parts) + paths.extend(self._all_directories(self.package_path)) + return paths + + def _get_paths(self): + ''' Return a list of paths to search for plugins in ''' + + if self._paths is not None: + return self._paths + + ret = self._extra_dirs[:] + for basedir in _basedirs: + fullpath = os.path.realpath(os.path.join(basedir, self.subdir)) + if os.path.isdir(fullpath): + + files = glob.glob("%s/*" % fullpath) + + # allow directories to be two levels deep + files2 = glob.glob("%s/*/*" % fullpath) + + if files2 is not None: + files.extend(files2) + + for file in files: + if os.path.isdir(file) and file not in ret: + ret.append(file) + if fullpath not in ret: + ret.append(fullpath) + + # look in any configured plugin paths, allow one level deep for subcategories + if self.config is not None: + configured_paths = self.config.split(os.pathsep) + for path in configured_paths: + path = os.path.realpath(os.path.expanduser(path)) + contents = glob.glob("%s/*" % path) + for c in contents: + if os.path.isdir(c) and c not in ret: + ret.append(c) + if path not in ret: + ret.append(path) + + # look for any plugins installed in the package subtree + ret.extend(self._get_package_paths()) + + # cache and return the result + self._paths = ret + return ret + + + def add_directory(self, directory, with_subdir=False): + ''' Adds an additional directory to the search path ''' + + directory = os.path.realpath(directory) + + if directory is not None: + if with_subdir: + directory = os.path.join(directory, self.subdir) + if directory not in self._extra_dirs: + # append the directory and invalidate the path cache + self._extra_dirs.append(directory) + self._paths = None + + def find_plugin(self, name, suffixes=None, transport=''): + ''' Find a plugin named name ''' + + if not suffixes: + if self.class_name: + suffixes = ['.py'] + else: + if transport == 'winrm': + suffixes = ['.ps1', ''] + else: + suffixes = ['.py', ''] + + for suffix in suffixes: + full_name = '%s%s' % (name, suffix) + if full_name in self._plugin_path_cache: + return self._plugin_path_cache[full_name] + + for i in self._get_paths(): + path = os.path.join(i, full_name) + if os.path.isfile(path): + self._plugin_path_cache[full_name] = path + return path + + return None + + def has_plugin(self, name): + ''' Checks if a plugin named name exists ''' + + return self.find_plugin(name) is not None + + __contains__ = has_plugin + + def get(self, name, *args, **kwargs): + ''' instantiates a plugin of the given name using arguments ''' + + if name in self.aliases: + name = self.aliases[name] + path = self.find_plugin(name) + if path is None: + return None + if path not in self._module_cache: + self._module_cache[path] = imp.load_source('.'.join([self.package, name]), path) + return getattr(self._module_cache[path], self.class_name)(*args, **kwargs) + + def all(self, *args, **kwargs): + ''' instantiates all plugins with the same arguments ''' + + for i in self._get_paths(): + matches = glob.glob(os.path.join(i, "*.py")) + matches.sort() + for path in matches: + name, ext = os.path.splitext(os.path.basename(path)) + if name.startswith("_"): + continue + if path not in self._module_cache: + self._module_cache[path] = imp.load_source('.'.join([self.package, name]), path) + yield getattr(self._module_cache[path], self.class_name)(*args, **kwargs) + +action_loader = PluginLoader( + 'ActionModule', + 'ansible.runner.action_plugins', + C.DEFAULT_ACTION_PLUGIN_PATH, + 'action_plugins' +) + +cache_loader = PluginLoader( + 'CacheModule', + 'ansible.cache', + C.DEFAULT_CACHE_PLUGIN_PATH, + 'cache_plugins' +) + +callback_loader = PluginLoader( + 'CallbackModule', + 'ansible.callback_plugins', + C.DEFAULT_CALLBACK_PLUGIN_PATH, + 'callback_plugins' +) + +connection_loader = PluginLoader( + 'Connection', + 'ansible.runner.connection_plugins', + C.DEFAULT_CONNECTION_PLUGIN_PATH, + 'connection_plugins', + aliases={'paramiko': 'paramiko_ssh'} +) + +shell_loader = PluginLoader( + 'ShellModule', + 'ansible.runner.shell_plugins', + 'shell_plugins', + 'shell_plugins', +) + +module_finder = PluginLoader( + '', + 'ansible.modules', + C.DEFAULT_MODULE_PATH, + 'library' +) + +lookup_finder = PluginLoader( + 'LookupModule', + 'ansible.runner.lookup_plugins', + C.DEFAULT_LOOKUP_PLUGIN_PATH, + 'lookup_plugins' +) + +vars_finder = PluginLoader( + 'VarsModule', + 'ansible.inventory.vars_plugins', + C.DEFAULT_VARS_PLUGIN_PATH, + 'vars_plugins' +) + +filter_finder = PluginLoader( + 'FilterModule', + 'ansible.runner.filter_plugins', + C.DEFAULT_FILTER_PLUGIN_PATH, + 'filter_plugins' +) + +fragment_finder = PluginLoader( + 'ModuleDocFragment', + 'ansible.utils.module_docs_fragments', + os.path.join(os.path.dirname(__file__), 'module_docs_fragments'), + '', +)