Skip to content

Commit

Permalink
FIX OrthogonalMatchingPursuit normalized twice
Browse files Browse the repository at this point in the history
  • Loading branch information
vene committed Jul 26, 2013
1 parent a4e53a6 commit 86d7afb
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
7 changes: 0 additions & 7 deletions sklearn/linear_model/omp.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,13 +679,6 @@ def fit(self, X, y, Gram=None, Xy=None):
self.tol, norms_sq,
copy_Gram, True).T

if self.normalize:
nonzeros = np.flatnonzero(X_std)
scaling = X_std[nonzeros]
if self.coef_.ndim == 2:
scaling = scaling[np.newaxis, :]
self.coef_[:, nonzeros] /= scaling

self._set_intercept(X_mean, y_mean, X_std)
return self

Expand Down
19 changes: 17 additions & 2 deletions sklearn/linear_model/tests/test_omp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

from sklearn.linear_model import (orthogonal_mp, orthogonal_mp_gram,
OrthogonalMatchingPursuit,
OrthogonalMatchingPursuitCV)
OrthogonalMatchingPursuitCV,
LinearRegression)
from sklearn.utils.fixes import count_nonzero
from sklearn.utils import check_random_state
from sklearn.datasets import make_sparse_coded_signal

n_samples, n_features, n_nonzero_coefs, n_targets = 20, 30, 5, 3
Expand Down Expand Up @@ -93,7 +95,6 @@ def test_bad_input():


def test_perfect_signal_recovery():
# XXX: use signal generator
idx, = gamma[:, 0].nonzero()
gamma_rec = orthogonal_mp(X, y[:, 0], 5)
gamma_gram = orthogonal_mp_gram(G, Xy[:, 0], 5)
Expand Down Expand Up @@ -218,3 +219,17 @@ def test_omp_cv():
n_nonzero_coefs=ompcv.n_nonzero_coefs_)
omp.fit(X, y_)
assert_array_almost_equal(ompcv.coef_, omp.coef_)


def test_omp_reaches_least_squares():
# Use small simple data; it's a sanity check but OMP can stop early
rng = check_random_state(0)
n_samples, n_features = (10, 8)
n_targets = 3
X = rng.randn(n_samples, n_features)
Y = rng.randn(n_samples, n_targets)
omp = OrthogonalMatchingPursuit(n_nonzero_coefs=n_features)
lstsq = LinearRegression()
omp.fit(X, Y)
lstsq.fit(X, Y)
assert_array_almost_equal(omp.coef_, lstsq.coef_)

0 comments on commit 86d7afb

Please sign in to comment.