xgboost
tree_model.h
Go to the documentation of this file.
1 
7 #ifndef XGBOOST_TREE_MODEL_H_
8 #define XGBOOST_TREE_MODEL_H_
9 
10 #include <dmlc/io.h>
11 #include <dmlc/parameter.h>
12 #include <xgboost/base.h>
13 #include <xgboost/data.h>
14 #include <xgboost/feature_map.h>
15 #include <xgboost/linalg.h> // for VectorView
16 #include <xgboost/logging.h>
17 #include <xgboost/model.h>
18 #include <xgboost/multi_target_tree_model.h> // for MultiTargetTree
19 
20 #include <algorithm>
21 #include <cstring>
22 #include <limits>
23 #include <memory> // for make_unique
24 #include <stack>
25 #include <string>
26 #include <tuple>
27 #include <vector>
28 
29 namespace xgboost {
30 class Json;
31 
32 // FIXME(trivialfis): Once binary IO is gone, make this parameter internal as it should
33 // not be configured by users.
35 struct TreeParam : public dmlc::Parameter<TreeParam> {
39  int num_nodes{1};
41  int num_deleted{0};
52  int reserved[31];
55  // assert compact alignment
56  static_assert(sizeof(TreeParam) == (31 + 6) * sizeof(int), "TreeParam: 64 bit align");
57  std::memset(reserved, 0, sizeof(reserved));
58  }
59 
60  // Swap byte order for all fields. Useful for transporting models between machines with different
61  // endianness (big endian vs little endian)
62  [[nodiscard]] TreeParam ByteSwap() const {
63  TreeParam x = *this;
64  dmlc::ByteSwap(&x.deprecated_num_roots, sizeof(x.deprecated_num_roots), 1);
65  dmlc::ByteSwap(&x.num_nodes, sizeof(x.num_nodes), 1);
66  dmlc::ByteSwap(&x.num_deleted, sizeof(x.num_deleted), 1);
67  dmlc::ByteSwap(&x.deprecated_max_depth, sizeof(x.deprecated_max_depth), 1);
68  dmlc::ByteSwap(&x.num_feature, sizeof(x.num_feature), 1);
69  dmlc::ByteSwap(&x.size_leaf_vector, sizeof(x.size_leaf_vector), 1);
70  dmlc::ByteSwap(x.reserved, sizeof(x.reserved[0]), sizeof(x.reserved) / sizeof(x.reserved[0]));
71  return x;
72  }
73 
74  // declare the parameters
76  // only declare the parameters that can be set by the user.
77  // other arguments are set by the algorithm.
78  DMLC_DECLARE_FIELD(num_nodes).set_lower_bound(1).set_default(1);
79  DMLC_DECLARE_FIELD(num_feature)
80  .set_default(0)
81  .describe("Number of features used in tree construction.");
82  DMLC_DECLARE_FIELD(num_deleted).set_default(0);
83  DMLC_DECLARE_FIELD(size_leaf_vector)
84  .set_lower_bound(0)
85  .set_default(1)
86  .describe("Size of leaf vector, reserved for vector tree");
87  }
88 
89  bool operator==(const TreeParam& b) const {
90  return num_nodes == b.num_nodes && num_deleted == b.num_deleted &&
92  }
93 };
94 
96 struct RTreeNodeStat {
104  int leaf_child_cnt {0};
105 
106  RTreeNodeStat() = default;
107  RTreeNodeStat(float loss_chg, float sum_hess, float weight) :
109  bool operator==(const RTreeNodeStat& b) const {
110  return loss_chg == b.loss_chg && sum_hess == b.sum_hess &&
112  }
113  // Swap byte order for all fields. Useful for transporting models between machines with different
114  // endianness (big endian vs little endian)
115  [[nodiscard]] RTreeNodeStat ByteSwap() const {
116  RTreeNodeStat x = *this;
117  dmlc::ByteSwap(&x.loss_chg, sizeof(x.loss_chg), 1);
118  dmlc::ByteSwap(&x.sum_hess, sizeof(x.sum_hess), 1);
119  dmlc::ByteSwap(&x.base_weight, sizeof(x.base_weight), 1);
120  dmlc::ByteSwap(&x.leaf_child_cnt, sizeof(x.leaf_child_cnt), 1);
121  return x;
122  }
123 };
124 
128 template <typename T>
130  std::unique_ptr<T> ptr_{nullptr};
131 
132  public:
133  CopyUniquePtr() = default;
135  ptr_.reset(nullptr);
136  if (that.ptr_) {
137  ptr_ = std::make_unique<T>(*that);
138  }
139  }
140  T* get() const noexcept { return ptr_.get(); } // NOLINT
141 
142  T& operator*() { return *ptr_; }
143  T* operator->() noexcept { return this->get(); }
144 
145  T const& operator*() const { return *ptr_; }
146  T const* operator->() const noexcept { return this->get(); }
147 
148  explicit operator bool() const { return static_cast<bool>(ptr_); }
149  bool operator!() const { return !ptr_; }
150  void reset(T* ptr) { ptr_.reset(ptr); } // NOLINT
151 };
152 
158 class RegTree : public Model {
159  public:
162  static constexpr uint32_t kDeletedNodeMarker = std::numeric_limits<uint32_t>::max();
163  static constexpr bst_node_t kRoot{0};
164 
166  class Node {
167  public:
169  // assert compact alignment
170  static_assert(sizeof(Node) == 4 * sizeof(int) + sizeof(Info),
171  "Node: 64 bit align");
172  }
173  Node(int32_t cleft, int32_t cright, int32_t parent,
174  uint32_t split_ind, float split_cond, bool default_left) :
175  parent_{parent}, cleft_{cleft}, cright_{cright} {
176  this->SetParent(parent_);
177  this->SetSplit(split_ind, split_cond, default_left);
178  }
179 
181  [[nodiscard]] XGBOOST_DEVICE int LeftChild() const { return this->cleft_; }
183  [[nodiscard]] XGBOOST_DEVICE int RightChild() const { return this->cright_; }
185  [[nodiscard]] XGBOOST_DEVICE int DefaultChild() const {
186  return this->DefaultLeft() ? this->LeftChild() : this->RightChild();
187  }
189  [[nodiscard]] XGBOOST_DEVICE unsigned SplitIndex() const {
190  return sindex_ & ((1U << 31) - 1U);
191  }
193  [[nodiscard]] XGBOOST_DEVICE bool DefaultLeft() const { return (sindex_ >> 31) != 0; }
195  [[nodiscard]] XGBOOST_DEVICE bool IsLeaf() const { return cleft_ == kInvalidNodeId; }
197  [[nodiscard]] XGBOOST_DEVICE float LeafValue() const { return (this->info_).leaf_value; }
199  [[nodiscard]] XGBOOST_DEVICE SplitCondT SplitCond() const { return (this->info_).split_cond; }
201  [[nodiscard]] XGBOOST_DEVICE int Parent() const { return parent_ & ((1U << 31) - 1); }
203  [[nodiscard]] XGBOOST_DEVICE bool IsLeftChild() const { return (parent_ & (1U << 31)) != 0; }
205  [[nodiscard]] XGBOOST_DEVICE bool IsDeleted() const { return sindex_ == kDeletedNodeMarker; }
207  [[nodiscard]] XGBOOST_DEVICE bool IsRoot() const { return parent_ == kInvalidNodeId; }
212  XGBOOST_DEVICE void SetLeftChild(int nid) {
213  this->cleft_ = nid;
214  }
220  this->cright_ = nid;
221  }
228  XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond,
229  bool default_left = false) {
230  if (default_left) split_index |= (1U << 31);
231  this->sindex_ = split_index;
232  (this->info_).split_cond = split_cond;
233  }
240  XGBOOST_DEVICE void SetLeaf(bst_float value, int right = kInvalidNodeId) {
241  (this->info_).leaf_value = value;
242  this->cleft_ = kInvalidNodeId;
243  this->cright_ = right;
244  }
247  this->sindex_ = kDeletedNodeMarker;
248  }
251  this->sindex_ = 0;
252  }
253  // set parent
254  XGBOOST_DEVICE void SetParent(int pidx, bool is_left_child = true) {
255  if (is_left_child) pidx |= (1U << 31);
256  this->parent_ = pidx;
257  }
258  bool operator==(const Node& b) const {
259  return parent_ == b.parent_ && cleft_ == b.cleft_ &&
260  cright_ == b.cright_ && sindex_ == b.sindex_ &&
261  info_.leaf_value == b.info_.leaf_value;
262  }
263 
264  [[nodiscard]] Node ByteSwap() const {
265  Node x = *this;
266  dmlc::ByteSwap(&x.parent_, sizeof(x.parent_), 1);
267  dmlc::ByteSwap(&x.cleft_, sizeof(x.cleft_), 1);
268  dmlc::ByteSwap(&x.cright_, sizeof(x.cright_), 1);
269  dmlc::ByteSwap(&x.sindex_, sizeof(x.sindex_), 1);
270  dmlc::ByteSwap(&x.info_, sizeof(x.info_), 1);
271  return x;
272  }
273 
274  private:
279  union Info{
280  bst_float leaf_value;
281  SplitCondT split_cond;
282  };
283  // pointer to parent, highest bit is used to
284  // indicate whether it's a left child or not
285  int32_t parent_{kInvalidNodeId};
286  // pointer to left, right
287  int32_t cleft_{kInvalidNodeId}, cright_{kInvalidNodeId};
288  // split feature index, left split or right split depends on the highest bit
289  uint32_t sindex_{0};
290  // extra info
291  Info info_;
292  };
293 
299  void ChangeToLeaf(int rid, bst_float value) {
300  CHECK(nodes_[nodes_[rid].LeftChild() ].IsLeaf());
301  CHECK(nodes_[nodes_[rid].RightChild()].IsLeaf());
302  this->DeleteNode(nodes_[rid].LeftChild());
303  this->DeleteNode(nodes_[rid].RightChild());
304  nodes_[rid].SetLeaf(value);
305  }
311  void CollapseToLeaf(int rid, bst_float value) {
312  if (nodes_[rid].IsLeaf()) return;
313  if (!nodes_[nodes_[rid].LeftChild() ].IsLeaf()) {
314  CollapseToLeaf(nodes_[rid].LeftChild(), 0.0f);
315  }
316  if (!nodes_[nodes_[rid].RightChild() ].IsLeaf()) {
317  CollapseToLeaf(nodes_[rid].RightChild(), 0.0f);
318  }
319  this->ChangeToLeaf(rid, value);
320  }
321 
323  param_.Init(Args{});
324  nodes_.resize(param_.num_nodes);
325  stats_.resize(param_.num_nodes);
326  split_types_.resize(param_.num_nodes, FeatureType::kNumerical);
327  split_categories_segments_.resize(param_.num_nodes);
328  for (int i = 0; i < param_.num_nodes; i++) {
329  nodes_[i].SetLeaf(0.0f);
330  nodes_[i].SetParent(kInvalidNodeId);
331  }
332  }
336  explicit RegTree(bst_target_t n_targets, bst_feature_t n_features) : RegTree{} {
337  param_.num_feature = n_features;
338  param_.size_leaf_vector = n_targets;
339  if (n_targets > 1) {
340  this->p_mt_tree_.reset(new MultiTargetTree{&param_});
341  }
342  }
343 
345  Node& operator[](int nid) {
346  return nodes_[nid];
347  }
349  const Node& operator[](int nid) const {
350  return nodes_[nid];
351  }
352 
354  [[nodiscard]] const std::vector<Node>& GetNodes() const { return nodes_; }
355 
357  [[nodiscard]] const std::vector<RTreeNodeStat>& GetStats() const { return stats_; }
358 
360  RTreeNodeStat& Stat(int nid) {
361  return stats_[nid];
362  }
364  [[nodiscard]] const RTreeNodeStat& Stat(int nid) const {
365  return stats_[nid];
366  }
367 
372  void Load(dmlc::Stream* fi);
377  void Save(dmlc::Stream* fo) const;
378 
379  void LoadModel(Json const& in) override;
380  void SaveModel(Json* out) const override;
381 
382  bool operator==(const RegTree& b) const {
383  return nodes_ == b.nodes_ && stats_ == b.stats_ &&
384  deleted_nodes_ == b.deleted_nodes_ && param_ == b.param_;
385  }
386  /* \brief Iterate through all nodes in this tree.
387  *
388  * \param Function that accepts a node index, and returns false when iteration should
389  * stop, otherwise returns true.
390  */
391  template <typename Func> void WalkTree(Func func) const {
392  std::stack<bst_node_t> nodes;
393  nodes.push(kRoot);
394  auto &self = *this;
395  while (!nodes.empty()) {
396  auto nidx = nodes.top();
397  nodes.pop();
398  if (!func(nidx)) {
399  return;
400  }
401  auto left = self[nidx].LeftChild();
402  auto right = self[nidx].RightChild();
403  if (left != RegTree::kInvalidNodeId) {
404  nodes.push(left);
405  }
406  if (right != RegTree::kInvalidNodeId) {
407  nodes.push(right);
408  }
409  }
410  }
417  [[nodiscard]] bool Equal(const RegTree& b) const;
418 
436  void ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value,
437  bool default_left, bst_float base_weight,
438  bst_float left_leaf_weight, bst_float right_leaf_weight,
439  bst_float loss_change, float sum_hess, float left_sum,
440  float right_sum,
441  bst_node_t leaf_right_child = kInvalidNodeId);
445  void ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left,
448  linalg::VectorView<float const> right_weight);
449 
466  common::Span<const uint32_t> split_cat, bool default_left,
467  bst_float base_weight, bst_float left_leaf_weight,
468  bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
469  float left_sum, float right_sum);
473  [[nodiscard]] bool HasCategoricalSplit() const { return !split_categories_.empty(); }
477  [[nodiscard]] bool IsMultiTarget() const { return static_cast<bool>(p_mt_tree_); }
481  [[nodiscard]] bst_target_t NumTargets() const { return param_.size_leaf_vector; }
485  [[nodiscard]] auto GetMultiTargetTree() const {
486  CHECK(IsMultiTarget());
487  return p_mt_tree_.get();
488  }
492  [[nodiscard]] bst_feature_t NumFeatures() const noexcept { return param_.num_feature; }
496  [[nodiscard]] bst_node_t NumNodes() const noexcept { return param_.num_nodes; }
500  [[nodiscard]] bst_node_t NumValidNodes() const noexcept {
501  return param_.num_nodes - param_.num_deleted;
502  }
506  [[nodiscard]] bst_node_t NumExtraNodes() const noexcept {
507  return param_.num_nodes - 1 - param_.num_deleted;
508  }
509  /* \brief Count number of leaves in tree. */
510  [[nodiscard]] bst_node_t GetNumLeaves() const;
511  [[nodiscard]] bst_node_t GetNumSplitNodes() const;
512 
517  [[nodiscard]] std::int32_t GetDepth(bst_node_t nid) const {
518  if (IsMultiTarget()) {
519  return this->p_mt_tree_->Depth(nid);
520  }
521  int depth = 0;
522  while (!nodes_[nid].IsRoot()) {
523  ++depth;
524  nid = nodes_[nid].Parent();
525  }
526  return depth;
527  }
532  CHECK(IsMultiTarget());
533  return this->p_mt_tree_->SetLeaf(nidx, weight);
534  }
535 
540  [[nodiscard]] int MaxDepth(int nid) const {
541  if (nodes_[nid].IsLeaf()) return 0;
542  return std::max(MaxDepth(nodes_[nid].LeftChild()) + 1, MaxDepth(nodes_[nid].RightChild()) + 1);
543  }
544 
548  int MaxDepth() { return MaxDepth(0); }
549 
554  struct FVec {
559  void Init(size_t size);
564  void Fill(const SparsePage::Inst& inst);
565 
570  void Drop();
575  [[nodiscard]] size_t Size() const;
581  [[nodiscard]] bst_float GetFvalue(size_t i) const;
587  [[nodiscard]] bool IsMissing(size_t i) const;
588  [[nodiscard]] bool HasMissing() const;
589 
590 
591  private:
596  union Entry {
597  bst_float fvalue;
598  int flag;
599  };
600  std::vector<Entry> data_;
601  bool has_missing_;
602  };
603 
610  std::vector<float>* mean_values,
611  bst_float* out_contribs) const;
619  [[nodiscard]] std::string DumpModel(const FeatureMap& fmap, bool with_stats,
620  std::string format) const;
626  [[nodiscard]] FeatureType NodeSplitType(bst_node_t nidx) const { return split_types_.at(nidx); }
630  [[nodiscard]] std::vector<FeatureType> const& GetSplitTypes() const {
631  return split_types_;
632  }
634  return split_categories_;
635  }
640  auto node_ptr = GetCategoriesMatrix().node_ptr;
641  auto categories = GetCategoriesMatrix().categories;
642  auto segment = node_ptr[nidx];
643  auto node_cats = categories.subspan(segment.beg, segment.size);
644  return node_cats;
645  }
646  [[nodiscard]] auto const& GetSplitCategoriesPtr() const { return split_categories_segments_; }
647 
656  struct Segment {
657  std::size_t beg{0};
658  std::size_t size{0};
659  };
663  };
664 
668  view.categories = this->GetSplitCategories();
669  view.node_ptr = common::Span<CategoricalSplitMatrix::Segment const>(split_categories_segments_);
670  return view;
671  }
672 
673  [[nodiscard]] bst_feature_t SplitIndex(bst_node_t nidx) const {
674  if (IsMultiTarget()) {
675  return this->p_mt_tree_->SplitIndex(nidx);
676  }
677  return (*this)[nidx].SplitIndex();
678  }
679  [[nodiscard]] float SplitCond(bst_node_t nidx) const {
680  if (IsMultiTarget()) {
681  return this->p_mt_tree_->SplitCond(nidx);
682  }
683  return (*this)[nidx].SplitCond();
684  }
685  [[nodiscard]] bool DefaultLeft(bst_node_t nidx) const {
686  if (IsMultiTarget()) {
687  return this->p_mt_tree_->DefaultLeft(nidx);
688  }
689  return (*this)[nidx].DefaultLeft();
690  }
691  [[nodiscard]] bool IsRoot(bst_node_t nidx) const {
692  if (IsMultiTarget()) {
693  return nidx == kRoot;
694  }
695  return (*this)[nidx].IsRoot();
696  }
697  [[nodiscard]] bool IsLeaf(bst_node_t nidx) const {
698  if (IsMultiTarget()) {
699  return this->p_mt_tree_->IsLeaf(nidx);
700  }
701  return (*this)[nidx].IsLeaf();
702  }
703  [[nodiscard]] bst_node_t Parent(bst_node_t nidx) const {
704  if (IsMultiTarget()) {
705  return this->p_mt_tree_->Parent(nidx);
706  }
707  return (*this)[nidx].Parent();
708  }
709  [[nodiscard]] bst_node_t LeftChild(bst_node_t nidx) const {
710  if (IsMultiTarget()) {
711  return this->p_mt_tree_->LeftChild(nidx);
712  }
713  return (*this)[nidx].LeftChild();
714  }
715  [[nodiscard]] bst_node_t RightChild(bst_node_t nidx) const {
716  if (IsMultiTarget()) {
717  return this->p_mt_tree_->RightChild(nidx);
718  }
719  return (*this)[nidx].RightChild();
720  }
721  [[nodiscard]] bool IsLeftChild(bst_node_t nidx) const {
722  if (IsMultiTarget()) {
723  CHECK_NE(nidx, kRoot);
724  auto p = this->p_mt_tree_->Parent(nidx);
725  return nidx == this->p_mt_tree_->LeftChild(p);
726  }
727  return (*this)[nidx].IsLeftChild();
728  }
729  [[nodiscard]] bst_node_t Size() const {
730  if (IsMultiTarget()) {
731  return this->p_mt_tree_->Size();
732  }
733  return this->nodes_.size();
734  }
735 
736  private:
737  template <bool typed>
738  void LoadCategoricalSplit(Json const& in);
739  void SaveCategoricalSplit(Json* p_out) const;
741  TreeParam param_;
742  // vector of nodes
743  std::vector<Node> nodes_;
744  // free node space, used during training process
745  std::vector<int> deleted_nodes_;
746  // stats of nodes
747  std::vector<RTreeNodeStat> stats_;
748  std::vector<FeatureType> split_types_;
749 
750  // Categories for each internal node.
751  std::vector<uint32_t> split_categories_;
752  // Ptr to split categories of each node.
753  std::vector<CategoricalSplitMatrix::Segment> split_categories_segments_;
754  // ptr to multi-target tree with vector leaf.
756  // allocate a new node,
757  // !!!!!! NOTE: may cause BUG here, nodes.resize
758  bst_node_t AllocNode() {
759  if (param_.num_deleted != 0) {
760  int nid = deleted_nodes_.back();
761  deleted_nodes_.pop_back();
762  nodes_[nid].Reuse();
763  --param_.num_deleted;
764  return nid;
765  }
766  int nd = param_.num_nodes++;
767  CHECK_LT(param_.num_nodes, std::numeric_limits<int>::max())
768  << "number of nodes in the tree exceed 2^31";
769  nodes_.resize(param_.num_nodes);
770  stats_.resize(param_.num_nodes);
771  split_types_.resize(param_.num_nodes, FeatureType::kNumerical);
772  split_categories_segments_.resize(param_.num_nodes);
773  return nd;
774  }
775  // delete a tree node, keep the parent field to allow trace back
776  void DeleteNode(int nid) {
777  CHECK_GE(nid, 1);
778  auto pid = (*this)[nid].Parent();
779  if (nid == (*this)[pid].LeftChild()) {
780  (*this)[pid].SetLeftChild(kInvalidNodeId);
781  } else {
782  (*this)[pid].SetRightChild(kInvalidNodeId);
783  }
784 
785  deleted_nodes_.push_back(nid);
786  nodes_[nid].MarkDelete();
787  ++param_.num_deleted;
788  }
789 };
790 
791 inline void RegTree::FVec::Init(size_t size) {
792  Entry e; e.flag = -1;
793  data_.resize(size);
794  std::fill(data_.begin(), data_.end(), e);
795  has_missing_ = true;
796 }
797 
798 inline void RegTree::FVec::Fill(const SparsePage::Inst& inst) {
799  size_t feature_count = 0;
800  for (auto const& entry : inst) {
801  if (entry.index >= data_.size()) {
802  continue;
803  }
804  data_[entry.index].fvalue = entry.fvalue;
805  ++feature_count;
806  }
807  has_missing_ = data_.size() != feature_count;
808 }
809 
810 inline void RegTree::FVec::Drop() {
811  Entry e{};
812  e.flag = -1;
813  std::fill_n(data_.data(), data_.size(), e);
814  has_missing_ = true;
815 }
816 
817 inline size_t RegTree::FVec::Size() const {
818  return data_.size();
819 }
820 
821 inline bst_float RegTree::FVec::GetFvalue(size_t i) const {
822  return data_[i].fvalue;
823 }
824 
825 inline bool RegTree::FVec::IsMissing(size_t i) const {
826  return data_[i].flag == -1;
827 }
828 
829 inline bool RegTree::FVec::HasMissing() const {
830  return has_missing_;
831 }
832 
833 // Multi-target tree not yet implemented error
835  return " support for multi-target tree is not yet implemented.";
836 }
837 } // namespace xgboost
838 #endif // XGBOOST_TREE_MODEL_H_
Defines configuration macros and basic types for xgboost.
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition: base.h:64
Helper for defining copyable data structure that contains unique pointers.
Definition: tree_model.h:129
T const * operator->() const noexcept
Definition: tree_model.h:146
T * get() const noexcept
Definition: tree_model.h:140
bool operator!() const
Definition: tree_model.h:149
CopyUniquePtr(CopyUniquePtr const &that)
Definition: tree_model.h:134
T * operator->() noexcept
Definition: tree_model.h:143
T & operator*()
Definition: tree_model.h:142
T const & operator*() const
Definition: tree_model.h:145
void reset(T *ptr)
Definition: tree_model.h:150
Feature map data structure to help text model dump. TODO(tqchen) consider make it even more lightweig...
Definition: feature_map.h:22
Data structure representing JSON format.
Definition: json.h:357
Tree structure for multi-target model.
Definition: multi_target_tree_model.h:23
static constexpr bst_node_t InvalidNodeId()
Definition: multi_target_tree_model.h:25
tree node
Definition: tree_model.h:166
XGBOOST_DEVICE int Parent() const
get parent of the node
Definition: tree_model.h:201
XGBOOST_DEVICE void MarkDelete()
mark that this node is deleted
Definition: tree_model.h:246
XGBOOST_DEVICE bool IsRoot() const
whether current node is root
Definition: tree_model.h:207
XGBOOST_DEVICE int RightChild() const
index of right child
Definition: tree_model.h:183
XGBOOST_DEVICE float LeafValue() const
Definition: tree_model.h:197
XGBOOST_DEVICE Node()
Definition: tree_model.h:168
XGBOOST_DEVICE void SetParent(int pidx, bool is_left_child=true)
Definition: tree_model.h:254
XGBOOST_DEVICE unsigned SplitIndex() const
feature index of split condition
Definition: tree_model.h:189
Node ByteSwap() const
Definition: tree_model.h:264
XGBOOST_DEVICE void SetLeaf(bst_float value, int right=kInvalidNodeId)
set the leaf value of the node
Definition: tree_model.h:240
XGBOOST_DEVICE bool IsLeftChild() const
whether current node is left child
Definition: tree_model.h:203
XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond, bool default_left=false)
set split condition of current node
Definition: tree_model.h:228
XGBOOST_DEVICE void SetLeftChild(int nid)
set the left child
Definition: tree_model.h:212
XGBOOST_DEVICE bool IsDeleted() const
whether this node is deleted
Definition: tree_model.h:205
XGBOOST_DEVICE bool IsLeaf() const
whether current node is leaf node
Definition: tree_model.h:195
bool operator==(const Node &b) const
Definition: tree_model.h:258
Node(int32_t cleft, int32_t cright, int32_t parent, uint32_t split_ind, float split_cond, bool default_left)
Definition: tree_model.h:173
XGBOOST_DEVICE void Reuse()
Reuse this deleted node.
Definition: tree_model.h:250
XGBOOST_DEVICE void SetRightChild(int nid)
set the right child
Definition: tree_model.h:219
XGBOOST_DEVICE bool DefaultLeft() const
when feature is unknown, whether goes to left child
Definition: tree_model.h:193
XGBOOST_DEVICE int LeftChild() const
index of left child
Definition: tree_model.h:181
XGBOOST_DEVICE int DefaultChild() const
index of default child when feature is missing
Definition: tree_model.h:185
XGBOOST_DEVICE SplitCondT SplitCond() const
Definition: tree_model.h:199
define regression tree to be the most common tree model.
Definition: tree_model.h:158
int MaxDepth(int nid) const
get maximum depth
Definition: tree_model.h:540
void SaveModel(Json *out) const override
saves the model config to a JSON object
bst_target_t NumTargets() const
The size of leaf weight.
Definition: tree_model.h:481
void WalkTree(Func func) const
Definition: tree_model.h:391
void Save(dmlc::Stream *fo) const
save model to stream
bool IsLeaf(bst_node_t nidx) const
Definition: tree_model.h:697
bool operator==(const RegTree &b) const
Definition: tree_model.h:382
const RTreeNodeStat & Stat(int nid) const
get node statistics given nid
Definition: tree_model.h:364
void ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left, linalg::VectorView< float const > base_weight, linalg::VectorView< float const > left_weight, linalg::VectorView< float const > right_weight)
Expands a leaf node into two additional leaf nodes for a multi-target tree.
bst_node_t Parent(bst_node_t nidx) const
Definition: tree_model.h:703
bst_node_t NumNodes() const noexcept
Get the total number of nodes including deleted ones in this tree.
Definition: tree_model.h:496
const Node & operator[](int nid) const
get node given nid
Definition: tree_model.h:349
void ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum, bst_node_t leaf_right_child=kInvalidNodeId)
Expands a leaf node into two additional leaf nodes.
Node & operator[](int nid)
get node given nid
Definition: tree_model.h:345
RegTree()
Definition: tree_model.h:322
static constexpr bst_node_t kInvalidNodeId
Definition: tree_model.h:161
bst_feature_t SplitIndex(bst_node_t nidx) const
Definition: tree_model.h:673
bool IsRoot(bst_node_t nidx) const
Definition: tree_model.h:691
static constexpr uint32_t kDeletedNodeMarker
Definition: tree_model.h:162
bool IsMultiTarget() const
Whether this is a multi-target tree.
Definition: tree_model.h:477
bst_node_t NumExtraNodes() const noexcept
number of extra nodes besides the root
Definition: tree_model.h:506
bool DefaultLeft(bst_node_t nidx) const
Definition: tree_model.h:685
auto GetMultiTargetTree() const
Get the underlying implementaiton of multi-target tree.
Definition: tree_model.h:485
void Load(dmlc::Stream *fi)
load model from stream
bst_node_t LeftChild(bst_node_t nidx) const
Definition: tree_model.h:709
bst_node_t GetNumLeaves() const
RegTree(bst_target_t n_targets, bst_feature_t n_features)
Constructor that initializes the tree model with shape.
Definition: tree_model.h:336
bst_node_t RightChild(bst_node_t nidx) const
Definition: tree_model.h:715
common::Span< uint32_t const > NodeCats(bst_node_t nidx) const
Get the bit storage for categories.
Definition: tree_model.h:639
bool IsLeftChild(bst_node_t nidx) const
Definition: tree_model.h:721
CategoricalSplitMatrix GetCategoriesMatrix() const
Definition: tree_model.h:665
RTreeNodeStat & Stat(int nid)
get node statistics given nid
Definition: tree_model.h:360
bst_float SplitCondT
Definition: tree_model.h:160
void ExpandCategorical(bst_node_t nid, bst_feature_t split_index, common::Span< const uint32_t > split_cat, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum)
Expands a leaf node with categories.
bool Equal(const RegTree &b) const
Compares whether 2 trees are equal from a user's perspective. The equality compares only non-deleted ...
std::vector< FeatureType > const & GetSplitTypes() const
Get split types for all nodes.
Definition: tree_model.h:630
void CollapseToLeaf(int rid, bst_float value)
collapse a non leaf node to a leaf node, delete its children
Definition: tree_model.h:311
bst_node_t NumValidNodes() const noexcept
Get the total number of valid nodes in this tree.
Definition: tree_model.h:500
void ChangeToLeaf(int rid, bst_float value)
change a non leaf node to a leaf node, delete its children
Definition: tree_model.h:299
const std::vector< RTreeNodeStat > & GetStats() const
get const reference to stats
Definition: tree_model.h:357
void SetLeaf(bst_node_t nidx, linalg::VectorView< float const > weight)
Set the leaf weight for a multi-target tree.
Definition: tree_model.h:531
const std::vector< Node > & GetNodes() const
get const reference to nodes
Definition: tree_model.h:354
void CalculateContributionsApprox(const RegTree::FVec &feat, std::vector< float > *mean_values, bst_float *out_contribs) const
calculate the approximate feature contributions for the given root
void LoadModel(Json const &in) override
load the model from a JSON object
std::string DumpModel(const FeatureMap &fmap, bool with_stats, std::string format) const
dump the model in the requested format as a text string
FeatureType NodeSplitType(bst_node_t nidx) const
Get split type for a node.
Definition: tree_model.h:626
bst_feature_t NumFeatures() const noexcept
Get the number of features.
Definition: tree_model.h:492
common::Span< uint32_t const > GetSplitCategories() const
Definition: tree_model.h:633
bool HasCategoricalSplit() const
Whether this tree has categorical split.
Definition: tree_model.h:473
std::int32_t GetDepth(bst_node_t nid) const
get current depth
Definition: tree_model.h:517
static constexpr bst_node_t kRoot
Definition: tree_model.h:163
bst_node_t GetNumSplitNodes() const
auto const & GetSplitCategoriesPtr() const
Definition: tree_model.h:646
float SplitCond(bst_node_t nidx) const
Definition: tree_model.h:679
int MaxDepth()
get maximum depth
Definition: tree_model.h:548
bst_node_t Size() const
Definition: tree_model.h:729
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition: span.h:424
XGBOOST_DEVICE auto subspan() const -> Span< element_type, detail::ExtentValue< Extent, Offset, Count >::value >
Definition: span.h:596
A tensor view with static type and dimension. It implements indexing and slicing.
Definition: linalg.h:293
The input data structure of xgboost.
Feature map data structure to help visualization and model dump.
Linear algebra related utilities.
Defines the abstract interface for different components in XGBoost.
namespace of xgboost
Definition: base.h:90
std::vector< std::pair< std::string, std::string > > Args
Definition: base.h:316
uint32_t bst_feature_t
Type for data column (feature) index.
Definition: base.h:101
std::int32_t bst_node_t
Type for tree node index.
Definition: base.h:112
FeatureType
Definition: data.h:41
std::uint32_t bst_target_t
Type for indexing into output targets.
Definition: base.h:118
float bst_float
float type, used for storing statistics
Definition: base.h:97
StringView MTNotImplemented()
Definition: tree_model.h:834
Definition: model.h:17
node statistics used in regression tree
Definition: tree_model.h:96
RTreeNodeStat(float loss_chg, float sum_hess, float weight)
Definition: tree_model.h:107
bst_float loss_chg
loss change caused by current split
Definition: tree_model.h:98
int leaf_child_cnt
number of child that is leaf node known up to now
Definition: tree_model.h:104
bst_float sum_hess
sum of hessian values, used to measure coverage of data
Definition: tree_model.h:100
bool operator==(const RTreeNodeStat &b) const
Definition: tree_model.h:109
RTreeNodeStat ByteSwap() const
Definition: tree_model.h:115
bst_float base_weight
weight of current node
Definition: tree_model.h:102
std::size_t size
Definition: tree_model.h:658
std::size_t beg
Definition: tree_model.h:657
CSR-like matrix for categorical splits.
Definition: tree_model.h:655
common::Span< uint32_t const > categories
Definition: tree_model.h:661
common::Span< Segment const > node_ptr
Definition: tree_model.h:662
common::Span< FeatureType const > split_type
Definition: tree_model.h:660
dense feature vector that can be taken by RegTree and can be construct from sparse feature vector.
Definition: tree_model.h:554
void Drop()
drop the trace after fill, must be called after fill.
Definition: tree_model.h:810
bool HasMissing() const
Definition: tree_model.h:829
void Fill(const SparsePage::Inst &inst)
fill the vector with sparse vector
Definition: tree_model.h:798
bool IsMissing(size_t i) const
check whether i-th entry is missing
Definition: tree_model.h:825
size_t Size() const
returns the size of the feature vector
Definition: tree_model.h:817
bst_float GetFvalue(size_t i) const
get ith value
Definition: tree_model.h:821
void Init(size_t size)
initialize the vector with size vector
Definition: tree_model.h:791
Definition: string_view.h:15
meta parameters of the tree
Definition: tree_model.h:35
bst_feature_t num_feature
number of features used for tree construction
Definition: tree_model.h:45
int num_nodes
total number of nodes
Definition: tree_model.h:39
int num_deleted
number of deleted nodes
Definition: tree_model.h:41
bool operator==(const TreeParam &b) const
Definition: tree_model.h:89
int reserved[31]
reserved part, make sure alignment works for 64bit
Definition: tree_model.h:52
TreeParam ByteSwap() const
Definition: tree_model.h:62
TreeParam()
constructor
Definition: tree_model.h:54
DMLC_DECLARE_PARAMETER(TreeParam)
Definition: tree_model.h:75
bst_target_t size_leaf_vector
leaf vector size, used for vector tree used to store more than one dimensional information in tree
Definition: tree_model.h:50
int deprecated_num_roots
(Deprecated) number of start root
Definition: tree_model.h:37
int deprecated_max_depth
maximum depth, this is a statistics of the tree
Definition: tree_model.h:43