Skip to content

Commit

Permalink
[EM] Support ExtMemQdm in the GPU predictor. (dmlc#10694)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Aug 13, 2024
1 parent 4370454 commit 2ecc85f
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 129 deletions.
2 changes: 1 addition & 1 deletion include/xgboost/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy
* - missing: Which value to represent missing value
* - nthread (optional): Number of threads used for initializing DMatrix.
* - max_bin (optional): Maximum number of bins for building histogram.
* \param out The created Device Quantile DMatrix
* \param out The created Quantile DMatrix.
*
* \return 0 when success, -1 when failure happens
*/
Expand Down
2 changes: 1 addition & 1 deletion include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class MetaInfo {
* if specified, xgboost will start from this init margin
* can be used to specify initial prediction to boost from.
*/
linalg::Tensor<float, 2> base_margin_; // NOLINT
linalg::Matrix<float> base_margin_; // NOLINT
/*!
* \brief lower bound of the label, to be used for survival analysis (censored regression)
*/
Expand Down
3 changes: 1 addition & 2 deletions include/xgboost/predictor.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2017-2023 by Contributors
* Copyright 2017-2024, XGBoost Contributors
* \file predictor.h
* \brief Interface of predictor,
* performs predictions for a gradient booster.
Expand All @@ -15,7 +15,6 @@
#include <functional> // for function
#include <memory> // for shared_ptr
#include <string>
#include <utility> // for make_pair
#include <vector>

// Forward declarations
Expand Down
24 changes: 15 additions & 9 deletions src/data/ellpack_page.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,26 @@ struct EllpackDeviceAccessor {
min_fvalue = cuts->min_vals_.ConstHostSpan();
}
}
// Get a matrix element, uses binary search for look up Return NaN if missing
// Given a row index and a feature index, returns the corresponding cut value
[[nodiscard]] __device__ int32_t GetBinIndex(size_t ridx, size_t fidx) const {
ridx -= base_rowid;
/**
* @brief Given a row index and a feature index, returns the corresponding cut value.
*
* Uses binary search for look up. Returns NaN if missing.
*
* @tparam global_ridx Whether the row index is global to all ellpack batches or it's
* local to the current batch.
*/
template <bool global_ridx = true>
[[nodiscard]] __device__ bst_bin_t GetBinIndex(size_t ridx, size_t fidx) const {
if (global_ridx) {
ridx -= base_rowid;
}
auto row_begin = row_stride * ridx;
auto row_end = row_begin + row_stride;
auto gidx = -1;
bst_bin_t gidx = -1;
if (is_dense) {
gidx = gidx_iter[row_begin + fidx];
} else {
gidx = common::BinarySearchBin(row_begin,
row_end,
gidx_iter,
feature_segments[fidx],
gidx = common::BinarySearchBin(row_begin, row_end, gidx_iter, feature_segments[fidx],
feature_segments[fidx + 1]);
}
return gidx;
Expand Down
Loading

0 comments on commit 2ecc85f

Please sign in to comment.