diff --git a/conda_env/env.py b/conda_env/env.py index 10d6d3064fd..ab800aaa6a7 100644 --- a/conda_env/env.py +++ b/conda_env/env.py @@ -23,6 +23,32 @@ from conda._vendor.toolz.itertoolz import concatv, groupby # NOQA +VALID_KEYS = ('name', 'dependencies', 'prefix', 'channels') + + +def validate_keys(data, kwargs): + """Check for unknown keys, remove them and print a warning.""" + invalid_keys = [] + new_data = data.copy() + for key in data.keys(): + if key not in VALID_KEYS: + invalid_keys.append(key) + new_data.pop(key) + + if invalid_keys: + filename = kwargs.get('filename') + verb = 'are' if len(invalid_keys) != 1 else 'is' + plural = 's' if len(invalid_keys) != 1 else '' + print("\nEnvironmentSectionNotValid: The following section{plural} on " + "'{filename}' {verb} invalid and will be ignored:" + "".format(filename=filename, plural=plural, verb=verb)) + for key in invalid_keys: + print(' - {}'.format(key)) + print('') + + return new_data + + def load_from_directory(directory): """Load and return an ``Environment`` from a given ``directory``""" files = ['environment.yml', 'environment.yaml'] @@ -86,9 +112,12 @@ def from_environment(name, prefix, no_builds=False, ignore_channels=False): def from_yaml(yamlstr, **kwargs): """Load and return a ``Environment`` from a given ``yaml string``""" data = yaml_load_standard(yamlstr) + data = validate_keys(data, kwargs) + if kwargs is not None: for key, value in kwargs.items(): data[key] = value + return Environment(**data) diff --git a/conda_env/exceptions.py b/conda_env/exceptions.py index 1b6e95e1684..d8bc2c0c607 100644 --- a/conda_env/exceptions.py +++ b/conda_env/exceptions.py @@ -6,17 +6,24 @@ class CondaEnvException(CondaError): def __init__(self, message, *args, **kwargs): - msg = "Conda Env Exception: %s" % message + msg = "%s" % message super(CondaEnvException, self).__init__(msg, *args, **kwargs) class EnvironmentFileNotFound(CondaEnvException): def __init__(self, filename, *args, **kwargs): - msg = '{} file not found'.format(filename) + msg = "'{}' file not found".format(filename) self.filename = filename super(EnvironmentFileNotFound, self).__init__(msg, *args, **kwargs) +class EnvironmentFileExtensionNotValid(CondaEnvException): + def __init__(self, filename, *args, **kwargs): + msg = "'{}' file extension must be one of '.txt', '.yaml' or '.yml'".format(filename) + self.filename = filename + super(EnvironmentFileExtensionNotValid, self).__init__(msg, *args, **kwargs) + + class NoBinstar(CondaError): def __init__(self): msg = 'The anaconda-client cli must be installed to perform this action' diff --git a/conda_env/specs/__init__.py b/conda_env/specs/__init__.py index dbf55c19bf0..4407e050e1d 100644 --- a/conda_env/specs/__init__.py +++ b/conda_env/specs/__init__.py @@ -8,28 +8,31 @@ from .notebook import NotebookSpec from .requirements import RequirementsSpec from .yaml_file import YamlFileSpec -from ..exceptions import EnvironmentFileNotFound, SpecNotFound - - -all_specs = [ - BinstarSpec, - NotebookSpec, - YamlFileSpec, - RequirementsSpec -] +from ..exceptions import (EnvironmentFileExtensionNotValid, EnvironmentFileNotFound, + SpecNotFound) def detect(**kwargs): - # Check file existence if --file was provided + # Check file existence filename = kwargs.get('filename') if filename and not os.path.isfile(filename): raise EnvironmentFileNotFound(filename=filename) + # Check extensions + all_valid_exts = YamlFileSpec.extensions.union(RequirementsSpec.extensions) + fname, ext = os.path.splitext(filename) + if ext == '' or ext not in all_valid_exts: + raise EnvironmentFileExtensionNotValid(filename) + elif ext in YamlFileSpec.extensions: + specs = [YamlFileSpec] + elif ext in RequirementsSpec.extensions: + specs = [RequirementsSpec] + else: + specs = [NotebookSpec, BinstarSpec] + # Check specifications - specs = [] - for SpecClass in all_specs: + for SpecClass in specs: spec = SpecClass(**kwargs) - specs.append(spec) if spec.can_handle(): return spec diff --git a/conda_env/specs/requirements.py b/conda_env/specs/requirements.py index 864325b8579..cf56aacd0fa 100644 --- a/conda_env/specs/requirements.py +++ b/conda_env/specs/requirements.py @@ -12,6 +12,7 @@ class RequirementsSpec(object): and returns an Environment object from it. ''' msg = None + extensions = set(['.txt', ]) def __init__(self, filename=None, name=None, **kwargs): self.filename = filename diff --git a/conda_env/specs/yaml_file.py b/conda_env/specs/yaml_file.py index aaabf976173..dd3ff855576 100644 --- a/conda_env/specs/yaml_file.py +++ b/conda_env/specs/yaml_file.py @@ -7,6 +7,7 @@ class YamlFileSpec(object): _environment = None + extensions = set(('.yaml', '.yml')) def __init__(self, filename=None, **kwargs): self.filename = filename diff --git a/tests/conda_env/support/invalid_keys.yml b/tests/conda_env/support/invalid_keys.yml new file mode 100644 index 00000000000..8ac9ec04a04 --- /dev/null +++ b/tests/conda_env/support/invalid_keys.yml @@ -0,0 +1,6 @@ +names: nlp +chanels: + - anaconda +dependecies: + - nltk +prefis: /something/miniconda/envs/test diff --git a/tests/conda_env/support/valid_keys.yml b/tests/conda_env/support/valid_keys.yml new file mode 100644 index 00000000000..664a1f82d66 --- /dev/null +++ b/tests/conda_env/support/valid_keys.yml @@ -0,0 +1,6 @@ +name: nlp +channels: + - anaconda +dependencies: + - nltk +prefix: /something/miniconda/envs/test diff --git a/tests/conda_env/test_cli.py b/tests/conda_env/test_cli.py index f7928bde663..b12d7be069e 100644 --- a/tests/conda_env/test_cli.py +++ b/tests/conda_env/test_cli.py @@ -16,9 +16,11 @@ from conda.exceptions import EnvironmentLocationNotFound from conda.install import rm_rf from conda_env.cli.main import create_parser, do_call as do_call_conda_env -from conda_env.exceptions import EnvironmentFileNotFound +from conda_env.exceptions import EnvironmentFileExtensionNotValid, EnvironmentFileNotFound from conda_env.yaml import load as yaml_load +from . import support_file + environment_1 = ''' name: env-1 dependencies: @@ -36,6 +38,16 @@ - malev ''' +environment_3_invalid = ''' +name: env-1 +dependecies: + - python + - flask +channels: + - malev +foo: bar +''' + test_env_name_1 = "env-1" test_env_name_2 = "snowflakes" test_env_name_3 = "env_foo" @@ -264,7 +276,7 @@ def test_env_export(self): snowflake, e, = run_env_command(Commands.ENV_EXPORT, test_env_name_2) - with tempfile.NamedTemporaryFile(mode="w", suffix="yml", delete=False) as env_yaml: + with tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False) as env_yaml: env_yaml.write(snowflake) env_yaml.flush() env_yaml.close() @@ -292,7 +304,7 @@ def test_list(self): snowflake, e = run_conda_command(Commands.LIST, test_env_name_2, "-e") - with tempfile.NamedTemporaryFile(mode="w", suffix="txt", delete=False) as env_txt: + with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as env_txt: env_txt.write(snowflake) env_txt.flush() env_txt.close() @@ -320,7 +332,7 @@ def test_export_muti_channel(self): check1, e = run_conda_command(Commands.LIST, test_env_name_2, "--explicit") - with tempfile.NamedTemporaryFile(mode="w", suffix="yml", delete=False) as env_yaml: + with tempfile.NamedTemporaryFile(mode="w", suffix=".yml", delete=False) as env_yaml: env_yaml.write(snowflake) env_yaml.flush() env_yaml.close() @@ -333,6 +345,16 @@ def test_export_muti_channel(self): check2, e = run_conda_command(Commands.LIST, test_env_name_2, "--explicit") self.assertEqual(check1, check2) + def test_non_existent_file(self): + with self.assertRaises(EnvironmentFileNotFound): + run_env_command(Commands.ENV_CREATE, 'i_do_not_exist.yml') + + def test_invalid_extensions(self): + with tempfile.NamedTemporaryFile(mode="w", suffix=".ymla", delete=False) as env_yaml: + with self.assertRaises(EnvironmentFileExtensionNotValid): + run_env_command(Commands.ENV_CREATE, env_yaml.name) + + if __name__ == '__main__': unittest.main() diff --git a/tests/conda_env/test_env.py b/tests/conda_env/test_env.py index af5cd728215..b0619af18a1 100644 --- a/tests/conda_env/test_env.py +++ b/tests/conda_env/test_env.py @@ -25,8 +25,20 @@ def write(self, chunk): self.output += chunk.decode('utf-8') +def get_environment(filename): + return env.from_file(support_file(filename)) + + def get_simple_environment(): - return env.from_file(support_file('simple.yml')) + return get_environment('simple.yml') + + +def get_valid_keys_environment(): + return get_environment('valid_keys.yml') + + +def get_invalid_keys_environment(): + return get_environment('invalid_keys.yml') class from_file_TestCase(unittest.TestCase): @@ -215,6 +227,18 @@ def test_dependencies_update_after_adding(self): e.dependencies.add('bar') assert 'bar' in e.dependencies['conda'] + def test_valid_keys(self): + e = get_valid_keys_environment() + e_dict = e.to_dict() + for key in env.VALID_KEYS: + assert key in e_dict + + def test_invalid_keys(self): + e = get_invalid_keys_environment() + e_dict = e.to_dict() + assert 'name' in e_dict + assert len(e_dict) == 1 + class DirectoryTestCase(unittest.TestCase): directory = support_file('example')