xgboost
tree_model.h
Go to the documentation of this file.
1 
7 #ifndef XGBOOST_TREE_MODEL_H_
8 #define XGBOOST_TREE_MODEL_H_
9 
10 #include <xgboost/base.h>
11 #include <xgboost/data.h>
12 #include <xgboost/feature_map.h>
13 #include <xgboost/linalg.h> // for VectorView
14 #include <xgboost/logging.h>
15 #include <xgboost/model.h>
16 #include <xgboost/multi_target_tree_model.h> // for MultiTargetTree
17 
18 #include <algorithm>
19 #include <cstring>
20 #include <limits>
21 #include <memory> // for make_unique
22 #include <stack>
23 #include <string>
24 #include <vector>
25 
26 namespace xgboost {
27 class Json;
28 
30 struct TreeParam {
39 
40  bool operator==(const TreeParam& b) const {
41  return num_nodes == b.num_nodes && num_deleted == b.num_deleted &&
43  }
44 
45  void FromJson(Json const& in);
46  void ToJson(Json* p_out) const;
47 };
48 
50 struct RTreeNodeStat {
58  int leaf_child_cnt {0};
59 
60  RTreeNodeStat() = default;
61  RTreeNodeStat(float loss_chg, float sum_hess, float weight) :
63  bool operator==(const RTreeNodeStat& b) const {
64  return loss_chg == b.loss_chg && sum_hess == b.sum_hess &&
66  }
67 };
68 
72 template <typename T>
74  std::unique_ptr<T> ptr_{nullptr};
75 
76  public:
77  CopyUniquePtr() = default;
79  ptr_.reset(nullptr);
80  if (that.ptr_) {
81  ptr_ = std::make_unique<T>(*that);
82  }
83  }
84  T* get() const noexcept { return ptr_.get(); } // NOLINT
85 
86  T& operator*() { return *ptr_; }
87  T* operator->() noexcept { return this->get(); }
88 
89  T const& operator*() const { return *ptr_; }
90  T const* operator->() const noexcept { return this->get(); }
91 
92  explicit operator bool() const { return static_cast<bool>(ptr_); }
93  bool operator!() const { return !ptr_; }
94  void reset(T* ptr) { ptr_.reset(ptr); } // NOLINT
95 };
96 
102 class RegTree : public Model {
103  public:
106  static constexpr uint32_t kDeletedNodeMarker = std::numeric_limits<uint32_t>::max();
107  static constexpr bst_node_t kRoot{0};
108 
110  class Node {
111  public:
113  // assert compact alignment
114  static_assert(sizeof(Node) == 4 * sizeof(int) + sizeof(Info), "Node: 64 bit align");
115  }
116  Node(int32_t cleft, int32_t cright, int32_t parent, uint32_t split_ind, float split_cond,
117  bool default_left)
118  : parent_{parent}, cleft_{cleft}, cright_{cright} {
119  this->SetParent(parent_);
120  this->SetSplit(split_ind, split_cond, default_left);
121  }
122 
124  [[nodiscard]] XGBOOST_DEVICE int LeftChild() const { return this->cleft_; }
126  [[nodiscard]] XGBOOST_DEVICE int RightChild() const { return this->cright_; }
128  [[nodiscard]] XGBOOST_DEVICE int DefaultChild() const {
129  return this->DefaultLeft() ? this->LeftChild() : this->RightChild();
130  }
132  [[nodiscard]] XGBOOST_DEVICE bst_feature_t SplitIndex() const {
133  static_assert(!std::is_signed_v<bst_feature_t>);
134  return sindex_ & ((1U << 31) - 1U);
135  }
137  [[nodiscard]] XGBOOST_DEVICE bool DefaultLeft() const { return (sindex_ >> 31) != 0; }
139  [[nodiscard]] XGBOOST_DEVICE bool IsLeaf() const { return cleft_ == kInvalidNodeId; }
141  [[nodiscard]] XGBOOST_DEVICE float LeafValue() const { return (this->info_).leaf_value; }
143  [[nodiscard]] XGBOOST_DEVICE SplitCondT SplitCond() const { return (this->info_).split_cond; }
145  [[nodiscard]] XGBOOST_DEVICE int Parent() const { return parent_ & ((1U << 31) - 1); }
147  [[nodiscard]] XGBOOST_DEVICE bool IsLeftChild() const { return (parent_ & (1U << 31)) != 0; }
149  [[nodiscard]] XGBOOST_DEVICE bool IsDeleted() const { return sindex_ == kDeletedNodeMarker; }
151  [[nodiscard]] XGBOOST_DEVICE bool IsRoot() const { return parent_ == kInvalidNodeId; }
156  XGBOOST_DEVICE void SetLeftChild(int nid) {
157  this->cleft_ = nid;
158  }
164  this->cright_ = nid;
165  }
172  XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond,
173  bool default_left = false) {
174  if (default_left) split_index |= (1U << 31);
175  this->sindex_ = split_index;
176  (this->info_).split_cond = split_cond;
177  }
184  XGBOOST_DEVICE void SetLeaf(bst_float value, int right = kInvalidNodeId) {
185  (this->info_).leaf_value = value;
186  this->cleft_ = kInvalidNodeId;
187  this->cright_ = right;
188  }
191  this->sindex_ = kDeletedNodeMarker;
192  }
195  this->sindex_ = 0;
196  }
197  // set parent
198  XGBOOST_DEVICE void SetParent(int pidx, bool is_left_child = true) {
199  if (is_left_child) pidx |= (1U << 31);
200  this->parent_ = pidx;
201  }
202  bool operator==(const Node& b) const {
203  return parent_ == b.parent_ && cleft_ == b.cleft_ &&
204  cright_ == b.cright_ && sindex_ == b.sindex_ &&
205  info_.leaf_value == b.info_.leaf_value;
206  }
207 
208  private:
213  union Info{
214  bst_float leaf_value;
215  SplitCondT split_cond;
216  };
217  // pointer to parent, highest bit is used to
218  // indicate whether it's a left child or not
219  int32_t parent_{kInvalidNodeId};
220  // pointer to left, right
221  int32_t cleft_{kInvalidNodeId}, cright_{kInvalidNodeId};
222  // split feature index, left split or right split depends on the highest bit
223  uint32_t sindex_{0};
224  // extra info
225  Info info_;
226  };
227 
233  void ChangeToLeaf(int rid, bst_float value) {
234  CHECK(nodes_[nodes_[rid].LeftChild() ].IsLeaf());
235  CHECK(nodes_[nodes_[rid].RightChild()].IsLeaf());
236  this->DeleteNode(nodes_[rid].LeftChild());
237  this->DeleteNode(nodes_[rid].RightChild());
238  nodes_[rid].SetLeaf(value);
239  }
245  void CollapseToLeaf(int rid, bst_float value) {
246  if (nodes_[rid].IsLeaf()) return;
247  if (!nodes_[nodes_[rid].LeftChild() ].IsLeaf()) {
248  CollapseToLeaf(nodes_[rid].LeftChild(), 0.0f);
249  }
250  if (!nodes_[nodes_[rid].RightChild() ].IsLeaf()) {
251  CollapseToLeaf(nodes_[rid].RightChild(), 0.0f);
252  }
253  this->ChangeToLeaf(rid, value);
254  }
255 
257  nodes_.resize(param_.num_nodes);
258  stats_.resize(param_.num_nodes);
259  split_types_.resize(param_.num_nodes, FeatureType::kNumerical);
260  split_categories_segments_.resize(param_.num_nodes);
261  for (int i = 0; i < param_.num_nodes; i++) {
262  nodes_[i].SetLeaf(0.0f);
263  nodes_[i].SetParent(kInvalidNodeId);
264  }
265  }
269  explicit RegTree(bst_target_t n_targets, bst_feature_t n_features) : RegTree{} {
270  param_.num_feature = n_features;
271  param_.size_leaf_vector = n_targets;
272  if (n_targets > 1) {
273  this->p_mt_tree_.reset(new MultiTargetTree{&param_});
274  }
275  }
276 
278  Node& operator[](int nid) {
279  return nodes_[nid];
280  }
282  const Node& operator[](int nid) const {
283  return nodes_[nid];
284  }
285 
287  [[nodiscard]] const std::vector<Node>& GetNodes() const { return nodes_; }
288 
290  [[nodiscard]] const std::vector<RTreeNodeStat>& GetStats() const { return stats_; }
291 
293  RTreeNodeStat& Stat(int nid) {
294  return stats_[nid];
295  }
297  [[nodiscard]] const RTreeNodeStat& Stat(int nid) const {
298  return stats_[nid];
299  }
300 
301  void LoadModel(Json const& in) override;
302  void SaveModel(Json* out) const override;
303 
304  bool operator==(const RegTree& b) const {
305  return nodes_ == b.nodes_ && stats_ == b.stats_ &&
306  deleted_nodes_ == b.deleted_nodes_ && param_ == b.param_;
307  }
308  /* \brief Iterate through all nodes in this tree.
309  *
310  * \param Function that accepts a node index, and returns false when iteration should
311  * stop, otherwise returns true.
312  */
313  template <typename Func> void WalkTree(Func func) const {
314  std::stack<bst_node_t> nodes;
315  nodes.push(kRoot);
316  auto &self = *this;
317  while (!nodes.empty()) {
318  auto nidx = nodes.top();
319  nodes.pop();
320  if (!func(nidx)) {
321  return;
322  }
323  auto left = self.LeftChild(nidx);
324  auto right = self.RightChild(nidx);
325  if (left != RegTree::kInvalidNodeId) {
326  nodes.push(left);
327  }
328  if (right != RegTree::kInvalidNodeId) {
329  nodes.push(right);
330  }
331  }
332  }
339  [[nodiscard]] bool Equal(const RegTree& b) const;
340 
358  void ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value,
359  bool default_left, bst_float base_weight,
360  bst_float left_leaf_weight, bst_float right_leaf_weight,
361  bst_float loss_change, float sum_hess, float left_sum,
362  float right_sum,
363  bst_node_t leaf_right_child = kInvalidNodeId);
367  void ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left,
370  linalg::VectorView<float const> right_weight);
371 
388  common::Span<const uint32_t> split_cat, bool default_left,
389  bst_float base_weight, bst_float left_leaf_weight,
390  bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
391  float left_sum, float right_sum);
395  [[nodiscard]] bool HasCategoricalSplit() const { return !split_categories_.empty(); }
399  [[nodiscard]] bool IsMultiTarget() const { return static_cast<bool>(p_mt_tree_); }
403  [[nodiscard]] bst_target_t NumTargets() const { return param_.size_leaf_vector; }
407  [[nodiscard]] auto GetMultiTargetTree() const {
408  CHECK(IsMultiTarget());
409  return p_mt_tree_.get();
410  }
414  [[nodiscard]] bst_feature_t NumFeatures() const noexcept { return param_.num_feature; }
418  [[nodiscard]] bst_node_t NumNodes() const noexcept { return param_.num_nodes; }
422  [[nodiscard]] bst_node_t NumValidNodes() const noexcept {
423  return param_.num_nodes - param_.num_deleted;
424  }
428  [[nodiscard]] bst_node_t NumExtraNodes() const noexcept {
429  return param_.num_nodes - 1 - param_.num_deleted;
430  }
431  /* \brief Count number of leaves in tree. */
432  [[nodiscard]] bst_node_t GetNumLeaves() const;
433  [[nodiscard]] bst_node_t GetNumSplitNodes() const;
434 
439  [[nodiscard]] std::int32_t GetDepth(bst_node_t nid) const {
440  if (IsMultiTarget()) {
441  return this->p_mt_tree_->Depth(nid);
442  }
443  int depth = 0;
444  while (!nodes_[nid].IsRoot()) {
445  ++depth;
446  nid = nodes_[nid].Parent();
447  }
448  return depth;
449  }
454  CHECK(IsMultiTarget());
455  return this->p_mt_tree_->SetLeaf(nidx, weight);
456  }
457 
462  [[nodiscard]] int MaxDepth(int nid) const {
463  if (nodes_[nid].IsLeaf()) return 0;
464  return std::max(MaxDepth(nodes_[nid].LeftChild()) + 1, MaxDepth(nodes_[nid].RightChild()) + 1);
465  }
466 
470  int MaxDepth() { return MaxDepth(0); }
471 
476  struct FVec {
481  void Init(size_t size);
486  void Fill(SparsePage::Inst const& inst);
487 
492  void Drop();
497  [[nodiscard]] size_t Size() const;
503  [[nodiscard]] bst_float GetFvalue(size_t i) const;
509  [[nodiscard]] bool IsMissing(size_t i) const;
510  [[nodiscard]] bool HasMissing() const;
511  void HasMissing(bool has_missing) { this->has_missing_ = has_missing; }
512 
513  [[nodiscard]] common::Span<float> Data() { return data_; }
514 
515  private:
521  std::vector<float> data_;
522  bool has_missing_;
523  };
524 
532  [[nodiscard]] std::string DumpModel(const FeatureMap& fmap, bool with_stats,
533  std::string format) const;
539  [[nodiscard]] FeatureType NodeSplitType(bst_node_t nidx) const { return split_types_.at(nidx); }
543  [[nodiscard]] std::vector<FeatureType> const& GetSplitTypes() const {
544  return split_types_;
545  }
547  return split_categories_;
548  }
553  auto node_ptr = GetCategoriesMatrix().node_ptr;
554  auto categories = GetCategoriesMatrix().categories;
555  auto segment = node_ptr[nidx];
556  auto node_cats = categories.subspan(segment.beg, segment.size);
557  return node_cats;
558  }
559  [[nodiscard]] auto const& GetSplitCategoriesPtr() const { return split_categories_segments_; }
560 
569  struct Segment {
570  std::size_t beg{0};
571  std::size_t size{0};
572  };
576  };
577 
581  view.categories = this->GetSplitCategories();
582  view.node_ptr = common::Span<CategoricalSplitMatrix::Segment const>(split_categories_segments_);
583  return view;
584  }
585 
586  [[nodiscard]] bst_feature_t SplitIndex(bst_node_t nidx) const {
587  if (IsMultiTarget()) {
588  return this->p_mt_tree_->SplitIndex(nidx);
589  }
590  return (*this)[nidx].SplitIndex();
591  }
592  [[nodiscard]] float SplitCond(bst_node_t nidx) const {
593  if (IsMultiTarget()) {
594  return this->p_mt_tree_->SplitCond(nidx);
595  }
596  return (*this)[nidx].SplitCond();
597  }
598  [[nodiscard]] bool DefaultLeft(bst_node_t nidx) const {
599  if (IsMultiTarget()) {
600  return this->p_mt_tree_->DefaultLeft(nidx);
601  }
602  return (*this)[nidx].DefaultLeft();
603  }
604  [[nodiscard]] bst_node_t DefaultChild(bst_node_t nidx) const {
605  return this->DefaultLeft(nidx) ? this->LeftChild(nidx) : this->RightChild(nidx);
606  }
607  [[nodiscard]] bool IsRoot(bst_node_t nidx) const {
608  if (IsMultiTarget()) {
609  return nidx == kRoot;
610  }
611  return (*this)[nidx].IsRoot();
612  }
613  [[nodiscard]] bool IsLeaf(bst_node_t nidx) const {
614  if (IsMultiTarget()) {
615  return this->p_mt_tree_->IsLeaf(nidx);
616  }
617  return (*this)[nidx].IsLeaf();
618  }
619  [[nodiscard]] bst_node_t Parent(bst_node_t nidx) const {
620  if (IsMultiTarget()) {
621  return this->p_mt_tree_->Parent(nidx);
622  }
623  return (*this)[nidx].Parent();
624  }
625  [[nodiscard]] bst_node_t LeftChild(bst_node_t nidx) const {
626  if (IsMultiTarget()) {
627  return this->p_mt_tree_->LeftChild(nidx);
628  }
629  return (*this)[nidx].LeftChild();
630  }
631  [[nodiscard]] bst_node_t RightChild(bst_node_t nidx) const {
632  if (IsMultiTarget()) {
633  return this->p_mt_tree_->RightChild(nidx);
634  }
635  return (*this)[nidx].RightChild();
636  }
637  [[nodiscard]] bool IsLeftChild(bst_node_t nidx) const {
638  if (IsMultiTarget()) {
639  CHECK_NE(nidx, kRoot);
640  auto p = this->p_mt_tree_->Parent(nidx);
641  return nidx == this->p_mt_tree_->LeftChild(p);
642  }
643  return (*this)[nidx].IsLeftChild();
644  }
645  [[nodiscard]] bst_node_t Size() const {
646  if (IsMultiTarget()) {
647  return this->p_mt_tree_->Size();
648  }
649  return this->nodes_.size();
650  }
651 
652  private:
653  template <bool typed>
654  void LoadCategoricalSplit(Json const& in);
655  void SaveCategoricalSplit(Json* p_out) const;
657  TreeParam param_;
658  // vector of nodes
659  std::vector<Node> nodes_;
660  // free node space, used during training process
661  std::vector<int> deleted_nodes_;
662  // stats of nodes
663  std::vector<RTreeNodeStat> stats_;
664  std::vector<FeatureType> split_types_;
665 
666  // Categories for each internal node.
667  std::vector<uint32_t> split_categories_;
668  // Ptr to split categories of each node.
669  std::vector<CategoricalSplitMatrix::Segment> split_categories_segments_;
670  // ptr to multi-target tree with vector leaf.
672  // allocate a new node,
673  // !!!!!! NOTE: may cause BUG here, nodes.resize
674  bst_node_t AllocNode() {
675  if (param_.num_deleted != 0) {
676  int nid = deleted_nodes_.back();
677  deleted_nodes_.pop_back();
678  nodes_[nid].Reuse();
679  --param_.num_deleted;
680  return nid;
681  }
682  int nd = param_.num_nodes++;
683  CHECK_LT(param_.num_nodes, std::numeric_limits<int>::max())
684  << "number of nodes in the tree exceed 2^31";
685  nodes_.resize(param_.num_nodes);
686  stats_.resize(param_.num_nodes);
687  split_types_.resize(param_.num_nodes, FeatureType::kNumerical);
688  split_categories_segments_.resize(param_.num_nodes);
689  return nd;
690  }
691  // delete a tree node, keep the parent field to allow trace back
692  void DeleteNode(int nid) {
693  CHECK_GE(nid, 1);
694  auto pid = (*this)[nid].Parent();
695  if (nid == (*this)[pid].LeftChild()) {
696  (*this)[pid].SetLeftChild(kInvalidNodeId);
697  } else {
698  (*this)[pid].SetRightChild(kInvalidNodeId);
699  }
700 
701  deleted_nodes_.push_back(nid);
702  nodes_[nid].MarkDelete();
703  ++param_.num_deleted;
704  }
705 };
706 
707 inline void RegTree::FVec::Init(size_t size) {
708  data_.resize(size);
709  std::fill(data_.begin(), data_.end(), std::numeric_limits<float>::quiet_NaN());
710  has_missing_ = true;
711 }
712 
713 inline void RegTree::FVec::Fill(SparsePage::Inst const& inst) {
714  auto p_data = inst.data();
715  auto p_out = data_.data();
716 
717  for (std::size_t i = 0, n = inst.size(); i < n; ++i) {
718  auto const& entry = p_data[i];
719  p_out[entry.index] = entry.fvalue;
720  }
721  has_missing_ = data_.size() != inst.size();
722 }
723 
724 inline void RegTree::FVec::Drop() { this->Init(this->Size()); }
725 
726 inline size_t RegTree::FVec::Size() const {
727  return data_.size();
728 }
729 
730 inline float RegTree::FVec::GetFvalue(size_t i) const {
731  return data_[i];
732 }
733 
734 inline bool RegTree::FVec::IsMissing(size_t i) const { return std::isnan(data_[i]); }
735 
736 inline bool RegTree::FVec::HasMissing() const { return has_missing_; }
737 
738 // Multi-target tree not yet implemented error
740  return " support for multi-target tree is not yet implemented.";
741 }
742 } // namespace xgboost
743 #endif // XGBOOST_TREE_MODEL_H_
Defines configuration macros and basic types for xgboost.
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition: base.h:64
Helper for defining copyable data structure that contains unique pointers.
Definition: tree_model.h:73
T const * operator->() const noexcept
Definition: tree_model.h:90
T * get() const noexcept
Definition: tree_model.h:84
bool operator!() const
Definition: tree_model.h:93
CopyUniquePtr(CopyUniquePtr const &that)
Definition: tree_model.h:78
T * operator->() noexcept
Definition: tree_model.h:87
T & operator*()
Definition: tree_model.h:86
T const & operator*() const
Definition: tree_model.h:89
void reset(T *ptr)
Definition: tree_model.h:94
Feature map data structure to help text model dump. TODO(tqchen) consider make it even more lightweig...
Definition: feature_map.h:22
Data structure representing JSON format.
Definition: json.h:392
Tree structure for multi-target model.
Definition: multi_target_tree_model.h:69
static constexpr bst_node_t InvalidNodeId()
Definition: multi_target_tree_model.h:71
tree node
Definition: tree_model.h:110
XGBOOST_DEVICE int Parent() const
get parent of the node
Definition: tree_model.h:145
XGBOOST_DEVICE void MarkDelete()
mark that this node is deleted
Definition: tree_model.h:190
XGBOOST_DEVICE bool IsRoot() const
whether current node is root
Definition: tree_model.h:151
XGBOOST_DEVICE int RightChild() const
index of right child
Definition: tree_model.h:126
XGBOOST_DEVICE float LeafValue() const
Definition: tree_model.h:141
XGBOOST_DEVICE Node()
Definition: tree_model.h:112
XGBOOST_DEVICE void SetParent(int pidx, bool is_left_child=true)
Definition: tree_model.h:198
XGBOOST_DEVICE void SetLeaf(bst_float value, int right=kInvalidNodeId)
set the leaf value of the node
Definition: tree_model.h:184
XGBOOST_DEVICE bool IsLeftChild() const
whether current node is left child
Definition: tree_model.h:147
XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond, bool default_left=false)
set split condition of current node
Definition: tree_model.h:172
XGBOOST_DEVICE void SetLeftChild(int nid)
set the left child
Definition: tree_model.h:156
XGBOOST_DEVICE bst_feature_t SplitIndex() const
feature index of split condition
Definition: tree_model.h:132
XGBOOST_DEVICE bool IsDeleted() const
whether this node is deleted
Definition: tree_model.h:149
XGBOOST_DEVICE bool IsLeaf() const
whether current node is leaf node
Definition: tree_model.h:139
bool operator==(const Node &b) const
Definition: tree_model.h:202
Node(int32_t cleft, int32_t cright, int32_t parent, uint32_t split_ind, float split_cond, bool default_left)
Definition: tree_model.h:116
XGBOOST_DEVICE void Reuse()
Reuse this deleted node.
Definition: tree_model.h:194
XGBOOST_DEVICE void SetRightChild(int nid)
set the right child
Definition: tree_model.h:163
XGBOOST_DEVICE bool DefaultLeft() const
when feature is unknown, whether goes to left child
Definition: tree_model.h:137
XGBOOST_DEVICE int LeftChild() const
index of left child
Definition: tree_model.h:124
XGBOOST_DEVICE int DefaultChild() const
index of default child when feature is missing
Definition: tree_model.h:128
XGBOOST_DEVICE SplitCondT SplitCond() const
Definition: tree_model.h:143
define regression tree to be the most common tree model.
Definition: tree_model.h:102
int MaxDepth(int nid) const
get maximum depth
Definition: tree_model.h:462
void SaveModel(Json *out) const override
saves the model config to a JSON object
bst_target_t NumTargets() const
The size of leaf weight.
Definition: tree_model.h:403
void WalkTree(Func func) const
Definition: tree_model.h:313
bool IsLeaf(bst_node_t nidx) const
Definition: tree_model.h:613
bool operator==(const RegTree &b) const
Definition: tree_model.h:304
const RTreeNodeStat & Stat(int nid) const
get node statistics given nid
Definition: tree_model.h:297
void ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left, linalg::VectorView< float const > base_weight, linalg::VectorView< float const > left_weight, linalg::VectorView< float const > right_weight)
Expands a leaf node into two additional leaf nodes for a multi-target tree.
bst_node_t Parent(bst_node_t nidx) const
Definition: tree_model.h:619
bst_node_t NumNodes() const noexcept
Get the total number of nodes including deleted ones in this tree.
Definition: tree_model.h:418
const Node & operator[](int nid) const
get node given nid
Definition: tree_model.h:282
bst_node_t DefaultChild(bst_node_t nidx) const
Definition: tree_model.h:604
void ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum, bst_node_t leaf_right_child=kInvalidNodeId)
Expands a leaf node into two additional leaf nodes.
Node & operator[](int nid)
get node given nid
Definition: tree_model.h:278
RegTree()
Definition: tree_model.h:256
static constexpr bst_node_t kInvalidNodeId
Definition: tree_model.h:105
bst_feature_t SplitIndex(bst_node_t nidx) const
Definition: tree_model.h:586
bool IsRoot(bst_node_t nidx) const
Definition: tree_model.h:607
static constexpr uint32_t kDeletedNodeMarker
Definition: tree_model.h:106
bool IsMultiTarget() const
Whether this is a multi-target tree.
Definition: tree_model.h:399
bst_node_t NumExtraNodes() const noexcept
number of extra nodes besides the root
Definition: tree_model.h:428
bool DefaultLeft(bst_node_t nidx) const
Definition: tree_model.h:598
auto GetMultiTargetTree() const
Get the underlying implementaiton of multi-target tree.
Definition: tree_model.h:407
bst_node_t LeftChild(bst_node_t nidx) const
Definition: tree_model.h:625
bst_node_t GetNumLeaves() const
RegTree(bst_target_t n_targets, bst_feature_t n_features)
Constructor that initializes the tree model with shape.
Definition: tree_model.h:269
bst_node_t RightChild(bst_node_t nidx) const
Definition: tree_model.h:631
common::Span< uint32_t const > NodeCats(bst_node_t nidx) const
Get the bit storage for categories.
Definition: tree_model.h:552
bool IsLeftChild(bst_node_t nidx) const
Definition: tree_model.h:637
CategoricalSplitMatrix GetCategoriesMatrix() const
Definition: tree_model.h:578
RTreeNodeStat & Stat(int nid)
get node statistics given nid
Definition: tree_model.h:293
bst_float SplitCondT
Definition: tree_model.h:104
void ExpandCategorical(bst_node_t nid, bst_feature_t split_index, common::Span< const uint32_t > split_cat, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum)
Expands a leaf node with categories.
bool Equal(const RegTree &b) const
Compares whether 2 trees are equal from a user's perspective. The equality compares only non-deleted ...
std::vector< FeatureType > const & GetSplitTypes() const
Get split types for all nodes.
Definition: tree_model.h:543
void CollapseToLeaf(int rid, bst_float value)
collapse a non leaf node to a leaf node, delete its children
Definition: tree_model.h:245
bst_node_t NumValidNodes() const noexcept
Get the total number of valid nodes in this tree.
Definition: tree_model.h:422
void ChangeToLeaf(int rid, bst_float value)
change a non leaf node to a leaf node, delete its children
Definition: tree_model.h:233
const std::vector< RTreeNodeStat > & GetStats() const
get const reference to stats
Definition: tree_model.h:290
void SetLeaf(bst_node_t nidx, linalg::VectorView< float const > weight)
Set the leaf weight for a multi-target tree.
Definition: tree_model.h:453
const std::vector< Node > & GetNodes() const
get const reference to nodes
Definition: tree_model.h:287
void LoadModel(Json const &in) override
load the model from a JSON object
std::string DumpModel(const FeatureMap &fmap, bool with_stats, std::string format) const
dump the model in the requested format as a text string
FeatureType NodeSplitType(bst_node_t nidx) const
Get split type for a node.
Definition: tree_model.h:539
bst_feature_t NumFeatures() const noexcept
Get the number of features.
Definition: tree_model.h:414
common::Span< uint32_t const > GetSplitCategories() const
Definition: tree_model.h:546
bool HasCategoricalSplit() const
Whether this tree has categorical split.
Definition: tree_model.h:395
std::int32_t GetDepth(bst_node_t nid) const
get current depth
Definition: tree_model.h:439
static constexpr bst_node_t kRoot
Definition: tree_model.h:107
bst_node_t GetNumSplitNodes() const
auto const & GetSplitCategoriesPtr() const
Definition: tree_model.h:559
float SplitCond(bst_node_t nidx) const
Definition: tree_model.h:592
int MaxDepth()
get maximum depth
Definition: tree_model.h:470
bst_node_t Size() const
Definition: tree_model.h:645
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition: span.h:431
constexpr XGBOOST_DEVICE pointer data() const __span_noexcept
Definition: span.h:550
XGBOOST_DEVICE auto subspan() const -> Span< element_type, detail::ExtentValue< Extent, Offset, Count >::value >
Definition: span.h:597
constexpr XGBOOST_DEVICE index_type size() const __span_noexcept
Definition: span.h:555
A tensor view with static type and dimension. It implements indexing and slicing.
Definition: linalg.h:277
The input data structure of xgboost.
Feature map data structure to help visualization and model dump.
Linear algebra related utilities.
Defines the abstract interface for different components in XGBoost.
Learner interface that integrates objective, gbm and evaluation together. This is the user facing XGB...
Definition: base.h:97
std::int32_t bst_node_t
Type for tree node index.
Definition: base.h:119
FeatureType
Definition: data.h:41
std::uint32_t bst_target_t
Type for indexing into output targets.
Definition: base.h:127
std::uint32_t bst_feature_t
Type for data column (feature) index.
Definition: base.h:107
float bst_float
float type, used for storing statistics
Definition: base.h:103
StringView MTNotImplemented()
Definition: tree_model.h:739
Definition: model.h:14
node statistics used in regression tree
Definition: tree_model.h:50
RTreeNodeStat(float loss_chg, float sum_hess, float weight)
Definition: tree_model.h:61
bst_float loss_chg
loss change caused by current split
Definition: tree_model.h:52
int leaf_child_cnt
number of child that is leaf node known up to now
Definition: tree_model.h:58
bst_float sum_hess
sum of hessian values, used to measure coverage of data
Definition: tree_model.h:54
bool operator==(const RTreeNodeStat &b) const
Definition: tree_model.h:63
bst_float base_weight
weight of current node
Definition: tree_model.h:56
std::size_t size
Definition: tree_model.h:571
std::size_t beg
Definition: tree_model.h:570
CSR-like matrix for categorical splits.
Definition: tree_model.h:568
common::Span< uint32_t const > categories
Definition: tree_model.h:574
common::Span< Segment const > node_ptr
Definition: tree_model.h:575
common::Span< FeatureType const > split_type
Definition: tree_model.h:573
dense feature vector that can be taken by RegTree and can be construct from sparse feature vector.
Definition: tree_model.h:476
void HasMissing(bool has_missing)
Definition: tree_model.h:511
void Drop()
drop the trace after fill, must be called after fill.
Definition: tree_model.h:724
bool HasMissing() const
Definition: tree_model.h:736
bool IsMissing(size_t i) const
check whether i-th entry is missing
Definition: tree_model.h:734
size_t Size() const
returns the size of the feature vector
Definition: tree_model.h:726
void Init(size_t size)
initialize the vector with size vector
Definition: tree_model.h:707
common::Span< float > Data()
Definition: tree_model.h:513
void Fill(SparsePage::Inst const &inst)
fill the vector with sparse vector
Definition: tree_model.h:713
bst_float GetFvalue(size_t i) const
get ith value
Definition: tree_model.h:730
Definition: string_view.h:16
meta parameters of the tree
Definition: tree_model.h:30
bst_node_t num_deleted
The number of deleted nodes.
Definition: tree_model.h:34
bst_feature_t num_feature
The number of features used for tree construction.
Definition: tree_model.h:36
bool operator==(const TreeParam &b) const
Definition: tree_model.h:40
bst_node_t num_nodes
The number of nodes.
Definition: tree_model.h:32
void ToJson(Json *p_out) const
void FromJson(Json const &in)
bst_target_t size_leaf_vector
leaf vector size. Used by the vector leaf.
Definition: tree_model.h:38