From c8714f587a7f2627082b4898e6d079d5a4cbac94 Mon Sep 17 00:00:00 2001 From: Alexis Mignon Date: Mon, 15 Feb 2016 17:13:13 +0100 Subject: [PATCH 1/3] Added the possibility to use custom objective function in the sklearn API --- python-package/xgboost/sklearn.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 763551abf775..53c787f5aa26 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -25,7 +25,7 @@ class XGBModel(XGBModelBase): Number of boosted trees to fit. silent : boolean Whether to print messages while running boosting. - objective : string + objective : string or callable Specify the learning task and the corresponding learning objective. nthread : int @@ -174,6 +174,12 @@ def fit(self, X, y, eval_set=None, eval_metric=None, params = self.get_xgb_params() + if callable(self.objective): + obj = self.objective + params["objective"] = "reg:linear" + else: + obj = None + feval = eval_metric if callable(eval_metric) else None if eval_metric is not None: if callable(eval_metric): @@ -184,7 +190,7 @@ def fit(self, X, y, eval_set=None, eval_metric=None, self._Booster = train(params, trainDmatrix, self.n_estimators, evals=evals, early_stopping_rounds=early_stopping_rounds, - evals_result=evals_result, feval=feval, + evals_result=evals_result, obj=obj, feval=feval, verbose_eval=verbose) if evals_result: @@ -302,13 +308,20 @@ def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None, evals_result = {} self.classes_ = list(np.unique(y)) self.n_classes_ = len(self.classes_) + + + xgb_options = self.get_xgb_params() + + if callable(self.objective): + obj = self.objective + xgb_options["objective"] = "binary:logistic" + else: + obj = None + if self.n_classes_ > 2: # Switch to using a multiclass objective in the underlying XGB instance - self.objective = "multi:softprob" - xgb_options = self.get_xgb_params() + xgb_options["objective"] = "multi:softprob" xgb_options['num_class'] = self.n_classes_ - else: - xgb_options = self.get_xgb_params() feval = eval_metric if callable(eval_metric) else None if eval_metric is not None: @@ -339,7 +352,7 @@ def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None, self._Booster = train(xgb_options, train_dmatrix, self.n_estimators, evals=evals, early_stopping_rounds=early_stopping_rounds, - evals_result=evals_result, feval=feval, + evals_result=evals_result, obj=obj, feval=feval, verbose_eval=verbose) if evals_result: From 07bd149b68474be357be3f751dee74e9ad08c58c Mon Sep 17 00:00:00 2001 From: Alexis Mignon Date: Tue, 16 Feb 2016 10:58:22 +0100 Subject: [PATCH 2/3] Created decorator function so that custom objective function passed to the constructor are more consistent with the sklearn conventions. Added comments in the doc string --- python-package/xgboost/sklearn.py | 58 ++++++++++++++++++++++++++++--- 1 file changed, 54 insertions(+), 4 deletions(-) diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 53c787f5aa26..1598f63c5bc9 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -11,6 +11,39 @@ XGBClassifierBase, XGBRegressorBase, LabelEncoder) +def _objective_decorator(func): + """Decorate an objective function + + Converts an objective function using the typical sklearn metrics + signature so that it is usable with ``xgboost.training.train`` + + Parameters + ---------- + func: callable + Expects a callable with signature ``func(y_true, y_pred)``: + + y_true: array_like of shape [n_samples] + The target values + y_pred: array_like of shape [n_samples] + The predicted values + + Returns + ------- + new_func: callable + The new objective function as expected by ``xgboost.training.train``. + The signature is ``new_func(preds, dmatrix)``: + + preds: array_like, shape [n_samples] + The predicted values + dmatrix: ``DMatrix`` + The training set from which the labels will be extracted using + ``dmatrix.get_label()`` + """ + def inner(preds, dmatrix): + labels = dmatrix.get_label() + return func(labels, preds) + return inner + class XGBModel(XGBModelBase): # pylint: disable=too-many-arguments, too-many-instance-attributes, invalid-name """Implementation of the Scikit-Learn API for XGBoost. @@ -26,8 +59,8 @@ class XGBModel(XGBModelBase): silent : boolean Whether to print messages while running boosting. objective : string or callable - Specify the learning task and the corresponding learning objective. - + Specify the learning task and the corresponding learning objective or + a custom objective function to be used (see note below). nthread : int Number of parallel threads used to run xgboost. gamma : float @@ -56,6 +89,22 @@ class XGBModel(XGBModelBase): missing : float, optional Value in the data which needs to be present as a missing value. If None, defaults to np.nan. + + Note + ---- + A custom objective function can be provided for the ``objective`` + parameter. In this case, it should have the signature + ``objective(y_true, y_pred) -> grad, hess``: + + y_true: array_like of shape [n_samples] + The target values + y_pred: array_like of shape [n_samples] + The predicted values + + grad: array_like of shape [n_samples] + The value of the gradient for each sample point. + hess: array_like of shape [n_samples] + The value of the second derivative for each sample point """ def __init__(self, max_depth=3, learning_rate=0.1, n_estimators=100, silent=True, objective="reg:linear", @@ -175,7 +224,7 @@ def fit(self, X, y, eval_set=None, eval_metric=None, params = self.get_xgb_params() if callable(self.objective): - obj = self.objective + obj = _objective_decorator(self.objective) params["objective"] = "reg:linear" else: obj = None @@ -313,7 +362,8 @@ def fit(self, X, y, sample_weight=None, eval_set=None, eval_metric=None, xgb_options = self.get_xgb_params() if callable(self.objective): - obj = self.objective + obj = _objective_decorator(self.objective) + # Use default value. Is it really not used ? xgb_options["objective"] = "binary:logistic" else: obj = None From 6e27d7539f57bfe9f26054221208a56516d71771 Mon Sep 17 00:00:00 2001 From: Alexis Mignon Date: Tue, 16 Feb 2016 10:59:25 +0100 Subject: [PATCH 3/3] - Added test cases for the use of custom objective functions - Made the indentation more consistent with pep8 --- tests/python/test_with_sklearn.py | 164 +++++++++++++++++++++--------- 1 file changed, 116 insertions(+), 48 deletions(-) diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 3e31ddb65c7d..5cfe40891f05 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -1,6 +1,6 @@ import xgboost as xgb import numpy as np -from sklearn.cross_validation import KFold, train_test_split +from sklearn.cross_validation import KFold from sklearn.metrics import mean_squared_error from sklearn.grid_search import GridSearchCV from sklearn.datasets import load_iris, load_digits, load_boston @@ -8,57 +8,125 @@ rng = np.random.RandomState(1994) def test_binary_classification(): - digits = load_digits(2) - y = digits['target'] - X = digits['data'] - kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng) - for train_index, test_index in kf: - xgb_model = xgb.XGBClassifier().fit(X[train_index],y[train_index]) - preds = xgb_model.predict(X[test_index]) - labels = y[test_index] - err = sum(1 for i in range(len(preds)) if int(preds[i]>0.5)!=labels[i]) / float(len(preds)) - assert err < 0.1 + digits = load_digits(2) + y = digits['target'] + X = digits['data'] + kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng) + for train_index, test_index in kf: + xgb_model = xgb.XGBClassifier().fit(X[train_index],y[train_index]) + preds = xgb_model.predict(X[test_index]) + labels = y[test_index] + err = sum(1 for i in range(len(preds)) if int(preds[i]>0.5)!=labels[i]) / float(len(preds)) + assert err < 0.1 def test_multiclass_classification(): - iris = load_iris() - y = iris['target'] - X = iris['data'] - kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng) - for train_index, test_index in kf: - xgb_model = xgb.XGBClassifier().fit(X[train_index],y[train_index]) - preds = xgb_model.predict(X[test_index]) - # test other params in XGBClassifier().fit - preds2 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=3) - preds3 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=0) - preds4 = xgb_model.predict(X[test_index], output_margin=False, ntree_limit=3) - labels = y[test_index] - err = sum(1 for i in range(len(preds)) if int(preds[i]>0.5)!=labels[i]) / float(len(preds)) - assert err < 0.4 + iris = load_iris() + y = iris['target'] + X = iris['data'] + kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng) + for train_index, test_index in kf: + xgb_model = xgb.XGBClassifier().fit(X[train_index],y[train_index]) + preds = xgb_model.predict(X[test_index]) + # test other params in XGBClassifier().fit + preds2 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=3) + preds3 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=0) + preds4 = xgb_model.predict(X[test_index], output_margin=False, ntree_limit=3) + labels = y[test_index] + err = sum(1 for i in range(len(preds)) if int(preds[i]>0.5)!=labels[i]) / float(len(preds)) + assert err < 0.4 def test_boston_housing_regression(): - boston = load_boston() - y = boston['target'] - X = boston['data'] - kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng) - for train_index, test_index in kf: - xgb_model = xgb.XGBRegressor().fit(X[train_index],y[train_index]) - preds = xgb_model.predict(X[test_index]) - # test other params in XGBRegressor().fit - preds2 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=3) - preds3 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=0) - preds4 = xgb_model.predict(X[test_index], output_margin=False, ntree_limit=3) - labels = y[test_index] - assert mean_squared_error(preds, labels) < 25 + boston = load_boston() + y = boston['target'] + X = boston['data'] + kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng) + for train_index, test_index in kf: + xgb_model = xgb.XGBRegressor().fit(X[train_index],y[train_index]) + preds = xgb_model.predict(X[test_index]) + # test other params in XGBRegressor().fit + preds2 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=3) + preds3 = xgb_model.predict(X[test_index], output_margin=True, ntree_limit=0) + preds4 = xgb_model.predict(X[test_index], output_margin=False, ntree_limit=3) + labels = y[test_index] + assert mean_squared_error(preds, labels) < 25 def test_parameter_tuning(): - boston = load_boston() - y = boston['target'] - X = boston['data'] - xgb_model = xgb.XGBRegressor() - clf = GridSearchCV(xgb_model, - {'max_depth': [2,4,6], - 'n_estimators': [50,100,200]}, verbose=1) - clf.fit(X,y) - assert clf.best_score_ < 0.7 - assert clf.best_params_ == {'n_estimators': 100, 'max_depth': 4} + boston = load_boston() + y = boston['target'] + X = boston['data'] + xgb_model = xgb.XGBRegressor() + clf = GridSearchCV(xgb_model, + {'max_depth': [2,4,6], + 'n_estimators': [50,100,200]}, verbose=1) + clf.fit(X,y) + assert clf.best_score_ < 0.7 + assert clf.best_params_ == {'n_estimators': 100, 'max_depth': 4} + +def test_regression_with_custom_objective(): + def objective_ls(y_true, y_pred): + grad = (y_pred - y_true) + hess = np.ones(len(y_true)) + return grad, hess + + boston = load_boston() + y = boston['target'] + X = boston['data'] + kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng) + for train_index, test_index in kf: + xgb_model = xgb.XGBRegressor(objective=objective_ls).fit( + X[train_index], y[train_index] + ) + preds = xgb_model.predict(X[test_index]) + labels = y[test_index] + assert mean_squared_error(preds, labels) < 25 + + # Test that the custom objective function is actually used + class XGBCustomObjectiveException(Exception): + pass + + def dummy_objective(y_true, y_pred): + raise XGBCustomObjectiveException() + + xgb_model = xgb.XGBRegressor(objective=dummy_objective) + np.testing.assert_raises( + XGBCustomObjectiveException, + xgb_model.fit, + X, y + ) + +def test_classification_with_custom_objective(): + def logregobj(y_true, y_pred): + y_pred = 1.0 / (1.0 + np.exp(-y_pred)) + grad = y_pred - y_true + hess = y_pred * (1.0-y_pred) + return grad, hess + + digits = load_digits(2) + y = digits['target'] + X = digits['data'] + kf = KFold(y.shape[0], n_folds=2, shuffle=True, random_state=rng) + for train_index, test_index in kf: + xgb_model = xgb.XGBClassifier(objective=logregobj).fit( + X[train_index],y[train_index] + ) + preds = xgb_model.predict(X[test_index]) + labels = y[test_index] + err = sum(1 for i in range(len(preds)) + if int(preds[i]>0.5)!=labels[i]) / float(len(preds)) + assert err < 0.1 + + + # Test that the custom objective function is actually used + class XGBCustomObjectiveException(Exception): + pass + + def dummy_objective(y_true, y_preds): + raise XGBCustomObjectiveException() + + xgb_model = xgb.XGBClassifier(objective=dummy_objective) + np.testing.assert_raises( + XGBCustomObjectiveException, + xgb_model.fit, + X, y + )