xgboost
row_set.h
Go to the documentation of this file.
1 
7 #ifndef XGBOOST_COMMON_ROW_SET_H_
8 #define XGBOOST_COMMON_ROW_SET_H_
9 
10 #include <xgboost/data.h>
11 #include <algorithm>
12 #include <vector>
13 
14 namespace xgboost {
15 namespace common {
16 
19  public:
23  struct Elem {
24  const size_t* begin{nullptr};
25  const size_t* end{nullptr};
26  int node_id{-1};
27  // id of node associated with this instance set; -1 means uninitialized
28  Elem()
29  = default;
30  Elem(const size_t* begin,
31  const size_t* end,
32  int node_id)
33  : begin(begin), end(end), node_id(node_id) {}
34 
35  inline size_t Size() const {
36  return end - begin;
37  }
38  };
39  /* \brief specifies how to split a rowset into two */
40  struct Split {
41  std::vector<size_t> left;
42  std::vector<size_t> right;
43  };
44 
45  inline std::vector<Elem>::const_iterator begin() const { // NOLINT
46  return elem_of_each_node_.begin();
47  }
48 
49  inline std::vector<Elem>::const_iterator end() const { // NOLINT
50  return elem_of_each_node_.end();
51  }
52 
54  inline const Elem& operator[](unsigned node_id) const {
55  const Elem& e = elem_of_each_node_[node_id];
56  CHECK(e.begin != nullptr)
57  << "access element that is not in the set";
58  return e;
59  }
60  // clear up things
61  inline void Clear() {
62  elem_of_each_node_.clear();
63  }
64  // initialize node id 0->everything
65  inline void Init() {
66  CHECK_EQ(elem_of_each_node_.size(), 0U);
67 
68  if (row_indices_.empty()) { // edge case: empty instance set
69  // assign arbitrary address here, to bypass nullptr check
70  // (nullptr usually indicates a nonexistent rowset, but we want to
71  // indicate a valid rowset that happens to have zero length and occupies
72  // the whole instance set)
73  // this is okay, as BuildHist will compute (end-begin) as the set size
74  const size_t* begin = reinterpret_cast<size_t*>(20);
75  const size_t* end = begin;
76  elem_of_each_node_.emplace_back(Elem(begin, end, 0));
77  return;
78  }
79 
80  const size_t* begin = dmlc::BeginPtr(row_indices_);
81  const size_t* end = dmlc::BeginPtr(row_indices_) + row_indices_.size();
82  elem_of_each_node_.emplace_back(Elem(begin, end, 0));
83  }
84  // split rowset into two
85  inline void AddSplit(unsigned node_id,
86  const std::vector<Split>& row_split_tloc,
87  unsigned left_node_id,
88  unsigned right_node_id) {
89  const Elem e = elem_of_each_node_[node_id];
90  const auto nthread = static_cast<bst_omp_uint>(row_split_tloc.size());
91  CHECK(e.begin != nullptr);
92  size_t* all_begin = dmlc::BeginPtr(row_indices_);
93  size_t* begin = all_begin + (e.begin - all_begin);
94 
95  size_t* it = begin;
96  for (bst_omp_uint tid = 0; tid < nthread; ++tid) {
97  std::copy(row_split_tloc[tid].left.begin(), row_split_tloc[tid].left.end(), it);
98  it += row_split_tloc[tid].left.size();
99  }
100  size_t* split_pt = it;
101  for (bst_omp_uint tid = 0; tid < nthread; ++tid) {
102  std::copy(row_split_tloc[tid].right.begin(), row_split_tloc[tid].right.end(), it);
103  it += row_split_tloc[tid].right.size();
104  }
105 
106  if (left_node_id >= elem_of_each_node_.size()) {
107  elem_of_each_node_.resize(left_node_id + 1, Elem(nullptr, nullptr, -1));
108  }
109  if (right_node_id >= elem_of_each_node_.size()) {
110  elem_of_each_node_.resize(right_node_id + 1, Elem(nullptr, nullptr, -1));
111  }
112 
113  elem_of_each_node_[left_node_id] = Elem(begin, split_pt, left_node_id);
114  elem_of_each_node_[right_node_id] = Elem(split_pt, e.end, right_node_id);
115  elem_of_each_node_[node_id] = Elem(nullptr, nullptr, -1);
116  }
117 
118  // stores the row indices in the set
119  std::vector<size_t> row_indices_;
120 
121  private:
122  // vector: node_id -> elements
123  std::vector<Elem> elem_of_each_node_;
124 };
125 
126 } // namespace common
127 } // namespace xgboost
128 
129 #endif // XGBOOST_COMMON_ROW_SET_H_
collection of rowset
Definition: row_set.h:18
std::vector< Elem >::const_iterator begin() const
Definition: row_set.h:45
Elem(const size_t *begin, const size_t *end, int node_id)
Definition: row_set.h:30
The input data structure of xgboost.
dmlc::omp_uint bst_omp_uint
define unsigned int for openmp loop
Definition: base.h:246
void Init()
Definition: row_set.h:65
std::vector< Elem >::const_iterator end() const
Definition: row_set.h:49
void AddSplit(unsigned node_id, const std::vector< Split > &row_split_tloc, unsigned left_node_id, unsigned right_node_id)
Definition: row_set.h:85
int node_id
Definition: row_set.h:26
namespace of xgboost
Definition: base.h:102
data structure to store an instance set, a subset of rows (instances) associated with a particular no...
Definition: row_set.h:23
const size_t * begin
Definition: row_set.h:24
const Elem & operator[](unsigned node_id) const
return corresponding element set given the node_id
Definition: row_set.h:54
const size_t * end
Definition: row_set.h:25
std::vector< size_t > row_indices_
Definition: row_set.h:119
std::vector< size_t > left
Definition: row_set.h:41
std::vector< size_t > right
Definition: row_set.h:42
void Clear()
Definition: row_set.h:61
size_t Size() const
Definition: row_set.h:35