Skip to content

Commit

Permalink
Merge pull request conda#7826 from teake/expand_env_vars
Browse files Browse the repository at this point in the history
Allow expansion of environment variables
  • Loading branch information
msarahan authored Dec 7, 2018
2 parents db6af4f + 28adfb7 commit 7dbe8c2
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 26 deletions.
34 changes: 22 additions & 12 deletions conda/base/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,10 @@ class Context(Configuration):

_root_prefix = PrimitiveParameter("", aliases=('root_dir', 'root_prefix'))
_envs_dirs = SequenceParameter(string_types, aliases=('envs_dirs', 'envs_path'),
string_delimiter=os.pathsep)
_pkgs_dirs = SequenceParameter(string_types, aliases=('pkgs_dirs',))
string_delimiter=os.pathsep,
expandvars=True)
_pkgs_dirs = SequenceParameter(string_types, aliases=('pkgs_dirs',),
expandvars=True)
_subdir = PrimitiveParameter('', aliases=('subdir',))
_subdirs = SequenceParameter(string_types, aliases=('subdirs',))

Expand All @@ -151,12 +153,15 @@ class Context(Configuration):
# remote connection details
ssl_verify = PrimitiveParameter(True, element_type=string_types + (bool,),
aliases=('verify_ssl',),
validation=ssl_verify_validation)
validation=ssl_verify_validation,
expandvars=True)
client_ssl_cert = PrimitiveParameter(None, aliases=('client_cert',),
element_type=string_types + (NoneType,))
element_type=string_types + (NoneType,),
expandvars=True)
client_ssl_cert_key = PrimitiveParameter(None, aliases=('client_cert_key',),
element_type=string_types + (NoneType,))
proxy_servers = MapParameter(string_types + (NoneType,))
element_type=string_types + (NoneType,),
expandvars=True)
proxy_servers = MapParameter(string_types + (NoneType,), expandvars=True)
remote_connect_timeout_secs = PrimitiveParameter(9.15)
remote_read_timeout_secs = PrimitiveParameter(60.)
remote_max_retries = PrimitiveParameter(3)
Expand All @@ -172,19 +177,24 @@ class Context(Configuration):
validation=channel_alias_validation)
channel_priority = PrimitiveParameter(ChannelPriority.FLEXIBLE)
_channels = SequenceParameter(string_types, default=(DEFAULTS_CHANNEL_NAME,),
aliases=('channels', 'channel',)) # channel for args.channel
aliases=('channels', 'channel',),
expandvars=True) # channel for args.channel
_custom_channels = MapParameter(string_types, DEFAULT_CUSTOM_CHANNELS,
aliases=('custom_channels',))
_custom_multichannels = MapParameter(list, aliases=('custom_multichannels',))
aliases=('custom_channels',),
expandvars=True)
_custom_multichannels = MapParameter(list, aliases=('custom_multichannels',),
expandvars=True)
_default_channels = SequenceParameter(string_types, DEFAULT_CHANNELS,
aliases=('default_channels',))
aliases=('default_channels',),
expandvars=True)
_migrated_channel_aliases = SequenceParameter(string_types,
aliases=('migrated_channel_aliases',))
migrated_custom_channels = MapParameter(string_types) # TODO: also take a list of strings
migrated_custom_channels = MapParameter(string_types,
expandvars=True) # TODO: also take a list of strings
override_channels_enabled = PrimitiveParameter(True)
show_channel_urls = PrimitiveParameter(None, element_type=(bool, NoneType))
use_local = PrimitiveParameter(False)
whitelist_channels = SequenceParameter(string_types)
whitelist_channels = SequenceParameter(string_types, expandvars=True)

always_softlink = PrimitiveParameter(False, aliases=('softlink',))
always_copy = PrimitiveParameter(False, aliases=('copy',))
Expand Down
47 changes: 35 additions & 12 deletions conda/common/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@
from itertools import chain
from logging import getLogger
from os import environ, stat
from os.path import basename, join
from os.path import basename, join, expandvars
from stat import S_IFDIR, S_IFMT, S_IFREG

from enum import Enum, EnumMeta

