forked from mne-tools/mne-python
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconftest.py
218 lines (186 loc) · 7.9 KB
/
conftest.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
# -*- coding: utf-8 -*-
# Author: Eric Larson <[email protected]>
#
# License: BSD (3-clause)
import os
import os.path as op
import shutil
import sys
import warnings
import pytest
# For some unknown reason, on Travis-xenial there are segfaults caused on
# the line pytest -> pdb.Pdb.__init__ -> "import readline". Forcing an
# import here seems to prevent them (!?). This suggests a potential problem
# with some other library stepping on memory where it shouldn't. It only
# seems to happen on the Linux runs that install Mayavi. Anectodally,
# @larsoner has had problems a couple of years ago where a mayavi import
# seemed to corrupt SciPy linalg function results (!), likely due to the
# associated VTK import, so this could be another manifestation of that.
try:
import readline # noqa
except Exception:
pass
import numpy as np
import mne
from mne.datasets import testing
from mne.fixes import _fn35
test_path = testing.data_path(download=False)
s_path = op.join(test_path, 'MEG', 'sample')
fname_evoked = op.join(s_path, 'sample_audvis_trunc-ave.fif')
fname_cov = op.join(s_path, 'sample_audvis_trunc-cov.fif')
fname_fwd = op.join(s_path, 'sample_audvis_trunc-meg-eeg-oct-4-fwd.fif')
subjects_dir = op.join(test_path, 'subjects')
def pytest_configure(config):
"""Configure pytest options."""
# Markers
for marker in ('slowtest', 'ultraslowtest'):
config.addinivalue_line('markers', marker)
# Fixtures
for fixture in ('matplotlib_config', 'fix_pytest_tmpdir_35'):
config.addinivalue_line('usefixtures', fixture)
# Warnings
# - Once SciPy updates not to have non-integer and non-tuple errors (1.2.0)
# we should remove them from here.
# - This list should also be considered alongside reset_warnings in
# doc/conf.py.
warning_lines = r"""
error::
ignore::ImportWarning
ignore:the matrix subclass:PendingDeprecationWarning
ignore:numpy.dtype size changed:RuntimeWarning
ignore:.*HasTraits.trait_.*:DeprecationWarning
ignore:.*takes no parameters:DeprecationWarning
ignore:joblib not installed:RuntimeWarning
ignore:Using a non-tuple sequence for multidimensional indexing:FutureWarning
ignore:using a non-integer number instead of an integer will result in an error:DeprecationWarning
ignore:Importing from numpy.testing.decorators is deprecated:DeprecationWarning
ignore:np.loads is deprecated, use pickle.loads instead:DeprecationWarning
ignore:The oldnumeric module will be dropped:DeprecationWarning
ignore:Collection picker None could not be converted to float:UserWarning
ignore:covariance is not positive-semidefinite:RuntimeWarning
ignore:Can only plot ICA components:RuntimeWarning
ignore:Matplotlib is building the font cache using fc-list:UserWarning
ignore:Using or importing the ABCs from 'collections':DeprecationWarning
ignore:`formatargspec` is deprecated:DeprecationWarning
# This is only necessary until sklearn updates their wheels for NumPy 1.16
ignore:numpy.ufunc size changed:RuntimeWarning
ignore:.*mne-realtime.*:DeprecationWarning
ignore:.*imp.*:DeprecationWarning
ignore:Exception creating Regex for oneOf.*:SyntaxWarning
ignore:scipy\.gradient is deprecated.*:DeprecationWarning
always:.*get_data.* is deprecated in favor of.*:DeprecationWarning
""" # noqa: E501
for warning_line in warning_lines.split('\n'):
warning_line = warning_line.strip()
if warning_line and not warning_line.startswith('#'):
config.addinivalue_line('filterwarnings', warning_line)
@pytest.fixture(scope='session')
def matplotlib_config():
"""Configure matplotlib for viz tests."""
import matplotlib
# "force" should not really be necessary but should not hurt
kwargs = dict()
with warnings.catch_warnings(record=True): # ignore warning
warnings.filterwarnings('ignore')
matplotlib.use('agg', force=True, **kwargs) # don't pop up windows
import matplotlib.pyplot as plt
assert plt.get_backend() == 'agg'
# overwrite some params that can horribly slow down tests that
# users might have changed locally (but should not otherwise affect
# functionality)
plt.ioff()
plt.rcParams['figure.dpi'] = 100
try:
from traits.etsconfig.api import ETSConfig
except Exception:
pass
else:
ETSConfig.toolkit = 'qt4'
try:
with warnings.catch_warnings(record=True): # traits
from mayavi import mlab
except Exception:
pass
else:
mlab.options.backend = 'test'
def _replace(mod, key):
orig = getattr(mod, key)
def func(x, *args, **kwargs):
return orig(_fn35(x), *args, **kwargs)
setattr(mod, key, func)
@pytest.fixture(scope='session')
def fix_pytest_tmpdir_35():
"""Deal with tmpdir being a LocalPath, which bombs on 3.5."""
if sys.version_info >= (3, 6):
return
for key in ('stat', 'mkdir', 'makedirs'):
_replace(os, key)
for key in ('split', 'splitext', 'realpath', 'join'):
_replace(op, key)
@pytest.fixture(scope='function', params=[testing._pytest_param()])
def evoked():
"""Get evoked data."""
evoked = mne.read_evokeds(fname_evoked, condition='Left Auditory',
baseline=(None, 0))
evoked.crop(0, 0.2)
return evoked
@pytest.fixture(scope='function', params=[testing._pytest_param()])
def noise_cov():
"""Get a noise cov from the testing dataset."""
return mne.read_cov(fname_cov)
@pytest.fixture(scope='function')
def bias_params_free(evoked, noise_cov):
"""Provide inputs for free bias functions."""
fwd = mne.read_forward_solution(fname_fwd)
return _bias_params(evoked, noise_cov, fwd)
@pytest.fixture(scope='function')
def bias_params_fixed(evoked, noise_cov):
"""Provide inputs for fixed bias functions."""
fwd = mne.read_forward_solution(fname_fwd)
fwd = mne.convert_forward_solution(fwd, force_fixed=True, surf_ori=True)
return _bias_params(evoked, noise_cov, fwd)
def _bias_params(evoked, noise_cov, fwd):
evoked.pick_types(meg=True, eeg=True, exclude=())
# restrict to limited set of verts (small src here) and one hemi for speed
vertices = [fwd['src'][0]['vertno'].copy(), []]
stc = mne.SourceEstimate(np.zeros((sum(len(v) for v in vertices), 1)),
vertices, 0., 1.)
fwd = mne.forward.restrict_forward_to_stc(fwd, stc)
assert fwd['sol']['row_names'] == noise_cov['names']
assert noise_cov['names'] == evoked.ch_names
evoked = mne.EvokedArray(fwd['sol']['data'].copy(), evoked.info)
data_cov = noise_cov.copy()
data_cov['data'] = np.dot(fwd['sol']['data'], fwd['sol']['data'].T)
assert data_cov['data'].shape[0] == len(noise_cov['names'])
want = np.arange(fwd['sol']['data'].shape[1])
if not mne.forward.is_fixed_orient(fwd):
want //= 3
return evoked, fwd, noise_cov, data_cov, want
@pytest.fixture(scope="module", params=[
"mayavi",
"pyvista",
])
def backend_name(request):
"""Get the backend name."""
yield request.param
@pytest.yield_fixture
def renderer(backend_name):
"""Yield the 3D backends."""
from mne.viz.backends.renderer import _use_test_3d_backend
from mne.viz.backends.tests._utils import has_mayavi, has_pyvista
if backend_name == 'mayavi':
if not has_mayavi():
pytest.skip("Test skipped, requires mayavi.")
elif backend_name == 'pyvista':
if not has_pyvista():
pytest.skip("Test skipped, requires pyvista.")
with _use_test_3d_backend(backend_name):
from mne.viz.backends import renderer
yield renderer
renderer._close_all()
@pytest.fixture(scope='function', params=[testing._pytest_param()])
def subjects_dir_tmp(tmpdir):
"""Copy MNE-testing-data subjects_dir to a temp dir for manipulation."""
for key in ('sample', 'fsaverage'):
shutil.copytree(op.join(subjects_dir, key), str(tmpdir.join(key)))
return str(tmpdir)