Skip to content

Commit

Permalink
Small cleanup for rowset collection. (dmlc#10401)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Jun 19, 2024
1 parent e5f1720 commit 2b400b1
Showing 1 changed file with 55 additions and 60 deletions.
115 changes: 55 additions & 60 deletions src/common/row_set.h
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
/*!
* Copyright 2017-2022 by Contributors
/**
* Copyright 2017-2024, XGBoost Contributors
* \file row_set.h
* \brief Quick Utility to compute subset of rows
* \author Philip Cho, Tianqi Chen
*/
#ifndef XGBOOST_COMMON_ROW_SET_H_
#define XGBOOST_COMMON_ROW_SET_H_

#include <xgboost/data.h>
#include <algorithm>
#include <vector>
#include <utility>
#include <memory>
#include <cstddef> // for size_t
#include <iterator> // for distance
#include <vector> // for vector

namespace xgboost {
namespace common {
/*! \brief collection of rowset */
#include "xgboost/base.h" // for bst_node_t
#include "xgboost/logging.h" // for CHECK

namespace xgboost::common {
/**
* @brief Collection of rows for each tree node.
*/
class RowSetCollection {
public:
RowSetCollection() = default;
Expand All @@ -24,110 +26,103 @@ class RowSetCollection {
RowSetCollection& operator=(RowSetCollection const&) = delete;
RowSetCollection& operator=(RowSetCollection&&) = default;

/*! \brief data structure to store an instance set, a subset of
* rows (instances) associated with a particular node in a decision
* tree. */
/**
* @brief data structure to store an instance set, a subset of rows (instances)
* associated with a particular node in a decision tree.
*/
struct Elem {
const size_t* begin{nullptr};
const size_t* end{nullptr};
std::size_t const* begin{nullptr};
std::size_t const* end{nullptr};
bst_node_t node_id{-1};
// id of node associated with this instance set; -1 means uninitialized
Elem()
= default;
Elem(const size_t* begin,
const size_t* end,
bst_node_t node_id = -1)
// id of node associated with this instance set; -1 means uninitialized
Elem() = default;
Elem(std::size_t const* begin, std::size_t const* end, bst_node_t node_id = -1)
: begin(begin), end(end), node_id(node_id) {}

inline size_t Size() const {
return end - begin;
}
std::size_t Size() const { return end - begin; }
};

std::vector<Elem>::const_iterator begin() const { // NOLINT
return elem_of_each_node_.begin();
[[nodiscard]] std::vector<Elem>::const_iterator begin() const { // NOLINT
return elem_of_each_node_.cbegin();
}

std::vector<Elem>::const_iterator end() const { // NOLINT
return elem_of_each_node_.end();
[[nodiscard]] std::vector<Elem>::const_iterator end() const { // NOLINT
return elem_of_each_node_.cend();
}

size_t Size() const { return std::distance(begin(), end()); }
[[nodiscard]] std::size_t Size() const { return std::distance(begin(), end()); }

/*! \brief return corresponding element set given the node_id */
inline const Elem& operator[](unsigned node_id) const {
const Elem& e = elem_of_each_node_[node_id];
/** @brief return corresponding element set given the node_id */
[[nodiscard]] Elem const& operator[](bst_node_t node_id) const {
Elem const& e = elem_of_each_node_[node_id];
return e;
}

/*! \brief return corresponding element set given the node_id */
inline Elem& operator[](unsigned node_id) {
/** @brief return corresponding element set given the node_id */
[[nodiscard]] Elem& operator[](bst_node_t node_id) {
Elem& e = elem_of_each_node_[node_id];
return e;
}

// clear up things
inline void Clear() {
void Clear() {
elem_of_each_node_.clear();
}
// initialize node id 0->everything
inline void Init() {
CHECK_EQ(elem_of_each_node_.size(), 0U);
void Init() {
CHECK(elem_of_each_node_.empty());

if (row_indices_.empty()) { // edge case: empty instance set
constexpr size_t* kBegin = nullptr;
constexpr size_t* kEnd = nullptr;
constexpr std::size_t* kBegin = nullptr;
constexpr std::size_t* kEnd = nullptr;
static_assert(kEnd - kBegin == 0);
elem_of_each_node_.emplace_back(kBegin, kEnd, 0);
return;
}

const size_t* begin = dmlc::BeginPtr(row_indices_);
const size_t* end = dmlc::BeginPtr(row_indices_) + row_indices_.size();
const std::size_t* begin = dmlc::BeginPtr(row_indices_);
const std::size_t* end = dmlc::BeginPtr(row_indices_) + row_indices_.size();
elem_of_each_node_.emplace_back(begin, end, 0);
}

std::vector<size_t>* Data() { return &row_indices_; }
std::vector<size_t> const* Data() const { return &row_indices_; }
[[nodiscard]] std::vector<std::size_t>* Data() { return &row_indices_; }
[[nodiscard]] std::vector<std::size_t> const* Data() const { return &row_indices_; }

// split rowset into two
inline void AddSplit(unsigned node_id, unsigned left_node_id, unsigned right_node_id,
size_t n_left, size_t n_right) {
void AddSplit(bst_node_t node_id, bst_node_t left_node_id, bst_node_t right_node_id,
bst_idx_t n_left, bst_idx_t n_right) {
const Elem e = elem_of_each_node_[node_id];

size_t* all_begin{nullptr};
size_t* begin{nullptr};
std::size_t* all_begin{nullptr};
std::size_t* begin{nullptr};
if (e.begin == nullptr) {
CHECK_EQ(n_left, 0);
CHECK_EQ(n_right, 0);
} else {
all_begin = dmlc::BeginPtr(row_indices_);
all_begin = row_indices_.data();
begin = all_begin + (e.begin - all_begin);
}

CHECK_EQ(n_left + n_right, e.Size());
CHECK_LE(begin + n_left, e.end);
CHECK_EQ(begin + n_left + n_right, e.end);

if (left_node_id >= elem_of_each_node_.size()) {
elem_of_each_node_.resize(left_node_id + 1, Elem(nullptr, nullptr, -1));
if (left_node_id >= static_cast<bst_node_t>(elem_of_each_node_.size())) {
elem_of_each_node_.resize(left_node_id + 1, Elem{nullptr, nullptr, -1});
}
if (right_node_id >= elem_of_each_node_.size()) {
elem_of_each_node_.resize(right_node_id + 1, Elem(nullptr, nullptr, -1));
if (right_node_id >= static_cast<bst_node_t>(elem_of_each_node_.size())) {
elem_of_each_node_.resize(right_node_id + 1, Elem{nullptr, nullptr, -1});
}

elem_of_each_node_[left_node_id] = Elem(begin, begin + n_left, left_node_id);
elem_of_each_node_[right_node_id] = Elem(begin + n_left, e.end, right_node_id);
elem_of_each_node_[node_id] = Elem(nullptr, nullptr, -1);
elem_of_each_node_[left_node_id] = Elem{begin, begin + n_left, left_node_id};
elem_of_each_node_[right_node_id] = Elem{begin + n_left, e.end, right_node_id};
elem_of_each_node_[node_id] = Elem{nullptr, nullptr, -1};
}

private:
// stores the row indexes in the set
std::vector<size_t> row_indices_;
std::vector<std::size_t> row_indices_;
// vector: node_id -> elements
std::vector<Elem> elem_of_each_node_;
};
} // namespace common
} // namespace xgboost
} // namespace xgboost::common

#endif // XGBOOST_COMMON_ROW_SET_H_

0 comments on commit 2b400b1

Please sign in to comment.