Skip to content

Commit

Permalink
Merge pull request dmlc#1017 from Far0n/hist
Browse files Browse the repository at this point in the history
[py] split value histograms
  • Loading branch information
terrytangyuan committed Apr 28, 2016
2 parents 6691d5c + cf607e2 commit 3434083
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 1 deletion.
45 changes: 44 additions & 1 deletion python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,18 @@
"""Core XGBoost Library."""
from __future__ import absolute_import

import sys
import os
import ctypes
import collections
import re

import numpy as np
import scipy.sparse

from .libpath import find_lib_path

from .compat import STRING_TYPES, PY3, DataFrame, py_str
from .compat import STRING_TYPES, PY3, DataFrame, py_str, PANDAS_INSTALLED


class XGBoostError(Exception):
Expand Down Expand Up @@ -1058,3 +1060,44 @@ def _validate_features(self, data):

raise ValueError(msg.format(self.feature_names,
data.feature_names))

def get_split_value_histogram(self, feature, fmap='', bins=None, as_pandas=True):
"""Get split value histogram of a feature
Parameters
----------
feature: str
The name of the feature.
fmap: str (optional)
The name of feature map file.
bin: int, default None
The maximum number of bins.
Number of bins equals number of unique split values n_unique, if bins == None or bins > n_unique.
as_pandas : bool, default True
Return pd.DataFrame when pandas is installed.
If False or pandas is not installed, return numpy ndarray.
Returns
-------
a histogram of used splitting values for the specified feature either as numpy array or pandas DataFrame.
"""
xgdump = self.get_dump(fmap=fmap)
values = []
regexp = re.compile("\[{0}<([\d.Ee+-]+)\]".format(feature))
for i in range(len(xgdump)):
m = re.findall(regexp, xgdump[i])
values.extend(map(float, m))

n_unique = np.unique(values).shape[0]
bins = max(min(n_unique, bins) if bins is not None else n_unique, 1)

nph = np.histogram(values, bins=bins)
nph = np.column_stack((nph[1][1:], nph[0]))
nph = nph[nph[:, 1] > 0]

if as_pandas and PANDAS_INSTALLED:
return DataFrame(nph, columns=['SplitValue', 'Count'])
elif as_pandas and not PANDAS_INSTALLED:
sys.stderr.write("Returning histogram as ndarray (as_pandas == True, but pandas is not installed).")
return nph
else:
return nph
19 changes: 19 additions & 0 deletions tests/python/test_with_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,22 @@ def test_sklearn_nfolds_cv():
cv3 = xgb.cv(params, dm, num_boost_round=10, nfold=nfolds, stratified=True, seed=seed)
assert cv1.shape[0] == cv2.shape[0] and cv2.shape[0] == cv3.shape[0]
assert cv2.iloc[-1, 0] == cv3.iloc[-1, 0]


def test_split_value_histograms():
digits_2class = load_digits(2)

X = digits_2class['data']
y = digits_2class['target']

dm = xgb.DMatrix(X, label=y)
params = {'max_depth': 6, 'eta': 0.01, 'silent': 1, 'objective': 'binary:logistic'}

gbdt = xgb.train(params, dm, num_boost_round=10)
assert gbdt.get_split_value_histogram("not_there", as_pandas=True).shape[0] == 0
assert gbdt.get_split_value_histogram("not_there", as_pandas=False).shape[0] == 0
assert gbdt.get_split_value_histogram("f28", bins=0).shape[0] == 1
assert gbdt.get_split_value_histogram("f28", bins=1).shape[0] == 1
assert gbdt.get_split_value_histogram("f28", bins=2).shape[0] == 2
assert gbdt.get_split_value_histogram("f28", bins=5).shape[0] == 2
assert gbdt.get_split_value_histogram("f28", bins=None).shape[0] == 2

0 comments on commit 3434083

Please sign in to comment.