Skip to content

Commit

Permalink
add sample weight test LinearModel (mne-tools#4800)
Browse files Browse the repository at this point in the history
  • Loading branch information
kingjr authored and agramfort committed Dec 1, 2017
1 parent fb0a8eb commit cb3f93d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
6 changes: 4 additions & 2 deletions mne/decoding/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self, model=None): # noqa: D102
self.model = model
self._estimator_type = getattr(model, "_estimator_type", None)

def fit(self, X, y):
def fit(self, X, y, **fit_params):
"""Estimate the coefficients of the linear model.
Save the coefficients in the attribute ``filters_`` and
Expand All @@ -74,6 +74,8 @@ def fit(self, X, y):
The training input samples to estimate the linear coefficients.
y : array, shape (n_samples, [n_targets])
The target values.
**fit_params : dict of string -> object
Parameters to pass to the fit method of the estimator.
Returns
-------
Expand All @@ -89,7 +91,7 @@ def fit(self, X, y):
'got %s instead.' % (y.shape,))

# fit the Model
self.model.fit(X, y)
self.model.fit(X, y, **fit_params)

# Computes patterns using Haufe's trick: A = Cov_X . W . Precision_Y

Expand Down
3 changes: 3 additions & 0 deletions mne/decoding/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,9 @@ def inverse_transform(self, X):
lm = LinearModel(Ridge(alpha=1)).fit(X, Y)
assert_array_almost_equal(A, lm.patterns_.T, decimal=2)

# Check can pass fitting parameters
lm.fit(X, Y, sample_weight=np.ones(len(Y)))


@requires_version('sklearn', '0.15')
def test_linearmodel():
Expand Down

0 comments on commit cb3f93d

Please sign in to comment.