Skip to content

Commit

Permalink
MINOR, WIP remove warning for custom solvers in rERP (mne-tools#5437)
Browse files Browse the repository at this point in the history
* init

* specify test
  • Loading branch information
jona-sassenhagen authored Aug 12, 2018
1 parent f21b843 commit 986e8a3
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 14 deletions.
6 changes: 1 addition & 5 deletions mne/stats/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,11 +253,7 @@ def solver(X, y):
return linalg.solve(a, X.T * y, sym_pos=True,
overwrite_a=True, overwrite_b=True).T
elif callable(solver):
warn("When using a custom solver, note that since MNE 0.15, this "
"function will pass the transposed data (n_channels, n_times) "
"to the solver. If you are using a solver that expects a "
"different format, it will give wrong results and might in "
"extreme cases crash your session.")
pass
else:
raise TypeError("The solver must be a str or a callable.")

Expand Down
15 changes: 6 additions & 9 deletions mne/stats/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,18 +140,15 @@ def test_continuous_regression_with_overlap():
from sklearn.linear_model.ridge import ridge_regression

def solver(X, y):
return ridge_regression(X, y, alpha=0.)

with pytest.warns(RuntimeWarning, match='transposed'):
assert_allclose(effect, linear_regression_raw(
raw, events, tmin=0, solver=solver)['1'].data.flatten())
return ridge_regression(X, y, alpha=0., solver="cholesky")
assert_allclose(effect, linear_regression_raw(
raw, events, tmin=0, solver=solver)['1'].data.flatten())

# test bad solvers
def solT(X, y):
return ridge_regression(X, y, alpha=0.).T
with pytest.warns(RuntimeWarning, match='transposed'):
pytest.raises(ValueError, linear_regression_raw, raw, events,
solver=solT)
return ridge_regression(X, y, alpha=0., solver="cholesky").T
pytest.raises(ValueError, linear_regression_raw, raw, events,
solver=solT)
pytest.raises(ValueError, linear_regression_raw, raw, events, solver='err')
pytest.raises(TypeError, linear_regression_raw, raw, events, solver=0)

Expand Down

0 comments on commit 986e8a3

Please sign in to comment.