forked from mne-tools/mne-python
-
Notifications
You must be signed in to change notification settings - Fork 0
/
baseline.py
161 lines (142 loc) · 5.57 KB
/
baseline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
"""Utility functions to baseline-correct data."""
# Authors: Alexandre Gramfort <[email protected]>
#
# License: BSD (3-clause)
import numpy as np
from .utils import logger, verbose, _check_option
def _check_baseline(baseline, tmin, tmax, sfreq):
"""Check for a valid baseline."""
if baseline is not None:
if not isinstance(baseline, tuple) or len(baseline) != 2:
raise ValueError('`baseline=%s` is an invalid argument, must be '
'a tuple of length 2 or None' % str(baseline))
# check default value of baseline and `tmin=0`
if baseline == (None, 0) and tmin == 0:
raise ValueError('Baseline interval is only one sample. Use '
'`baseline=(0, 0)` if this is desired.')
baseline_tmin, baseline_tmax = baseline
tstep = 1. / float(sfreq)
if baseline_tmin is None:
baseline_tmin = tmin
baseline_tmin = float(baseline_tmin)
if baseline_tmax is None:
baseline_tmax = tmax
baseline_tmax = float(baseline_tmax)
if baseline_tmin < tmin - tstep:
raise ValueError(
"Baseline interval (tmin = %s) is outside of "
"data range (tmin = %s)" % (baseline_tmin, tmin))
if baseline_tmax > tmax + tstep:
raise ValueError(
"Baseline interval (tmax = %s) is outside of "
"data range (tmax = %s)" % (baseline_tmax, tmax))
if baseline_tmin > baseline_tmax:
raise ValueError(
"Baseline min (%s) must be less than baseline max (%s)"
% (baseline_tmin, baseline_tmax))
def _log_rescale(baseline, mode='mean'):
"""Log the rescaling method."""
if baseline is not None:
_check_option('mode', mode, ['logratio', 'ratio', 'zscore', 'mean',
'percent', 'zlogratio'])
bmin, bmax = baseline
bmin = None if bmin is None else f'{round(bmin, 3):.3f}'
bmax = None if bmax is None else f'{round(bmax, 3):.3f}'
unit = '' if bmin is None and bmax is None else ' sec'
msg = (f'Applying baseline correction '
f'(baseline: [{bmin}, {bmax}]{unit}, mode: {mode})')
else:
msg = 'No baseline correction applied'
return msg
@verbose
def rescale(data, times, baseline, mode='mean', copy=True, picks=None,
verbose=None):
"""Rescale (baseline correct) data.
Parameters
----------
data : array
It can be of any shape. The only constraint is that the last
dimension should be time.
times : 1D array
Time instants is seconds.
%(baseline_array)s
mode : 'mean' | 'ratio' | 'logratio' | 'percent' | 'zscore' | 'zlogratio'
Perform baseline correction by
- subtracting the mean of baseline values ('mean')
- dividing by the mean of baseline values ('ratio')
- dividing by the mean of baseline values and taking the log
('logratio')
- subtracting the mean of baseline values followed by dividing by
the mean of baseline values ('percent')
- subtracting the mean of baseline values and dividing by the
standard deviation of baseline values ('zscore')
- dividing by the mean of baseline values, taking the log, and
dividing by the standard deviation of log baseline values
('zlogratio')
copy : bool
Whether to return a new instance or modify in place.
picks : list of int | None
Data to process along the axis=-2 (None, default, processes all).
%(verbose)s
Returns
-------
data_scaled: array
Array of same shape as data after rescaling.
"""
data = data.copy() if copy else data
msg = _log_rescale(baseline, mode)
logger.info(msg)
if baseline is None or data.shape[-1] == 0:
return data
bmin, bmax = baseline
if bmin is None:
imin = 0
else:
imin = np.where(times >= bmin)[0]
if len(imin) == 0:
raise ValueError('bmin is too large (%s), it exceeds the largest '
'time value' % (bmin,))
imin = int(imin[0])
if bmax is None:
imax = len(times)
else:
imax = np.where(times <= bmax)[0]
if len(imax) == 0:
raise ValueError('bmax is too small (%s), it is smaller than the '
'smallest time value' % (bmax,))
imax = int(imax[-1]) + 1
if imin >= imax:
raise ValueError('Bad rescaling slice (%s:%s) from time values %s, %s'
% (imin, imax, bmin, bmax))
# technically this is inefficient when `picks` is given, but assuming
# that we generally pick most channels for rescaling, it's not so bad
mean = np.mean(data[..., imin:imax], axis=-1, keepdims=True)
if mode == 'mean':
def fun(d, m):
d -= m
elif mode == 'ratio':
def fun(d, m):
d /= m
elif mode == 'logratio':
def fun(d, m):
d /= m
np.log10(d, out=d)
elif mode == 'percent':
def fun(d, m):
d -= m
d /= m
elif mode == 'zscore':
def fun(d, m):
d -= m
d /= np.std(d[..., imin:imax], axis=-1, keepdims=True)
elif mode == 'zlogratio':
def fun(d, m):
d /= m
np.log10(d, out=d)
d /= np.std(d[..., imin:imax], axis=-1, keepdims=True)
if picks is None:
fun(data, mean)
else:
for pi in picks:
fun(data[..., pi, :], mean[..., pi, :])
return data