from .compat import (isiterable, iteritems, itervalues, odict, primitive_types, string_types,
text_type, with_metaclass)
from .compat import (binary_type, isiterable, iteritems, itervalues, odict, primitive_types,
string_types, text_type, with_metaclass)
from .constants import NULL
from .path import expand
from .serialize import yaml_load
Expand Down Expand Up @@ -69,6 +69,12 @@ def pretty_map(dictionary, padding=' '):
return '\n'.join("%s%s: %s" % (padding, key, value) for key, value in iteritems(dictionary))


def expand_environment_variables(unexpanded):
if isinstance(unexpanded, string_types) or isinstance(unexpanded, binary_type):
return expandvars(unexpanded)
else:
return unexpanded

class ConfigurationError(CondaError):
pass

Expand Down Expand Up @@ -383,12 +389,13 @@ class Parameter(object):
_type = None
_element_type = None

def __init__(self, default, aliases=(), validation=None):
def __init__(self, default, aliases=(), validation=None, expandvars=False):
self._name = None
self._names = None
self.default = default
self.aliases = aliases
self._validation = validation
self._expandvars = expandvars

def _set_name(self, name):
# this is an explicit method, and not a descriptor/setter
Expand Down Expand Up @@ -445,6 +452,19 @@ def _get_all_matches(self, instance):
def _merge(self, matches):
raise NotImplementedError()

def _expand(self, data):
if self._expandvars:
# This is similar to conda._vendor.auxlib.type_coercion.typify_data_structure
# It could be DRY-er but that would break SRP.
if isinstance(data, Mapping):
return type(data)((k, expand_environment_variables(v)) for k, v in iteritems(data))
elif isiterable(data):
return type(data)(expand_environment_variables(v) for v in data)
else:
return expand_environment_variables(data)
else:
return data

def __get__(self, instance, instance_type):
# strategy is "extract and merge," which is actually just map and reduce
# extract matches from each source in SEARCH_PATH
Expand All @@ -453,9 +473,12 @@ def __get__(self, instance, instance_type):
return instance._cache_[self.name]

