xgboost
row_set.h
Go to the documentation of this file.
1 
7 #ifndef XGBOOST_COMMON_ROW_SET_H_
8 #define XGBOOST_COMMON_ROW_SET_H_
9 
10 #include <xgboost/data.h>
11 #include <algorithm>
12 #include <vector>
13 #include <utility>
14 #include <memory>
15 
16 namespace xgboost {
17 namespace common {
18 
21  public:
25  struct Elem {
26  const size_t* begin{nullptr};
27  const size_t* end{nullptr};
29  // id of node associated with this instance set; -1 means uninitialized
30  Elem()
31  = default;
32  Elem(const size_t* begin,
33  const size_t* end,
34  bst_node_t node_id = -1)
35  : begin(begin), end(end), node_id(node_id) {}
36 
37  inline size_t Size() const {
38  return end - begin;
39  }
40  };
41  /* \brief specifies how to split a rowset into two */
42  struct Split {
43  std::vector<size_t> left;
44  std::vector<size_t> right;
45  };
46 
47  inline std::vector<Elem>::const_iterator begin() const { // NOLINT
48  return elem_of_each_node_.begin();
49  }
50 
51  inline std::vector<Elem>::const_iterator end() const { // NOLINT
52  return elem_of_each_node_.end();
53  }
54 
56  inline const Elem& operator[](unsigned node_id) const {
57  const Elem& e = elem_of_each_node_[node_id];
58  return e;
59  }
60 
62  inline Elem& operator[](unsigned node_id) {
63  Elem& e = elem_of_each_node_[node_id];
64  return e;
65  }
66 
67  // clear up things
68  inline void Clear() {
69  elem_of_each_node_.clear();
70  }
71  // initialize node id 0->everything
72  inline void Init() {
73  CHECK_EQ(elem_of_each_node_.size(), 0U);
74 
75  if (row_indices_.empty()) { // edge case: empty instance set
76  constexpr size_t* kBegin = nullptr;
77  constexpr size_t* kEnd = nullptr;
78  static_assert(kEnd - kBegin == 0, "");
79  elem_of_each_node_.emplace_back(Elem(kBegin, kEnd, 0));
80  return;
81  }
82 
83  const size_t* begin = dmlc::BeginPtr(row_indices_);
84  const size_t* end = dmlc::BeginPtr(row_indices_) + row_indices_.size();
85  elem_of_each_node_.emplace_back(Elem(begin, end, 0));
86  }
87 
88  std::vector<size_t>* Data() { return &row_indices_; }
89  // split rowset into two
90  inline void AddSplit(unsigned node_id, unsigned left_node_id, unsigned right_node_id,
91  size_t n_left, size_t n_right) {
92  const Elem e = elem_of_each_node_[node_id];
93 
94  size_t* all_begin{nullptr};
95  size_t* begin{nullptr};
96  if (e.begin == nullptr) {
97  CHECK_EQ(n_left, 0);
98  CHECK_EQ(n_right, 0);
99  } else {
100  all_begin = dmlc::BeginPtr(row_indices_);
101  begin = all_begin + (e.begin - all_begin);
102  }
103 
104  CHECK_EQ(n_left + n_right, e.Size());
105  CHECK_LE(begin + n_left, e.end);
106  CHECK_EQ(begin + n_left + n_right, e.end);
107 
108  if (left_node_id >= elem_of_each_node_.size()) {
109  elem_of_each_node_.resize(left_node_id + 1, Elem(nullptr, nullptr, -1));
110  }
111  if (right_node_id >= elem_of_each_node_.size()) {
112  elem_of_each_node_.resize(right_node_id + 1, Elem(nullptr, nullptr, -1));
113  }
114 
115  elem_of_each_node_[left_node_id] = Elem(begin, begin + n_left, left_node_id);
116  elem_of_each_node_[right_node_id] = Elem(begin + n_left, e.end, right_node_id);
117  elem_of_each_node_[node_id] = Elem(nullptr, nullptr, -1);
118  }
119 
120  private:
121  // stores the row indexes in the set
122  std::vector<size_t> row_indices_;
123  // vector: node_id -> elements
124  std::vector<Elem> elem_of_each_node_;
125 };
126 
127 } // namespace common
128 } // namespace xgboost
129 
130 #endif // XGBOOST_COMMON_ROW_SET_H_
xgboost::common::RowSetCollection::Split::right
std::vector< size_t > right
Definition: row_set.h:44
xgboost::common::RowSetCollection::end
std::vector< Elem >::const_iterator end() const
Definition: row_set.h:51
xgboost::common::RowSetCollection::operator[]
const Elem & operator[](unsigned node_id) const
return corresponding element set given the node_id
Definition: row_set.h:56
xgboost::common::RowSetCollection::Elem::Size
size_t Size() const
Definition: row_set.h:37
xgboost::common::RowSetCollection::Split
Definition: row_set.h:42
xgboost::common::RowSetCollection::Elem::Elem
Elem()=default
xgboost::common::RowSetCollection
collection of rowset
Definition: row_set.h:20
xgboost::common::RowSetCollection::begin
std::vector< Elem >::const_iterator begin() const
Definition: row_set.h:47
xgboost::common::RowSetCollection::Split::left
std::vector< size_t > left
Definition: row_set.h:43
xgboost::bst_node_t
int32_t bst_node_t
Type for tree node index.
Definition: base.h:132
xgboost::common::RowSetCollection::operator[]
Elem & operator[](unsigned node_id)
return corresponding element set given the node_id
Definition: row_set.h:62
xgboost::common::RowSetCollection::Elem::node_id
bst_node_t node_id
Definition: row_set.h:28
xgboost::common::RowSetCollection::Data
std::vector< size_t > * Data()
Definition: row_set.h:88
xgboost::common::RowSetCollection::AddSplit
void AddSplit(unsigned node_id, unsigned left_node_id, unsigned right_node_id, size_t n_left, size_t n_right)
Definition: row_set.h:90
xgboost::common::RowSetCollection::Elem::end
const size_t * end
Definition: row_set.h:27
data.h
The input data structure of xgboost.
xgboost::common::RowSetCollection::Clear
void Clear()
Definition: row_set.h:68
xgboost::common::RowSetCollection::Elem::begin
const size_t * begin
Definition: row_set.h:26
xgboost::common::RowSetCollection::Elem::Elem
Elem(const size_t *begin, const size_t *end, bst_node_t node_id=-1)
Definition: row_set.h:32
xgboost::common::RowSetCollection::Init
void Init()
Definition: row_set.h:72
xgboost::common::RowSetCollection::Elem
data structure to store an instance set, a subset of rows (instances) associated with a particular no...
Definition: row_set.h:25
xgboost
namespace of xgboost
Definition: base.h:110