Skip to content

Commit

Permalink
Add "array-like" to _validate_type() (mne-tools#11713)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbrnr authored May 31, 2023
1 parent 4914d23 commit 917b000
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 20 deletions.
2 changes: 1 addition & 1 deletion doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ Bugs
API changes
~~~~~~~~~~~
- None yet
- The ``baseline`` argument can now be array-like (e.g. ``list``, ``tuple``, ``np.ndarray``, ...) instead of only a ``tuple`` (:gh:`11713` by `Clemens Brunner`_)
30 changes: 15 additions & 15 deletions mne/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np

from .utils import logger, verbose, _check_option
from .utils import logger, verbose, _check_option, _validate_type


def _log_rescale(baseline, mode="mean"):
Expand Down Expand Up @@ -143,42 +143,42 @@ def fun(d, m):


def _check_baseline(baseline, times, sfreq, on_baseline_outside_data="raise"):
"""Check if the baseline is valid, and adjust it if requested.
"""Check if the baseline is valid and adjust it if requested.
``None`` values inside the baseline parameter will be replaced with
``times[0]`` and ``times[-1]``.
``None`` values inside ``baseline`` will be replaced with ``times[0]`` and
``times[-1]``.
Parameters
----------
baseline : tuple | None
baseline : array-like, shape (2,) | None
Beginning and end of the baseline period, in seconds. If ``None``,
assume no baseline and return immediately.
times : array
The time points.
sfreq : float
The sampling rate.
on_baseline_outside_data : 'raise' | 'info' | 'adjust'
What do do if the baseline period exceeds the data.
What to do if the baseline period exceeds the data.
If ``'raise'``, raise an exception (default).
If ``'info'``, log an info message.
If ``'adjust'``, adjust the baseline such that it's within the data
range again.
If ``'adjust'``, adjust the baseline such that it is within the data range.
Returns
-------
(baseline_tmin, baseline_tmax) | None
The baseline with ``None`` values replaced with times, and with
adjusted times if ``on_baseline_outside_data='adjust'``; or ``None``
if the ``baseline`` parameter is ``None``.
The baseline with ``None`` values replaced with times, and with adjusted times
if ``on_baseline_outside_data='adjust'``; or ``None``, if ``baseline`` is
``None``.
"""
if baseline is None:
return None

if not isinstance(baseline, tuple) or len(baseline) != 2:
_validate_type(baseline, "array-like")
baseline = tuple(baseline)

if len(baseline) != 2:
raise ValueError(
f"`baseline={baseline}` is an invalid argument, must "
f"be a tuple of length 2 or None"
f"baseline must have exactly two elements (got {len(baseline)})."
)

tmin, tmax = times[0], times[-1]
Expand Down
2 changes: 1 addition & 1 deletion mne/tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1349,7 +1349,7 @@ def test_epochs_io_preload(tmp_path, preload):
epochs_no_bl.save(temp_fname_no_bl, overwrite=True)
epochs_read = read_epochs(temp_fname)
epochs_no_bl_read = read_epochs(temp_fname_no_bl)
with pytest.raises(ValueError, match="invalid"):
with pytest.raises(ValueError, match="exactly two elements"):
epochs.apply_baseline(baseline=[1, 2, 3])
epochs_with_bl = epochs_no_bl_read.copy().apply_baseline(baseline)
assert isinstance(epochs_with_bl, BaseEpochs)
Expand Down
7 changes: 4 additions & 3 deletions mne/utils/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# License: BSD-3-Clause

from builtins import input # no-op here but facilitates testing
from collections.abc import Sequence
from difflib import get_close_matches
from importlib import import_module
import operator
Expand Down Expand Up @@ -525,6 +526,7 @@ def __instancecheck__(cls, other):
"path-like": path_like,
"int-like": (int_like,),
"callable": (_Callable(),),
"array-like": (Sequence, np.ndarray),
}


Expand All @@ -538,9 +540,8 @@ def _validate_type(item, types=None, item_name=None, type_name=None, *, extra=""
types : type | str | tuple of types | tuple of str
The types to be checked against.
If str, must be one of {'int', 'int-like', 'str', 'numeric', 'info',
'path-like', 'callable'}.
If a tuple of str is passed, use 'int-like' and not 'int' for
integers.
'path-like', 'callable', 'array-like'}.
If a tuple of str is passed, use 'int-like' and not 'int' for integers.
item_name : str | None
Name of the item to show inside the error message.
type_name : str | None
Expand Down

0 comments on commit 917b000

Please sign in to comment.