Skip to content

Commit

Permalink
Merge pull request scikit-learn#926 from agramfort/fix_X_list_grid_se…
Browse files Browse the repository at this point in the history
…arch

FIX: fix grid search when X is list scikit-learn#925
  • Loading branch information
ogrisel committed Jul 3, 2012
2 parents 071cdb0 + 6e9c271 commit 5ec4a0a
Showing 2 changed files with 22 additions and 2 deletions.
10 changes: 8 additions & 2 deletions sklearn/grid_search.py
Original file line number Diff line number Diff line change
@@ -85,8 +85,14 @@ def fit_grid_point(X, y, base_clf, clf_params, train, test, loss_func,
clf.set_params(**clf_params)

if isinstance(X, list) or isinstance(X, tuple):
X_train = [X[i] for i, cond in enumerate(train) if cond]
X_test = [X[i] for i, cond in enumerate(test) if cond]
# train and test can be boolean mask but for list
# they should be indices so conversion is done if needed.
if isinstance(train, np.ndarray) and train.dtype == np.bool:
train = np.where(train)[0]
if isinstance(test, np.ndarray) and test.dtype == np.bool:
test = np.where(test)[0]
X_train = [X[i] for i in train]
X_test = [X[i] for i in test]
else:
if sp.issparse(X):
# For sparse matrices, slicing only works with indices
14 changes: 14 additions & 0 deletions sklearn/tests/test_grid_search.py
Original file line number Diff line number Diff line change
@@ -13,6 +13,7 @@
from sklearn.datasets.samples_generator import make_classification
from sklearn.svm import LinearSVC, SVC
from sklearn.metrics import f1_score, precision_score
from sklearn.cross_validation import KFold


class MockClassifier(BaseEstimator):
@@ -21,6 +22,7 @@ def __init__(self, foo_param=0):
self.foo_param = foo_param

def fit(self, X, Y):
assert_true(len(X) == len(Y))
return self

def predict(self, T):
@@ -211,3 +213,15 @@ def test_refit():
clf = GridSearchCV(BrokenClassifier(), [{'parameter': [0, 1]}],
score_func=precision_score, refit=True)
clf.fit(X, y)


def test_X_as_list():
"""Pass X as list in GridSearchCV
"""
X = np.arange(100).reshape(10, 10)
y = np.array([0] * 5 + [1] * 5)

clf = MockClassifier()
cv = KFold(n=len(X), k=3)
grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}, cv=cv)
grid_search.fit(X.tolist(), y).score(X, y)

0 comments on commit 5ec4a0a

Please sign in to comment.