Skip to content

Commit

Permalink
Fix mknfold using new StratifiedKFold API (dmlc#1660)
Browse files Browse the repository at this point in the history
  • Loading branch information
terrytangyuan authored and tqchen committed Oct 12, 2016
1 parent b56c609 commit 63829d6
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 10 deletions.
4 changes: 2 additions & 2 deletions python-package/xgboost/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@ class DataFrame(object):
try:
from sklearn.base import BaseEstimator
from sklearn.base import RegressorMixin, ClassifierMixin
from sklearn.preprocessing import LabelEncoder # noqa
from sklearn.preprocessing import LabelEncoder
try:
from sklearn.model_selection import KFold, StratifiedKFold
except ImportError:
from sklearn.cross_validation import KFold, StratifiedKFold

SKLEARN_INSTALLED = True

XGBModelBase = BaseEstimator
Expand Down
8 changes: 3 additions & 5 deletions python-package/xgboost/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,14 +232,12 @@ def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False,
randidx = np.random.permutation(dall.num_row())
kstep = int(len(randidx) / nfold)
idset = [randidx[(i * kstep): min(len(randidx), (i + 1) * kstep)] for i in range(nfold)]
elif folds is not None:
elif folds is not None and isinstance(folds, list):
idset = [x[1] for x in folds]
nfold = len(idset)
else:
idset = [x[1] for x in XGBStratifiedKFold(dall.get_label(),
n_folds=nfold,
shuffle=True,
random_state=seed)]
sfk = XGBStratifiedKFold(n_splits=nfold, shuffle=True, random_state=seed)
idset = [x[1] for x in sfk.split(X=dall.get_label(), y=dall.get_label())]

ret = []
for k in range(nfold):
Expand Down
6 changes: 3 additions & 3 deletions tests/python/test_with_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def test_sklearn_plotting():
def test_sklearn_nfolds_cv():
tm._skip_if_no_sklearn()
from sklearn.datasets import load_digits
from sklearn.cross_validation import StratifiedKFold
from sklearn.model_selection import StratifiedKFold

digits = load_digits(3)
X = digits['data']
Expand All @@ -269,10 +269,10 @@ def test_sklearn_nfolds_cv():

seed = 2016
nfolds = 5
skf = StratifiedKFold(y, n_folds=nfolds, shuffle=True, random_state=seed)
skf = StratifiedKFold(n_splits=nfolds, shuffle=True, random_state=seed)

cv1 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds, seed=seed)
cv2 = xgb.cv(params, dm, num_boost_round=10, folds=skf, seed=seed)
cv2 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds, folds=skf, seed=seed)
cv3 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds, stratified=True, seed=seed)
assert cv1.shape[0] == cv2.shape[0] and cv2.shape[0] == cv3.shape[0]
assert cv2.iloc[-1, 0] == cv3.iloc[-1, 0]
Expand Down

0 comments on commit 63829d6

Please sign in to comment.