Skip to content

Commit

Permalink
Merge pull request scikit-learn#3090 from ogrisel/learning-curves-war…
Browse files Browse the repository at this point in the history
…nings

[MRG] FIX: remove deprecation warnings in learning curves under Python 3
  • Loading branch information
ogrisel committed Apr 19, 2014
2 parents ea91673 + 6c0d41a commit a03abd6
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
4 changes: 2 additions & 2 deletions sklearn/learning_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def learning_curve(estimator, X, y, train_sizes=np.linspace(0.1, 1.0, 5),
verbose, parameters=None, fit_params=None, return_train_score=True)
for train, test in cv for n_train_samples in train_sizes_abs)
out = np.array(out)[:, :2]
n_cv_folds = out.shape[0] / n_unique_ticks
n_cv_folds = out.shape[0] // n_unique_ticks
out = out.reshape(n_cv_folds, n_unique_ticks, 2)

out = np.asarray(out).transpose((2, 1, 0))
Expand Down Expand Up @@ -297,7 +297,7 @@ def validation_curve(estimator, X, y, param_name, param_range, cv=None,

out = np.asarray(out)[:, :2]
n_params = len(param_range)
n_cv_folds = out.shape[0] / n_params
n_cv_folds = out.shape[0] // n_params
out = out.reshape(n_cv_folds, n_params, 2).transpose((2, 1, 0))

return out[0], out[1]
18 changes: 13 additions & 5 deletions sklearn/tests/test_learning_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import sys
from sklearn.externals.six.moves import cStringIO as StringIO
import numpy as np
import warnings
from sklearn.base import BaseEstimator
from sklearn.learning_curve import learning_curve, validation_curve
from sklearn.utils.testing import assert_raises
Expand Down Expand Up @@ -84,8 +85,11 @@ def test_learning_curve():
n_redundant=0, n_classes=2,
n_clusters_per_class=1, random_state=0)
estimator = MockImprovingEstimator(20)
train_sizes, train_scores, test_scores = learning_curve(
estimator, X, y, cv=3, train_sizes=np.linspace(0.1, 1.0, 10))
with warnings.catch_warnings(record=True) as w:
train_sizes, train_scores, test_scores = learning_curve(
estimator, X, y, cv=3, train_sizes=np.linspace(0.1, 1.0, 10))
if len(w) > 0:
raise RuntimeError("Unexpected warning: %r" % w[0].message)
assert_equal(train_scores.shape, (10, 3))
assert_equal(test_scores.shape, (10, 3))
assert_array_equal(train_sizes, np.linspace(2, 20, 10))
Expand Down Expand Up @@ -239,8 +243,12 @@ def test_validation_curve():
n_redundant=0, n_classes=2,
n_clusters_per_class=1, random_state=0)
param_range = np.linspace(0, 1, 10)
train_scores, test_scores = validation_curve(MockEstimatorWithParameter(),
X, y, param_name="param",
param_range=param_range, cv=2)
with warnings.catch_warnings(record=True) as w:
train_scores, test_scores = validation_curve(
MockEstimatorWithParameter(), X, y, param_name="param",
param_range=param_range, cv=2)
if len(w) > 0:
raise RuntimeError("Unexpected warning: %r" % w[0].message)

assert_array_almost_equal(train_scores.mean(axis=1), param_range)
assert_array_almost_equal(test_scores.mean(axis=1), 1 - param_range)

0 comments on commit a03abd6

Please sign in to comment.