From 2ecc85ffada1d57502df121858fc4256d22158ed Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 13 Aug 2024 12:21:11 +0800 Subject: [PATCH] [EM] Support ExtMemQdm in the GPU predictor. (#10694) --- include/xgboost/c_api.h | 2 +- include/xgboost/data.h | 2 +- include/xgboost/predictor.h | 3 +- src/data/ellpack_page.cuh | 24 ++-- src/predictor/gpu_predictor.cu | 155 +++++++++------------- tests/cpp/predictor/test_gpu_predictor.cu | 67 ++++++---- 6 files changed, 124 insertions(+), 129 deletions(-) diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index 16817bf5ad1c..9f72d1e1368c 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -494,7 +494,7 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy * - missing: Which value to represent missing value * - nthread (optional): Number of threads used for initializing DMatrix. * - max_bin (optional): Maximum number of bins for building histogram. - * \param out The created Device Quantile DMatrix + * \param out The created Quantile DMatrix. * * \return 0 when success, -1 when failure happens */ diff --git a/include/xgboost/data.h b/include/xgboost/data.h index bc38400e9b9a..10329f87b074 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -72,7 +72,7 @@ class MetaInfo { * if specified, xgboost will start from this init margin * can be used to specify initial prediction to boost from. */ - linalg::Tensor base_margin_; // NOLINT + linalg::Matrix base_margin_; // NOLINT /*! * \brief lower bound of the label, to be used for survival analysis (censored regression) */ diff --git a/include/xgboost/predictor.h b/include/xgboost/predictor.h index 6a38d6496fd4..555ded55fb02 100644 --- a/include/xgboost/predictor.h +++ b/include/xgboost/predictor.h @@ -1,5 +1,5 @@ /** - * Copyright 2017-2023 by Contributors + * Copyright 2017-2024, XGBoost Contributors * \file predictor.h * \brief Interface of predictor, * performs predictions for a gradient booster. @@ -15,7 +15,6 @@ #include // for function #include // for shared_ptr #include -#include // for make_pair #include // Forward declarations diff --git a/src/data/ellpack_page.cuh b/src/data/ellpack_page.cuh index 97167ad5c64b..e494afb3e9a4 100644 --- a/src/data/ellpack_page.cuh +++ b/src/data/ellpack_page.cuh @@ -60,20 +60,26 @@ struct EllpackDeviceAccessor { min_fvalue = cuts->min_vals_.ConstHostSpan(); } } - // Get a matrix element, uses binary search for look up Return NaN if missing - // Given a row index and a feature index, returns the corresponding cut value - [[nodiscard]] __device__ int32_t GetBinIndex(size_t ridx, size_t fidx) const { - ridx -= base_rowid; + /** + * @brief Given a row index and a feature index, returns the corresponding cut value. + * + * Uses binary search for look up. Returns NaN if missing. + * + * @tparam global_ridx Whether the row index is global to all ellpack batches or it's + * local to the current batch. + */ + template + [[nodiscard]] __device__ bst_bin_t GetBinIndex(size_t ridx, size_t fidx) const { + if (global_ridx) { + ridx -= base_rowid; + } auto row_begin = row_stride * ridx; auto row_end = row_begin + row_stride; - auto gidx = -1; + bst_bin_t gidx = -1; if (is_dense) { gidx = gidx_iter[row_begin + fidx]; } else { - gidx = common::BinarySearchBin(row_begin, - row_end, - gidx_iter, - feature_segments[fidx], + gidx = common::BinarySearchBin(row_begin, row_end, gidx_iter, feature_segments[fidx], feature_segments[fidx + 1]); } return gidx; diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index fe46e19ec63b..570872aa52ad 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -3,10 +3,8 @@ */ #include #include -#include #include #include -#include #include // for any, any_cast #include @@ -102,7 +100,7 @@ struct SparsePageView { } } // Value is missing - return nanf(""); + return std::numeric_limits::quiet_NaN(); } [[nodiscard]] XGBOOST_DEVICE size_t NumRows() const { return d_row_ptr.size() - 1; } [[nodiscard]] XGBOOST_DEVICE size_t NumCols() const { return num_features; } @@ -114,22 +112,21 @@ struct SparsePageLoader { float* smem; __device__ SparsePageLoader(SparsePageView data, bool use_shared, bst_feature_t num_features, - bst_idx_t num_rows, size_t entry_start, float) - : use_shared(use_shared), - data(data) { + bst_idx_t num_rows, float) + : use_shared(use_shared), data(data) { extern __shared__ float _smem[]; smem = _smem; // Copy instances if (use_shared) { bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x; int shared_elements = blockDim.x * data.num_features; - dh::BlockFill(smem, shared_elements, nanf("")); + dh::BlockFill(smem, shared_elements, std::numeric_limits::quiet_NaN()); __syncthreads(); if (global_idx < num_rows) { bst_uint elem_begin = data.d_row_ptr[global_idx]; bst_uint elem_end = data.d_row_ptr[global_idx + 1]; for (bst_uint elem_idx = elem_begin; elem_idx < elem_end; elem_idx++) { - Entry elem = data.d_data[elem_idx - entry_start]; + Entry elem = data.d_data[elem_idx]; smem[threadIdx.x * data.num_features + elem.index] = elem.fvalue; } } @@ -148,12 +145,12 @@ struct SparsePageLoader { struct EllpackLoader { EllpackDeviceAccessor const& matrix; XGBOOST_DEVICE EllpackLoader(EllpackDeviceAccessor const& m, bool, bst_feature_t, bst_idx_t, - size_t, float) + float) : matrix{m} {} - [[nodiscard]] __device__ __forceinline__ float GetElement(size_t ridx, size_t fidx) const { - auto gidx = matrix.GetBinIndex(ridx, fidx); + [[nodiscard]] XGBOOST_DEV_INLINE float GetElement(size_t ridx, size_t fidx) const { + auto gidx = matrix.GetBinIndex(ridx, fidx); if (gidx == -1) { - return nan(""); + return std::numeric_limits::quiet_NaN(); } if (common::IsCat(matrix.feature_types, fidx)) { return matrix.gidx_fvalue_map[gidx]; @@ -179,14 +176,14 @@ struct DeviceAdapterLoader { XGBOOST_DEV_INLINE DeviceAdapterLoader(Batch const batch, bool use_shared, bst_feature_t num_features, bst_idx_t num_rows, - size_t entry_start, float missing) + float missing) : batch{batch}, columns{num_features}, use_shared{use_shared}, is_valid{missing} { extern __shared__ float _smem[]; smem = _smem; if (use_shared) { uint32_t global_idx = blockDim.x * blockIdx.x + threadIdx.x; size_t shared_elements = blockDim.x * num_features; - dh::BlockFill(smem, shared_elements, nanf("")); + dh::BlockFill(smem, shared_elements, std::numeric_limits::quiet_NaN()); __syncthreads(); if (global_idx < num_rows) { auto beg = global_idx * columns; @@ -210,21 +207,19 @@ struct DeviceAdapterLoader { if (is_valid(value)) { return value; } else { - return nan(""); + return std::numeric_limits::quiet_NaN(); } } }; template -__device__ bst_node_t GetLeafIndex(bst_idx_t ridx, TreeView const &tree, - Loader *loader) { +__device__ bst_node_t GetLeafIndex(bst_idx_t ridx, TreeView const& tree, Loader* loader) { bst_node_t nidx = 0; RegTree::Node n = tree.d_tree[nidx]; while (!n.IsLeaf()) { float fvalue = loader->GetElement(ridx, n.SplitIndex()); bool is_missing = common::CheckNAN(fvalue); - nidx = GetNextNode(n, nidx, fvalue, - is_missing, tree.cats); + nidx = GetNextNode(n, nidx, fvalue, is_missing, tree.cats); n = tree.d_tree[nidx]; } return nidx; @@ -253,14 +248,14 @@ PredictLeafKernel(Data data, common::Span d_nodes, common::Span d_cat_node_segments, common::Span d_categories, - size_t tree_begin, size_t tree_end, size_t num_features, - size_t num_rows, size_t entry_start, bool use_shared, + size_t tree_begin, size_t tree_end, bst_feature_t num_features, + size_t num_rows, bool use_shared, float missing) { bst_idx_t ridx = blockDim.x * blockIdx.x + threadIdx.x; if (ridx >= num_rows) { return; } - Loader loader(data, use_shared, num_features, num_rows, entry_start, missing); + Loader loader{data, use_shared, num_features, num_rows, missing}; for (size_t tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) { TreeView d_tree{ tree_begin, tree_idx, d_nodes, @@ -288,10 +283,11 @@ PredictKernel(Data data, common::Span d_nodes, common::Span d_cat_node_segments, common::Span d_categories, size_t tree_begin, size_t tree_end, size_t num_features, size_t num_rows, - size_t entry_start, bool use_shared, int num_group, float missing) { + bool use_shared, int num_group, float missing) { bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x; - Loader loader(data, use_shared, num_features, num_rows, entry_start, missing); + Loader loader(data, use_shared, num_features, num_rows, missing); if (global_idx >= num_rows) return; + if (num_group == 1) { float sum = 0; for (size_t tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { @@ -627,10 +623,10 @@ __global__ void MaskBitVectorKernel( common::Span d_cat_tree_segments, common::Span d_cat_node_segments, common::Span d_categories, BitVector decision_bits, BitVector missing_bits, - std::size_t tree_begin, std::size_t tree_end, std::size_t num_features, std::size_t num_rows, - std::size_t entry_start, std::size_t num_nodes, bool use_shared, float missing) { + std::size_t tree_begin, std::size_t tree_end, bst_feature_t num_features, std::size_t num_rows, + std::size_t num_nodes, bool use_shared, float missing) { // This needs to be always instantiated since the data is loaded cooperatively by all threads. - SparsePageLoader loader(data, use_shared, num_features, num_rows, entry_start, missing); + SparsePageLoader loader{data, use_shared, num_features, num_rows, missing}; auto const row_idx = blockIdx.x * blockDim.x + threadIdx.x; if (row_idx >= num_rows) { return; @@ -789,17 +785,16 @@ class ColumnSplitHelper { batch.offset.SetDevice(ctx_->Device()); batch.data.SetDevice(ctx_->Device()); - std::size_t entry_start = 0; SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(), num_features); auto const grid = static_cast(common::DivRoundUp(num_rows, kBlockThreads)); - dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes, ctx_->CUDACtx()->Stream()} ( + dh::LaunchKernel {grid, kBlockThreads, shared_memory_bytes, ctx_->CUDACtx()->Stream()}( MaskBitVectorKernel, data, model.nodes.ConstDeviceSpan(), model.tree_segments.ConstDeviceSpan(), model.tree_group.ConstDeviceSpan(), model.split_types.ConstDeviceSpan(), model.categories_tree_segments.ConstDeviceSpan(), model.categories_node_segments.ConstDeviceSpan(), model.categories.ConstDeviceSpan(), decision_bits, missing_bits, model.tree_beg_, model.tree_end_, num_features, num_rows, - entry_start, num_nodes, use_shared, nan("")); + num_nodes, use_shared, std::numeric_limits::quiet_NaN()); AllReduceBitVectors(&decision_storage, &missing_storage); @@ -852,36 +847,30 @@ class ColumnSplitHelper { class GPUPredictor : public xgboost::Predictor { private: - void PredictInternal(const SparsePage& batch, - DeviceModel const& model, - size_t num_features, - HostDeviceVector* predictions, - size_t batch_offset, bool is_dense) const { + void PredictInternal(const SparsePage& batch, DeviceModel const& model, size_t num_features, + HostDeviceVector* predictions, size_t batch_offset, + bool is_dense) const { batch.offset.SetDevice(ctx_->Device()); batch.data.SetDevice(ctx_->Device()); const uint32_t BLOCK_THREADS = 128; - size_t num_rows = batch.Size(); + bst_idx_t num_rows = batch.Size(); auto GRID_SIZE = static_cast(common::DivRoundUp(num_rows, BLOCK_THREADS)); auto max_shared_memory_bytes = ConfigureDevice(ctx_->Device()); size_t shared_memory_bytes = SharedMemoryBytes(num_features, max_shared_memory_bytes); bool use_shared = shared_memory_bytes != 0; - size_t entry_start = 0; SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(), num_features); auto const kernel = [&](auto predict_fn) { - dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} ( + dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes}( predict_fn, data, model.nodes.ConstDeviceSpan(), - predictions->DeviceSpan().subspan(batch_offset), - model.tree_segments.ConstDeviceSpan(), - model.tree_group.ConstDeviceSpan(), - model.split_types.ConstDeviceSpan(), + predictions->DeviceSpan().subspan(batch_offset), model.tree_segments.ConstDeviceSpan(), + model.tree_group.ConstDeviceSpan(), model.split_types.ConstDeviceSpan(), model.categories_tree_segments.ConstDeviceSpan(), - model.categories_node_segments.ConstDeviceSpan(), - model.categories.ConstDeviceSpan(), model.tree_beg_, model.tree_end_, - num_features, num_rows, entry_start, use_shared, model.num_group, - nan("")); + model.categories_node_segments.ConstDeviceSpan(), model.categories.ConstDeviceSpan(), + model.tree_beg_, model.tree_end_, num_features, num_rows, use_shared, model.num_group, + std::numeric_limits::quiet_NaN()); }; if (is_dense) { kernel(PredictKernel); @@ -889,27 +878,23 @@ class GPUPredictor : public xgboost::Predictor { kernel(PredictKernel); } } - void PredictInternal(EllpackDeviceAccessor const& batch, - DeviceModel const& model, - HostDeviceVector* out_preds, - size_t batch_offset) const { + + void PredictInternal(EllpackDeviceAccessor const& batch, DeviceModel const& model, + HostDeviceVector* out_preds, bst_idx_t batch_offset) const { const uint32_t BLOCK_THREADS = 256; size_t num_rows = batch.n_rows; auto GRID_SIZE = static_cast(common::DivRoundUp(num_rows, BLOCK_THREADS)); DeviceModel d_model; bool use_shared = false; - size_t entry_start = 0; - dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS} ( - PredictKernel, batch, - model.nodes.ConstDeviceSpan(), out_preds->DeviceSpan().subspan(batch_offset), - model.tree_segments.ConstDeviceSpan(), model.tree_group.ConstDeviceSpan(), - model.split_types.ConstDeviceSpan(), + dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS}( + PredictKernel, batch, model.nodes.ConstDeviceSpan(), + out_preds->DeviceSpan().subspan(batch_offset), model.tree_segments.ConstDeviceSpan(), + model.tree_group.ConstDeviceSpan(), model.split_types.ConstDeviceSpan(), model.categories_tree_segments.ConstDeviceSpan(), - model.categories_node_segments.ConstDeviceSpan(), - model.categories.ConstDeviceSpan(), model.tree_beg_, model.tree_end_, - batch.NumFeatures(), num_rows, entry_start, use_shared, - model.num_group, nan("")); + model.categories_node_segments.ConstDeviceSpan(), model.categories.ConstDeviceSpan(), + model.tree_beg_, model.tree_end_, batch.NumFeatures(), num_rows, use_shared, + model.num_group, std::numeric_limits::quiet_NaN()); } void DevicePredictInternal(DMatrix* dmat, HostDeviceVector* out_preds, @@ -928,24 +913,22 @@ class GPUPredictor : public xgboost::Predictor { return; } + CHECK_LE(dmat->Info().num_col_, model.learner_model_param->num_feature); if (dmat->PageExists()) { - size_t batch_offset = 0; - for (auto &batch : dmat->GetBatches()) { - this->PredictInternal(batch, d_model, model.learner_model_param->num_feature, - out_preds, batch_offset, dmat->IsDense()); - batch_offset += batch.Size() * model.learner_model_param->num_output_group; + bst_idx_t batch_offset = 0; + for (auto& batch : dmat->GetBatches()) { + this->PredictInternal(batch, d_model, model.learner_model_param->num_feature, out_preds, + batch_offset, dmat->IsDense()); + batch_offset += batch.Size() * model.learner_model_param->OutputLength(); } } else { - size_t batch_offset = 0; + bst_idx_t batch_offset = 0; for (auto const& page : dmat->GetBatches(ctx_, BatchParam{})) { dmat->Info().feature_types.SetDevice(ctx_->Device()); auto feature_types = dmat->Info().feature_types.ConstDeviceSpan(); - this->PredictInternal( - page.Impl()->GetDeviceAccessor(ctx_->Device(), feature_types), - d_model, - out_preds, - batch_offset); - batch_offset += page.Impl()->n_rows; + this->PredictInternal(page.Impl()->GetDeviceAccessor(ctx_->Device(), feature_types), + d_model, out_preds, batch_offset); + batch_offset += page.Size() * model.learner_model_param->OutputLength(); } } } @@ -1004,17 +987,14 @@ class GPUPredictor : public xgboost::Predictor { d_model.Init(model, tree_begin, tree_end, m->Device()); bool use_shared = shared_memory_bytes != 0; - size_t entry_start = 0; - dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} ( - PredictKernel, m->Value(), - d_model.nodes.ConstDeviceSpan(), out_preds->predictions.DeviceSpan(), - d_model.tree_segments.ConstDeviceSpan(), d_model.tree_group.ConstDeviceSpan(), - d_model.split_types.ConstDeviceSpan(), + dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes}( + PredictKernel, m->Value(), d_model.nodes.ConstDeviceSpan(), + out_preds->predictions.DeviceSpan(), d_model.tree_segments.ConstDeviceSpan(), + d_model.tree_group.ConstDeviceSpan(), d_model.split_types.ConstDeviceSpan(), d_model.categories_tree_segments.ConstDeviceSpan(), - d_model.categories_node_segments.ConstDeviceSpan(), - d_model.categories.ConstDeviceSpan(), tree_begin, tree_end, m->NumColumns(), - m->NumRows(), entry_start, use_shared, output_groups, missing); + d_model.categories_node_segments.ConstDeviceSpan(), d_model.categories.ConstDeviceSpan(), + tree_begin, tree_end, m->NumColumns(), m->NumRows(), use_shared, output_groups, missing); } bool InplacePredict(std::shared_ptr p_m, const gbm::GBTreeModel& model, float missing, @@ -1043,8 +1023,8 @@ class GPUPredictor : public xgboost::Predictor { std::vector const* tree_weights, bool approximate, int, unsigned) const override { - std::string not_implemented{"contribution is not implemented in GPU " - "predictor, use `cpu_predictor` instead."}; + std::string not_implemented{ + "contribution is not implemented in the GPU predictor, use CPU instead."}; if (approximate) { LOG(FATAL) << "Approximated " << not_implemented; } @@ -1199,7 +1179,6 @@ class GPUPredictor : public xgboost::Predictor { info.num_col_, max_shared_memory_bytes); bool use_shared = shared_memory_bytes != 0; bst_feature_t num_features = info.num_col_; - size_t entry_start = 0; if (p_fmat->PageExists()) { for (auto const& batch : p_fmat->GetBatches()) { @@ -1223,7 +1202,7 @@ class GPUPredictor : public xgboost::Predictor { d_model.categories.ConstDeviceSpan(), d_model.tree_beg_, d_model.tree_end_, num_features, num_rows, - entry_start, use_shared, nan("")); + use_shared, std::numeric_limits::quiet_NaN()); batch_offset += batch.Size(); } } else { @@ -1245,16 +1224,12 @@ class GPUPredictor : public xgboost::Predictor { d_model.categories.ConstDeviceSpan(), d_model.tree_beg_, d_model.tree_end_, num_features, num_rows, - entry_start, use_shared, nan("")); + use_shared, std::numeric_limits::quiet_NaN()); batch_offset += batch.Size(); } } } - void Configure(const std::vector>& cfg) override { - Predictor::Configure(cfg); - } - private: /*! \brief Reconfigure the device when GPU is changed. */ static size_t ConfigureDevice(DeviceOrd device) { diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index 4895fb63fb79..01de15fe8bc8 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -147,39 +147,54 @@ TEST(GPUPredictor, EllpackTraining) { TestTrainingPrediction(&ctx, kRows, kBins, p_full, p_ellpack); } -TEST(GPUPredictor, ExternalMemoryTest) { - auto lparam = MakeCUDACtx(0); +namespace { +template +void TestDecisionStumpExternalMemory(Context const* ctx, bst_feature_t n_features, + Create create_fn) { + std::int32_t n_classes = 3; + LearnerModelParam mparam{MakeMP(n_features, .5, n_classes, ctx->Device())}; + auto model = CreateTestModel(&mparam, ctx, n_classes); std::unique_ptr gpu_predictor = - std::unique_ptr(Predictor::Create("gpu_predictor", &lparam)); + std::unique_ptr(Predictor::Create("gpu_predictor", ctx)); gpu_predictor->Configure({}); - const int n_classes = 3; - Context ctx = MakeCUDACtx(0); - LearnerModelParam mparam{MakeMP(5, .5, n_classes, ctx.Device())}; - - gbm::GBTreeModel model = CreateTestModel(&mparam, &ctx, n_classes); - std::vector> dmats; - - dmats.push_back(CreateSparsePageDMatrix(400)); - dmats.push_back(CreateSparsePageDMatrix(800)); - dmats.push_back(CreateSparsePageDMatrix(8000)); - - for (const auto& dmat: dmats) { - dmat->Info().base_margin_ = decltype(dmat->Info().base_margin_){ - {dmat->Info().num_row_, static_cast(n_classes)}, DeviceOrd::CUDA(0)}; - dmat->Info().base_margin_.Data()->Fill(0.5); + for (auto p_fmat : {create_fn(400), create_fn(800), create_fn(2048)}) { + p_fmat->Info().base_margin_ = linalg::Constant(ctx, 0.5f, p_fmat->Info().num_row_, n_classes); PredictionCacheEntry out_predictions; - gpu_predictor->InitOutPredictions(dmat->Info(), &out_predictions.predictions, model); - gpu_predictor->PredictBatch(dmat.get(), &out_predictions, model, 0); - EXPECT_EQ(out_predictions.predictions.Size(), dmat->Info().num_row_ * n_classes); - const std::vector &host_vector = out_predictions.predictions.ConstHostVector(); - for (size_t i = 0; i < host_vector.size() / n_classes; i++) { - ASSERT_EQ(host_vector[i * n_classes], 2.0); - ASSERT_EQ(host_vector[i * n_classes + 1], 0.5); - ASSERT_EQ(host_vector[i * n_classes + 2], 0.5); + gpu_predictor->InitOutPredictions(p_fmat->Info(), &out_predictions.predictions, model); + gpu_predictor->PredictBatch(p_fmat.get(), &out_predictions, model, 0); + ASSERT_EQ(out_predictions.predictions.Size(), p_fmat->Info().num_row_ * n_classes); + auto const& h_predt = out_predictions.predictions.ConstHostVector(); + for (size_t i = 0; i < h_predt.size() / n_classes; i++) { + ASSERT_EQ(h_predt[i * n_classes], 2.0); + ASSERT_EQ(h_predt[i * n_classes + 1], 0.5); + ASSERT_EQ(h_predt[i * n_classes + 2], 0.5); } } } +} // namespace + +TEST(GPUPredictor, ExternalMemory) { + auto ctx = MakeCUDACtx(0); + + bst_bin_t max_bin = 128; + bst_feature_t n_features = 32; + + TestDecisionStumpExternalMemory(&ctx, n_features, [&](bst_idx_t n_samples) { + return RandomDataGenerator{n_samples, n_features, 0.0f} + .Batches(4) + .Device(ctx.Device()) + .Bins(max_bin) + .GenerateSparsePageDMatrix("temp", false); + }); + TestDecisionStumpExternalMemory(&ctx, n_features, [&](bst_idx_t n_samples) { + return RandomDataGenerator{n_samples, n_features, 0.0f} + .Batches(4) + .Device(ctx.Device()) + .Bins(max_bin) + .GenerateExtMemQuantileDMatrix("temp", false); + }); +} TEST(GPUPredictor, InplacePredictCupy) { auto ctx = MakeCUDACtx(0);