Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexisMignon committed Feb 16, 2016
2 parents c8714f5 + 29c7cfc commit 5c29eea
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 10 deletions.
3 changes: 3 additions & 0 deletions python-package/xgboost/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,11 @@ class DataFrame(object):
from sklearn.base import BaseEstimator
from sklearn.base import RegressorMixin, ClassifierMixin
from sklearn.preprocessing import LabelEncoder
from sklearn.cross_validation import KFold, StratifiedKFold
SKLEARN_INSTALLED = True

XGBKFold = KFold
XGBStratifiedKFold = StratifiedKFold
XGBModelBase = BaseEstimator
XGBRegressorBase = RegressorMixin
XGBClassifierBase = ClassifierMixin
Expand Down
40 changes: 30 additions & 10 deletions python-package/xgboost/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import re
import numpy as np
from .core import Booster, STRING_TYPES
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold, XGBKFold)

def train(params, dtrain, num_boost_round=10, evals=(), obj=None, feval=None,
maximize=False, early_stopping_rounds=None, evals_result=None,
Expand Down Expand Up @@ -261,15 +262,26 @@ def eval(self, iteration, feval):
return self.bst.eval_set(self.watchlist, iteration, feval)


def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None):
def mknfold(dall, nfold, param, seed, evals=(), fpreproc=None, stratified=False, folds=None):
"""
Make an n-fold list of CVPack from random indices.
"""
evals = list(evals)
np.random.seed(seed)
randidx = np.random.permutation(dall.num_row())
kstep = len(randidx) / nfold
idset = [randidx[(i * kstep): min(len(randidx), (i + 1) * kstep)] for i in range(nfold)]

if stratified is False and folds is None:
randidx = np.random.permutation(dall.num_row())
kstep = len(randidx) / nfold
idset = [randidx[(i * kstep): min(len(randidx), (i + 1) * kstep)] for i in range(nfold)]
elif folds is not None:
idset = [x[1] for x in folds]
nfold = len(idset)
else:
idset = [x[1] for x in XGBStratifiedKFold(dall.get_label(),
n_folds=nfold,
shuffle=True,
random_state=seed)]

ret = []
for k in range(nfold):
dtrain = dall.slice(np.concatenate([idset[i] for i in range(nfold) if k != i]))
Expand Down Expand Up @@ -345,8 +357,8 @@ def aggcv(rlist, show_stdv=True, show_progress=None, as_pandas=True, trial=0):
return results


def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
obj=None, feval=None, maximize=False, early_stopping_rounds=None,
def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None,
metrics=(), obj=None, feval=None, maximize=False, early_stopping_rounds=None,
fpreproc=None, as_pandas=True, show_progress=None, show_stdv=True, seed=0):
# pylint: disable = invalid-name
"""Cross-validation with given paramaters.
Expand All @@ -361,6 +373,10 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
Number of boosting iterations.
nfold : int
Number of folds in CV.
stratified : bool
Perform stratified sampling.
folds : KFold or StratifiedKFold
Sklearn KFolds or StratifiedKFolds.
metrics : string or list of strings
Evaluation metrics to be watched in CV.
obj : function
Expand All @@ -381,9 +397,9 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
If False or pandas is not installed, return np.ndarray
show_progress : bool, int, or None, default None
Whether to display the progress. If None, progress will be displayed
when np.ndarray is returned. If True, progress will be displayed at
boosting stage. If an integer is given, progress will be displayed
at every given `show_progress` boosting stage.
when np.ndarray is returned. If True, progress will be displayed at
boosting stage. If an integer is given, progress will be displayed
at every given `show_progress` boosting stage.
show_stdv : bool, default True
Whether to display the standard deviation in progress.
Results are not affected, and always contains std.
Expand All @@ -394,6 +410,9 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
-------
evaluation history : list(string)
"""
if stratified == True and not SKLEARN_INSTALLED:
raise XGBoostError('sklearn needs to be installed in order to use stratified cv')

if isinstance(metrics, str):
metrics = [metrics]

Expand Down Expand Up @@ -436,7 +455,7 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),

best_score_i = 0
results = []
cvfolds = mknfold(dtrain, nfold, params, seed, metrics, fpreproc)
cvfolds = mknfold(dtrain, nfold, params, seed, metrics, fpreproc, stratified, folds)
for i in range(num_boost_round):
for fold in cvfolds:
fold.update(i, obj)
Expand Down Expand Up @@ -466,3 +485,4 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, metrics=(),
results = np.array(results)

return results

37 changes: 37 additions & 0 deletions tests/python/test_cv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import xgboost as xgb
import numpy as np
from sklearn.datasets import load_digits
from sklearn.cross_validation import KFold, StratifiedKFold, train_test_split
from sklearn.metrics import mean_squared_error
import unittest

rng = np.random.RandomState(1994)


class TestCrossValidation(unittest.TestCase):
def test_cv(self):
digits = load_digits(3)
X = digits['data']
y = digits['target']
dm = xgb.DMatrix(X, label=y)

params = {
'max_depth': 2,
'eta': 1,
'silent': 1,
'objective':
'multi:softprob',
'num_class': 3
}

seed = 2016
nfolds = 5
skf = StratifiedKFold(y, n_folds=nfolds, shuffle=True, random_state=seed)

import pandas as pd
cv1 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds, seed=seed)
cv2 = xgb.cv(params, dm, num_boost_round=10, folds=skf, seed=seed)
cv3 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds, stratified=True, seed=seed)
assert cv1.shape[0] == cv2.shape[0] and cv2.shape[0] == cv3.shape[0]
assert cv2.iloc[-1,0] == cv3.iloc[-1,0]

0 comments on commit 5c29eea

Please sign in to comment.