Skip to content

Commit

Permalink
[Breaking] Require format to be specified in input URI. (dmlc#9077)
Browse files Browse the repository at this point in the history
Previously, we use `libsvm` as default when format is not specified. However, the dmlc
data parser is not particularly robust against errors, and the most common type of error
is undefined format.

Along with which, we will recommend users to use other data loader instead. We will
continue the maintenance of the parsers as it's currently used for many internal tests
including federated learning.
  • Loading branch information
trivialfis authored Apr 28, 2023
1 parent e922004 commit 1f9a57d
Show file tree
Hide file tree
Showing 58 changed files with 328 additions and 269 deletions.
2 changes: 1 addition & 1 deletion R-package/tests/testthat/test_dmatrix.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ test_that("xgb.DMatrix: saving, loading", {
tmp <- c("0 1:1 2:1", "1 3:1", "0 1:1")
tmp_file <- tempfile(fileext = ".libsvm")
writeLines(tmp, tmp_file)
dtest4 <- xgb.DMatrix(tmp_file, silent = TRUE)
dtest4 <- xgb.DMatrix(paste(tmp_file, "?format=libsvm", sep = ""), silent = TRUE)
expect_equal(dim(dtest4), c(3, 4))
expect_equal(getinfo(dtest4, 'label'), c(0, 1, 0))

Expand Down
6 changes: 3 additions & 3 deletions demo/CLI/binary_classification/mushroom.conf
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ num_round = 2
# 0 means do not save any model except the final round model
save_period = 2
# The path of training data
data = "agaricus.txt.train"
data = "agaricus.txt.train?format=libsvm"
# The path of validation data, used to monitor training process, here [test] sets name of the validation set
eval[test] = "agaricus.txt.test"
eval[test] = "agaricus.txt.test?format=libsvm"
# evaluate on training data as well each round
eval_train = 1
# The path of test data
test:data = "agaricus.txt.test"
test:data = "agaricus.txt.test?format=libsvm"
6 changes: 3 additions & 3 deletions demo/CLI/regression/machine.conf
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ num_round = 2
# 0 means do not save any model except the final round model
save_period = 0
# The path of training data
data = "machine.txt.train"
data = "machine.txt.train?format=libsvm"
# The path of validation data, used to monitor training process, here [test] sets name of the validation set
eval[test] = "machine.txt.test"
eval[test] = "machine.txt.test?format=libsvm"
# The path of test data
test:data = "machine.txt.test"
test:data = "machine.txt.test?format=libsvm"
4 changes: 2 additions & 2 deletions demo/c-api/basic/c-api-demo.c
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ int main() {

// load the data
DMatrixHandle dtrain, dtest;
safe_xgboost(XGDMatrixCreateFromFile("../../data/agaricus.txt.train", silent, &dtrain));
safe_xgboost(XGDMatrixCreateFromFile("../../data/agaricus.txt.test", silent, &dtest));
safe_xgboost(XGDMatrixCreateFromFile("../../data/agaricus.txt.train?format=libsvm", silent, &dtrain));
safe_xgboost(XGDMatrixCreateFromFile("../../data/agaricus.txt.test?format=libsvm", silent, &dtest));

// create the booster
BoosterHandle booster;
Expand Down
16 changes: 10 additions & 6 deletions demo/guide-python/boost_from_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,19 @@
import xgboost as xgb

CURRENT_DIR = os.path.dirname(__file__)
dtrain = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.train'))
dtest = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.test'))
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
dtrain = xgb.DMatrix(
os.path.join(CURRENT_DIR, "../data/agaricus.txt.train?format=libsvm")
)
dtest = xgb.DMatrix(
os.path.join(CURRENT_DIR, "../data/agaricus.txt.test?format=libsvm")
)
watchlist = [(dtest, "eval"), (dtrain, "train")]
###
# advanced: start from a initial base prediction
#
print('start running example to start from a initial prediction')
print("start running example to start from a initial prediction")
# specify parameters via map, definition are same as c++ version
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}
param = {"max_depth": 2, "eta": 1, "objective": "binary:logistic"}
# train xgboost for 1 round
bst = xgb.train(param, dtrain, 1, watchlist)
# Note: we need the margin value instead of transformed prediction in
Expand All @@ -27,5 +31,5 @@
dtrain.set_base_margin(ptrain)
dtest.set_base_margin(ptest)

print('this is result of running from initial prediction')
print("this is result of running from initial prediction")
bst = xgb.train(param, dtrain, 1, watchlist)
62 changes: 42 additions & 20 deletions demo/guide-python/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,60 +10,82 @@

# load data in do training
CURRENT_DIR = os.path.dirname(__file__)
dtrain = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.train'))
param = {'max_depth':2, 'eta':1, 'objective':'binary:logistic'}
dtrain = xgb.DMatrix(
os.path.join(CURRENT_DIR, "../data/agaricus.txt.train?format=libsvm")
)
param = {"max_depth": 2, "eta": 1, "objective": "binary:logistic"}
num_round = 2

print('running cross validation')
print("running cross validation")
# do cross validation, this will print result out as
# [iteration] metric_name:mean_value+std_value
# std_value is standard deviation of the metric
xgb.cv(param, dtrain, num_round, nfold=5,
metrics={'error'}, seed=0,
callbacks=[xgb.callback.EvaluationMonitor(show_stdv=True)])
xgb.cv(
param,
dtrain,
num_round,
nfold=5,
metrics={"error"},
seed=0,
callbacks=[xgb.callback.EvaluationMonitor(show_stdv=True)],
)

print('running cross validation, disable standard deviation display')
print("running cross validation, disable standard deviation display")
# do cross validation, this will print result out as
# [iteration] metric_name:mean_value
res = xgb.cv(param, dtrain, num_boost_round=10, nfold=5,
metrics={'error'}, seed=0,
callbacks=[xgb.callback.EvaluationMonitor(show_stdv=False),
xgb.callback.EarlyStopping(3)])
res = xgb.cv(
param,
dtrain,
num_boost_round=10,
nfold=5,
metrics={"error"},
seed=0,
callbacks=[
xgb.callback.EvaluationMonitor(show_stdv=False),
xgb.callback.EarlyStopping(3),
],
)
print(res)
print('running cross validation, with preprocessing function')
print("running cross validation, with preprocessing function")


# define the preprocessing function
# used to return the preprocessed training, test data, and parameter
# we can use this to do weight rescale, etc.
# as a example, we try to set scale_pos_weight
def fpreproc(dtrain, dtest, param):
label = dtrain.get_label()
ratio = float(np.sum(label == 0)) / np.sum(label == 1)
param['scale_pos_weight'] = ratio
param["scale_pos_weight"] = ratio
return (dtrain, dtest, param)


# do cross validation, for each fold
# the dtrain, dtest, param will be passed into fpreproc
# then the return value of fpreproc will be used to generate
# results of that fold
xgb.cv(param, dtrain, num_round, nfold=5,
metrics={'auc'}, seed=0, fpreproc=fpreproc)
xgb.cv(param, dtrain, num_round, nfold=5, metrics={"auc"}, seed=0, fpreproc=fpreproc)

###
# you can also do cross validation with customized loss function
# See custom_objective.py
##
print('running cross validation, with customized loss function')
print("running cross validation, with customized loss function")


def logregobj(preds, dtrain):
labels = dtrain.get_label()
preds = 1.0 / (1.0 + np.exp(-preds))
grad = preds - labels
hess = preds * (1.0 - preds)
return grad, hess


def evalerror(preds, dtrain):
labels = dtrain.get_label()
return 'error', float(sum(labels != (preds > 0.0))) / len(labels)
return "error", float(sum(labels != (preds > 0.0))) / len(labels)


param = {'max_depth':2, 'eta':1}
param = {"max_depth": 2, "eta": 1}
# train with customized objective
xgb.cv(param, dtrain, num_round, nfold=5, seed=0,
obj=logregobj, feval=evalerror)
xgb.cv(param, dtrain, num_round, nfold=5, seed=0, obj=logregobj, feval=evalerror)
37 changes: 23 additions & 14 deletions demo/guide-python/evals_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,37 @@
import xgboost as xgb

CURRENT_DIR = os.path.dirname(__file__)
dtrain = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.train'))
dtest = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.test'))

param = [('max_depth', 2), ('objective', 'binary:logistic'), ('eval_metric', 'logloss'), ('eval_metric', 'error')]
dtrain = xgb.DMatrix(
os.path.join(CURRENT_DIR, "../data/agaricus.txt.train?format=libsvm")
)
dtest = xgb.DMatrix(
os.path.join(CURRENT_DIR, "../data/agaricus.txt.test?format=libsvm")
)

param = [
("max_depth", 2),
("objective", "binary:logistic"),
("eval_metric", "logloss"),
("eval_metric", "error"),
]

num_round = 2
watchlist = [(dtest,'eval'), (dtrain,'train')]
watchlist = [(dtest, "eval"), (dtrain, "train")]

evals_result = {}
bst = xgb.train(param, dtrain, num_round, watchlist, evals_result=evals_result)

print('Access logloss metric directly from evals_result:')
print(evals_result['eval']['logloss'])
print("Access logloss metric directly from evals_result:")
print(evals_result["eval"]["logloss"])

print('')
print('Access metrics through a loop:')
print("")
print("Access metrics through a loop:")
for e_name, e_mtrs in evals_result.items():
print('- {}'.format(e_name))
print("- {}".format(e_name))
for e_mtr_name, e_mtr_vals in e_mtrs.items():
print(' - {}'.format(e_mtr_name))
print(' - {}'.format(e_mtr_vals))
print(" - {}".format(e_mtr_name))
print(" - {}".format(e_mtr_vals))

print('')
print('Access complete dictionary:')
print("")
print("Access complete dictionary:")
print(evals_result)
26 changes: 20 additions & 6 deletions demo/guide-python/generalized_linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,22 @@
# basically, we are using linear model, instead of tree for our boosters
##
CURRENT_DIR = os.path.dirname(__file__)
dtrain = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.train'))
dtest = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.test'))
dtrain = xgb.DMatrix(
os.path.join(CURRENT_DIR, "../data/agaricus.txt.train?format=libsvm")
)
dtest = xgb.DMatrix(
os.path.join(CURRENT_DIR, "../data/agaricus.txt.test?format=libsvm")
)
# change booster to gblinear, so that we are fitting a linear model
# alpha is the L1 regularizer
# lambda is the L2 regularizer
# you can also set lambda_bias which is L2 regularizer on the bias term
param = {'objective':'binary:logistic', 'booster':'gblinear',
'alpha': 0.0001, 'lambda': 1}
param = {
"objective": "binary:logistic",
"booster": "gblinear",
"alpha": 0.0001,
"lambda": 1,
}

# normally, you do not need to set eta (step_size)
# XGBoost uses a parallel coordinate descent algorithm (shotgun),
Expand All @@ -29,9 +37,15 @@
##
# the rest of settings are the same
##
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
watchlist = [(dtest, "eval"), (dtrain, "train")]
num_round = 4
bst = xgb.train(param, dtrain, num_round, watchlist)
preds = bst.predict(dtest)
labels = dtest.get_label()
print('error=%f' % (sum(1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i]) / float(len(preds))))
print(
"error=%f"
% (
sum(1 for i in range(len(preds)) if int(preds[i] > 0.5) != labels[i])
/ float(len(preds))
)
)
4 changes: 2 additions & 2 deletions demo/guide-python/predict_first_ntree.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

def native_interface():
# load data in do training
dtrain = xgb.DMatrix(train)
dtest = xgb.DMatrix(test)
dtrain = xgb.DMatrix(train + "?format=libsvm")
dtest = xgb.DMatrix(test + "?format=libsvm")
param = {"max_depth": 2, "eta": 1, "objective": "binary:logistic"}
watchlist = [(dtest, "eval"), (dtrain, "train")]
num_round = 3
Expand Down
14 changes: 9 additions & 5 deletions demo/guide-python/predict_leaf_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,18 @@

# load data in do training
CURRENT_DIR = os.path.dirname(__file__)
dtrain = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.train'))
dtest = xgb.DMatrix(os.path.join(CURRENT_DIR, '../data/agaricus.txt.test'))
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
dtrain = xgb.DMatrix(
os.path.join(CURRENT_DIR, "../data/agaricus.txt.train?format=libsvm")
)
dtest = xgb.DMatrix(
os.path.join(CURRENT_DIR, "../data/agaricus.txt.test?format=libsvm")
)
param = {"max_depth": 2, "eta": 1, "objective": "binary:logistic"}
watchlist = [(dtest, "eval"), (dtrain, "train")]
num_round = 3
bst = xgb.train(param, dtrain, num_round, watchlist)

print('start testing predict the leaf indices')
print("start testing predict the leaf indices")
# predict using first 2 tree
leafindex = bst.predict(
dtest, iteration_range=(0, 2), pred_leaf=True, strict_shape=True
Expand Down
6 changes: 3 additions & 3 deletions doc/tutorials/external_memory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ The external memory version takes in the following `URI <https://en.wikipedia.or

.. code-block:: none
filename#cacheprefix
filename?format=libsvm#cacheprefix
The ``filename`` is the normal path to LIBSVM format file you want to load in, and
``cacheprefix`` is a path to a cache file that XGBoost will use for caching preprocessed
Expand All @@ -97,13 +97,13 @@ you have a dataset stored in a file similar to ``agaricus.txt.train`` with LIBSV

.. code-block:: python
dtrain = DMatrix('../data/agaricus.txt.train#dtrain.cache')
dtrain = DMatrix('../data/agaricus.txt.train?format=libsvm#dtrain.cache')
XGBoost will first load ``agaricus.txt.train`` in, preprocess it, then write to a new file named
``dtrain.cache`` as an on disk cache for storing preprocessed data in an internal binary format. For
more notes about text input formats, see :doc:`/tutorials/input_format`.

For CLI version, simply add the cache suffix, e.g. ``"../data/agaricus.txt.train#dtrain.cache"``.
For CLI version, simply add the cache suffix, e.g. ``"../data/agaricus.txt.train?format=libsvm#dtrain.cache"``.


**********************************
Expand Down
7 changes: 6 additions & 1 deletion doc/tutorials/input_format.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@
Text Input Format of DMatrix
############################

.. _basic_input_format:

Here we will briefly describe the text input formats for XGBoost. However, for users with access to a supported language environment like Python or R, it's recommended to use data parsers from that ecosystem instead. For instance, :py:func:`sklearn.datasets.load_svmlight_file`.

******************
Basic Input Format
******************
XGBoost currently supports two text formats for ingesting data: LIBSVM and CSV. The rest of this document will describe the LIBSVM format. (See `this Wikipedia article <https://en.wikipedia.org/wiki/Comma-separated_values>`_ for a description of the CSV format.). Please be careful that, XGBoost does **not** understand file extensions, nor try to guess the file format, as there is no universal agreement upon file extension of LIBSVM or CSV. Instead it employs `URI <https://en.wikipedia.org/wiki/Uniform_Resource_Identifier>`_ format for specifying the precise input file type. For example if you provide a `csv` file ``./data.train.csv`` as input, XGBoost will blindly use the default LIBSVM parser to digest it and generate a parser error. Instead, users need to provide an URI in the form of ``train.csv?format=csv``. For external memory input, the URI should of a form similar to ``train.csv?format=csv#dtrain.cache``. See :ref:`python_data_interface` and :doc:`/tutorials/external_memory` also.

XGBoost currently supports two text formats for ingesting data: LIBSVM and CSV. The rest of this document will describe the LIBSVM format. (See `this Wikipedia article <https://en.wikipedia.org/wiki/Comma-separated_values>`_ for a description of the CSV format.). Please be careful that, XGBoost does **not** understand file extensions, nor try to guess the file format, as there is no universal agreement upon file extension of LIBSVM or CSV. Instead it employs `URI <https://en.wikipedia.org/wiki/Uniform_Resource_Identifier>`_ format for specifying the precise input file type. For example if you provide a `csv` file ``./data.train.csv`` as input, XGBoost will blindly use the default LIBSVM parser to digest it and generate a parser error. Instead, users need to provide an URI in the form of ``train.csv?format=csv`` or ``train.csv?format=libsvm``. For external memory input, the URI should of a form similar to ``train.csv?format=csv#dtrain.cache``. See :ref:`python_data_interface` and :doc:`/tutorials/external_memory` also.

For training or predicting, XGBoost takes an instance file with the format as below:

Expand Down
6 changes: 5 additions & 1 deletion include/xgboost/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,11 @@ XGB_DLL int XGDMatrixCreateFromFile(const char *fname, int silent, DMatrixHandle
/*!
* \brief load a data matrix
* \param config JSON encoded parameters for DMatrix construction. Accepted fields are:
* - uri: The URI of the input file.
* - uri: The URI of the input file. The URI parameter `format` is required when loading text data.
* \verbatim embed:rst:leading-asterisk
* See :doc:`/tutorials/input_format` for more info.
* \endverbatim
* - silent (optional): Whether to print message during loading. Default to true.
* - data_split_mode (optional): Whether to split by row or column. In distributed mode, the
* file is split accordingly; otherwise this is only an indicator on how the file was split
Expand Down
12 changes: 4 additions & 8 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -566,21 +566,17 @@ class DMatrix {
return Info().num_nonzero_ == Info().num_row_ * Info().num_col_;
}

/*!
/**
* \brief Load DMatrix from URI.
*
* \param uri The URI of input.
* \param silent Whether print information during loading.
* \param data_split_mode In distributed mode, split the input according this mode; otherwise,
* it's just an indicator on how the input was split beforehand.
* \param file_format The format type of the file, used for dmlc::Parser::Create.
* By default "auto" will be able to load in both local binary file.
* \param page_size Page size for external memory.
* \return The created DMatrix.
*/
static DMatrix* Load(const std::string& uri,
bool silent = true,
DataSplitMode data_split_mode = DataSplitMode::kRow,
const std::string& file_format = "auto");
static DMatrix* Load(const std::string& uri, bool silent = true,
DataSplitMode data_split_mode = DataSplitMode::kRow);

/**
* \brief Creates a new DMatrix from an external data adapter.
Expand Down
Loading

0 comments on commit 1f9a57d

Please sign in to comment.