matches, errors = self._get_all_matches(instance)
merged = self._merge(matches) if matches else self.default
# We need to expand any environment variables before type casting.
# Otherwise e.g. `my_bool_var: $BOOL` with BOOL=True would raise a TypeCoercionError.
expanded = self._expand(merged)
try:
result = typify_data_structure(self._merge(matches) if matches else self.default,
self._element_type)
result = typify_data_structure(expanded, self._element_type)
except TypeCoercionError as e:
errors.append(CustomValidationError(self.name, e.value, "<<merged>>", text_type(e)))
else:
Expand Down Expand Up @@ -515,7 +538,7 @@ class PrimitiveParameter(Parameter):
python 2 has long and unicode types.
"""

def __init__(self, default, aliases=(), validation=None, element_type=None):
def __init__(self, default, aliases=(), validation=None, element_type=None, expandvars=False):
"""
Args:
default (Any): The parameter's default value.
Expand All @@ -529,7 +552,7 @@ def __init__(self, default, aliases=(), validation=None, element_type=None):
"""
self._type = type(default) if element_type is None else element_type
self._element_type = self._type
super(PrimitiveParameter, self).__init__(default, aliases, validation)
super(PrimitiveParameter, self).__init__(default, aliases, validation, expandvars)

def _merge(self, matches):
important_match = first(matches, self._match_key_is_important, default=None)
Expand All @@ -554,7 +577,7 @@ class SequenceParameter(Parameter):
_type = tuple

def __init__(self, element_type, default=(), aliases=(), validation=None,
string_delimiter=','):
string_delimiter=',', expandvars=False):
"""
Args:
element_type (type or Iterable[type]): The generic type of each element in
Expand All @@ -567,7 +590,7 @@ def __init__(self, element_type, default=(), aliases=(), validation=None,
"""
self._element_type = element_type
self.string_delimiter = string_delimiter
super(SequenceParameter, self).__init__(default, aliases, validation)
super(SequenceParameter, self).__init__(default, aliases, validation, expandvars)

def collect_errors(self, instance, value, source="<<merged>>"):
errors = super(SequenceParameter, self).collect_errors(instance, value)
Expand Down Expand Up @@ -643,7 +666,7 @@ class MapParameter(Parameter):
"""
_type = frozendict

def __init__(self, element_type, default=None, aliases=(), validation=None):
def __init__(self, element_type, default=None, aliases=(), validation=None, expandvars=False):
"""
Args:
element_type (type or Iterable[type]): The generic type of each element.
Expand All @@ -655,7 +678,7 @@ def __init__(self, element_type, default=None, aliases=(), validation=None):
"""
self._element_type = element_type
default = default and frozendict(default) or frozendict()
super(MapParameter, self).__init__(default, aliases, validation)
super(MapParameter, self).__init__(default, aliases, validation, expandvars)

def collect_errors(self, instance, value, source="<<merged>>"):
errors = super(MapParameter, self).collect_errors(instance, value)
Expand Down
33 changes: 31 additions & 2 deletions tests/common/test_configuration.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function, unicode_literals

from conda.common.io import env_var
from conda.common.io import env_var, env_vars

from conda._vendor.auxlib.ish import dals
from conda.common.compat import odict, string_types
Expand All @@ -18,7 +18,6 @@
from tempfile import mkdtemp
from unittest import TestCase


test_yaml_raw = {
'file1': dals("""
always_yes: no
Expand Down Expand Up @@ -135,6 +134,20 @@
# comment
value
"""),
'env_vars': dals("""
env_var_map:
expanded: $EXPANDED_VAR
unexpanded: $UNEXPANDED_VAR
env_var_str: $EXPANDED_VAR
env_var_bool: $BOOL_VAR
normal_str: $EXPANDED_VAR
env_var_list:
- $EXPANDED_VAR
- $UNEXPANDED_VAR
- regular_var
"""),

}

Expand All @@ -150,6 +163,12 @@ class SampleConfiguration(Configuration):
boolean_map = MapParameter(bool)
commented_map = MapParameter(string_types)

env_var_map = MapParameter(string_types, expandvars=True)
env_var_str = PrimitiveParameter('', expandvars=True)
env_var_bool = PrimitiveParameter(False, element_type=bool, expandvars=True)
normal_str = PrimitiveParameter('', expandvars=False)
env_var_list = SequenceParameter(string_types, expandvars=True)


def load_from_string_data(*seq):
return odict((f, YamlRawParameter.make_raw_parameters(f, yaml_load(test_yaml_raw[f])))
Expand Down Expand Up @@ -463,3 +482,13 @@ def test_invalid_seq_parameter(self):
config = SampleConfiguration()._set_raw_data(data)
with raises(InvalidTypeError):
config.channels

def test_expanded_variables(self):
with env_vars({'EXPANDED_VAR': 'itsexpanded', 'BOOL_VAR': 'True'}):
config = SampleConfiguration()._set_raw_data(load_from_string_data('env_vars'))
assert config.env_var_map['expanded'] == 'itsexpanded'
assert config.env_var_map['unexpanded'] == '$UNEXPANDED_VAR'
assert config.env_var_str == 'itsexpanded'
assert config.env_var_bool is True
assert config.normal_str == '$EXPANDED_VAR'
assert config.env_var_list == ('itsexpanded', '$UNEXPANDED_VAR', 'regular_var')
35 changes: 35 additions & 0 deletions tests/models/test_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,41 @@ def test_old_channel_alias(self):
"ftp://new.url:8082/conda-forge/label/dev/noarch",
]

class ChannelEnvironmentVarExpansionTest(TestCase):

@classmethod
def setUpClass(cls):
channels_config = dals("""
channels:
- http://user22:[email protected]:8080
whitelist_channels:
- http://user22:[email protected]:8080
custom_channels:
unexpanded: http://user1:[email protected]:8080/with/path/t/tk-1234
expanded: http://user33:[email protected]:8080/with/path/t/tk-1234
""")
reset_context()
rd = odict(testdata=YamlRawParameter.make_raw_parameters('testdata', yaml_load(channels_config)))
context._set_raw_data(rd)

@classmethod
def tearDownClass(cls):
reset_context()

def test_unexpanded_variables(self):
with env_var('EXPANDED_PWD', 'pass44'):
channel = Channel('unexpanded')
assert channel.auth == 'user1:$UNEXPANDED_PWD'

def test_expanded_variables(self):
with env_var('EXPANDED_PWD', 'pass44'):
channel = Channel('expanded')
assert channel.auth == 'user33:pass44'
assert context.channels[0] == 'http://user22:[email protected]:8080'
assert context.whitelist_channels[0] == 'http://user22:[email protected]:8080'


class ChannelAuthTokenPriorityTests(TestCase):

Expand Down

0 comments on commit 7dbe8c2

Please sign in to comment.