Skip to content

Commit

Permalink
ENH: Make config.get_subjects_dir() work with pathlib.Path (mne-tools…
Browse files Browse the repository at this point in the history
…#9465)

* Make config.get_subjects_dir() work with Path

* special-case None

Co-authored-by: Eric Larson <[email protected]>

* Add regression test

* Don't alter global state in test

* Remove tmpdir fixture

Co-authored-by: Eric Larson <[email protected]>
  • Loading branch information
hoechenberger and larsoner authored Jun 11, 2021
1 parent 51f25d3 commit 6d7dadd
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 2 deletions.
2 changes: 1 addition & 1 deletion mne/utils/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def _check_event_id(event_id, events):

def _check_fname(fname, overwrite=False, must_exist=False, name='File',
need_dir=False):
"""Check for file existence."""
"""Check for file existence, and return string of its absolute path."""
_validate_type(fname, 'path-like', name)
if op.exists(fname):
if not overwrite:
Expand Down
5 changes: 5 additions & 0 deletions mne/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,13 @@ def get_subjects_dir(subjects_dir=None, raise_error=False):
value : str | None
The SUBJECTS_DIR value.
"""
_validate_type(item=subjects_dir, types=('path-like', None),
item_name='subjects_dir', type_name='str or path-like')

if subjects_dir is None:
subjects_dir = get_config('SUBJECTS_DIR', raise_error=raise_error)
if subjects_dir is not None:
subjects_dir = str(subjects_dir)
return subjects_dir


Expand Down
17 changes: 16 additions & 1 deletion mne/utils/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from mne.utils import (set_config, get_config, get_config_path,
set_memmap_min_size, _get_stim_channel, sys_info,
ClosingStringIO)
ClosingStringIO, get_subjects_dir)


def test_config(tmpdir):
Expand Down Expand Up @@ -87,3 +87,18 @@ def test_sys_info():

if platform.system() == 'Darwin':
assert 'Platform: macOS-' in out


def test_get_subjects_dir(monkeypatch):
"""Test get_subjects_dir()."""
# String
subjects_dir = '/foo'
assert get_subjects_dir(subjects_dir) == subjects_dir

# Path
subjects_dir = Path('/foo')
assert get_subjects_dir(subjects_dir) == str(subjects_dir)

# `None`
monkeypatch.delenv('SUBJECTS_DIR', raising=False)
assert get_subjects_dir() is None

0 comments on commit 6d7dadd

Please sign in to comment.