Go to the documentation of this file.
7 #ifndef XGBOOST_COMMON_ROW_SET_H_
8 #define XGBOOST_COMMON_ROW_SET_H_
27 const size_t*
end{
nullptr};
37 inline size_t Size()
const {
47 inline std::vector<Elem>::const_iterator
begin()
const {
48 return elem_of_each_node_.begin();
51 inline std::vector<Elem>::const_iterator
end()
const {
52 return elem_of_each_node_.end();
57 const Elem& e = elem_of_each_node_[node_id];
58 CHECK(e.
begin !=
nullptr)
59 <<
"access element that is not in the set";
65 Elem& e = elem_of_each_node_[node_id];
71 elem_of_each_node_.clear();
75 CHECK_EQ(elem_of_each_node_.size(), 0U);
77 if (row_indices_.empty()) {
83 const size_t*
begin =
reinterpret_cast<size_t*
>(20);
89 const size_t*
begin = dmlc::BeginPtr(row_indices_);
90 const size_t*
end = dmlc::BeginPtr(row_indices_) + row_indices_.size();
94 std::vector<size_t>*
Data() {
return &row_indices_; }
97 unsigned left_node_id,
98 unsigned right_node_id,
101 const Elem e = elem_of_each_node_[node_id];
102 CHECK(e.
begin !=
nullptr);
103 size_t* all_begin = dmlc::BeginPtr(row_indices_);
104 size_t*
begin = all_begin + (e.
begin - all_begin);
106 CHECK_EQ(n_left + n_right, e.
Size());
108 CHECK_EQ(
begin + n_left + n_right, e.
end);
110 if (left_node_id >= elem_of_each_node_.size()) {
111 elem_of_each_node_.resize(left_node_id + 1,
Elem(
nullptr,
nullptr, -1));
113 if (right_node_id >= elem_of_each_node_.size()) {
114 elem_of_each_node_.resize(right_node_id + 1,
Elem(
nullptr,
nullptr, -1));
117 elem_of_each_node_[left_node_id] =
Elem(
begin,
begin + n_left, left_node_id);
118 elem_of_each_node_[right_node_id] =
Elem(
begin + n_left, e.
end, right_node_id);
119 elem_of_each_node_[node_id] =
Elem(
nullptr,
nullptr, -1);
124 std::vector<size_t> row_indices_;
126 std::vector<Elem> elem_of_each_node_;
135 template<
size_t BlockSize>
138 template<
typename Func>
139 void Init(
const size_t n_tasks,
size_t n_nodes, Func funcNTaks) {
144 for (
size_t i = 1; i < n_nodes+1; ++i) {
158 CHECK_NE(local_block_ptr, (
BlockInfo*)
nullptr);
164 const size_t task_idx =
GetTaskIdx(nid, begin);
165 return {
mem_blocks_.at(task_idx)->Left(), end - begin };
169 const size_t task_idx =
GetTaskIdx(nid, begin);
170 return {
mem_blocks_.at(task_idx)->Right(), end - begin };
213 size_t* left_result = rows_indexes +
mem_blocks_[task_idx]->n_offset_left;
214 size_t* right_result = rows_indexes +
mem_blocks_[task_idx]->n_offset_right;
216 const size_t* left =
mem_blocks_[task_idx]->Left();
217 const size_t* right =
mem_blocks_[task_idx]->Right();
219 std::copy_n(left,
mem_blocks_[task_idx]->n_left, left_result);
220 std::copy_n(right,
mem_blocks_[task_idx]->n_right, right_result);
236 return &left_data_[0];
240 return &right_data_[0];
243 size_t left_data_[BlockSize];
244 size_t right_data_[BlockSize];
256 #endif // XGBOOST_COMMON_ROW_SET_H_
void MergeToArray(int nid, size_t begin, size_t *rows_indexes)
Definition: row_set.h:210
std::vector< size_t > right
Definition: row_set.h:44
std::vector< Elem >::const_iterator end() const
Definition: row_set.h:51
size_t * Right()
Definition: row_set.h:239
Definition: row_set.h:136
common::Span< size_t > GetRightBuffer(int nid, size_t begin, size_t end)
Definition: row_set.h:168
const Elem & operator[](unsigned node_id) const
return corresponding element set given the node_id
Definition: row_set.h:56
size_t Size() const
Definition: row_set.h:37
size_t max_n_tasks_
Definition: row_set.h:249
void Init(const size_t n_tasks, size_t n_nodes, Func funcNTaks)
Definition: row_set.h:139
std::vector< std::shared_ptr< BlockInfo > > mem_blocks_
Definition: row_set.h:248
size_t n_offset_right
Definition: row_set.h:233
collection of rowset
Definition: row_set.h:20
size_t n_offset_left
Definition: row_set.h:232
std::vector< Elem >::const_iterator begin() const
Definition: row_set.h:47
size_t n_left
Definition: row_set.h:229
std::vector< size_t > left
Definition: row_set.h:43
std::vector< std::pair< size_t, size_t > > left_right_nodes_sizes_
Definition: row_set.h:246
int32_t bst_node_t
Type for tree node index.
Definition: base.h:132
Elem & operator[](unsigned node_id)
return corresponding element set given the node_id
Definition: row_set.h:64
bst_node_t node_id
Definition: row_set.h:28
common::Span< size_t > GetLeftBuffer(int nid, size_t begin, size_t end)
Definition: row_set.h:163
Definition: row_set.h:228
std::vector< size_t > * Data()
Definition: row_set.h:94
void AddSplit(unsigned node_id, unsigned left_node_id, unsigned right_node_id, size_t n_left, size_t n_right)
Definition: row_set.h:96
size_t n_right
Definition: row_set.h:230
size_t * Left()
Definition: row_set.h:235
const size_t * end
Definition: row_set.h:27
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition: span.h:137
size_t GetTaskIdx(int nid, size_t begin)
Definition: row_set.h:223
The input data structure of xgboost.
void SetNRightElems(int nid, size_t begin, size_t end, size_t n_right)
Definition: row_set.h:178
void AllocateForTask(size_t id)
Definition: row_set.h:155
size_t GetNLeftElems(int nid) const
Definition: row_set.h:184
void Clear()
Definition: row_set.h:70
std::vector< size_t > blocks_offsets_
Definition: row_set.h:247
const size_t * begin
Definition: row_set.h:26
void CalculateRowOffsets()
Definition: row_set.h:194
Elem(const size_t *begin, const size_t *end, bst_node_t node_id=-1)
Definition: row_set.h:32
auto get(U &json) -> decltype(detail::GetImpl(*Cast< T >(&json.GetValue())))&
Get Json value.
Definition: json.h:546
void Init()
Definition: row_set.h:74
void SetNLeftElems(int nid, size_t begin, size_t end, size_t n_left)
Definition: row_set.h:173
data structure to store an instance set, a subset of rows (instances) associated with a particular no...
Definition: row_set.h:25
size_t GetNRightElems(int nid) const
Definition: row_set.h:188
namespace of xgboost
Definition: base.h:110