Skip to content

Commit

Permalink
TST Improve SelectFromModel tests (scikit-learn#9733)
Browse files Browse the repository at this point in the history
Should fix one of the issues in scikit-learn#9393
  • Loading branch information
jnothman authored and amueller committed Oct 21, 2017
1 parent 0e1d261 commit b661a9c
Showing 1 changed file with 21 additions and 6 deletions.
27 changes: 21 additions & 6 deletions sklearn/feature_selection/tests/test_from_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def test_input_estimator_unchanged():
assert_true(transformer.estimator is est)


@skip_if_32bit
def test_feature_importances():
X, y = datasets.make_classification(
n_samples=1000, n_features=10, n_informative=3, n_redundant=0,
Expand All @@ -59,17 +58,33 @@ def test_feature_importances():
feature_mask = np.abs(importances) > func(importances)
assert_array_almost_equal(X_new, X[:, feature_mask])


def test_sample_weight():
# Ensure sample weights are passed to underlying estimator
X, y = datasets.make_classification(
n_samples=100, n_features=10, n_informative=3, n_redundant=0,
n_repeated=0, shuffle=False, random_state=0)

# Check with sample weights
sample_weight = np.ones(y.shape)
sample_weight[y == 1] *= 100

est = RandomForestClassifier(n_estimators=50, random_state=0)
est = LogisticRegression(random_state=0, fit_intercept=False)
transformer = SelectFromModel(estimator=est)
transformer.fit(X, y, sample_weight=None)
mask = transformer._get_support_mask()
transformer.fit(X, y, sample_weight=sample_weight)
importances = transformer.estimator_.feature_importances_
weighted_mask = transformer._get_support_mask()
assert not np.all(weighted_mask == mask)
transformer.fit(X, y, sample_weight=3 * sample_weight)
importances_bis = transformer.estimator_.feature_importances_
assert_almost_equal(importances, importances_bis)
reweighted_mask = transformer._get_support_mask()
assert np.all(weighted_mask == reweighted_mask)


def test_coef_default_threshold():
X, y = datasets.make_classification(
n_samples=100, n_features=10, n_informative=3, n_redundant=0,
n_repeated=0, shuffle=False, random_state=0)

# For the Lasso and related models, the threshold defaults to 1e-5
transformer = SelectFromModel(estimator=Lasso(alpha=0.1))
Expand All @@ -80,7 +95,7 @@ def test_feature_importances():


@skip_if_32bit
def test_feature_importances_2d_coef():
def test_2d_coef():
X, y = datasets.make_classification(
n_samples=1000, n_features=10, n_informative=3, n_redundant=0,
n_repeated=0, shuffle=False, random_state=0, n_classes=4)
Expand Down

0 comments on commit b661a9c

Please sign in to comment.