xgboost
tree_model.h
Go to the documentation of this file.
1 
7 #ifndef XGBOOST_TREE_MODEL_H_
8 #define XGBOOST_TREE_MODEL_H_
9 
10 #include <dmlc/io.h>
11 #include <dmlc/parameter.h>
12 #include <xgboost/base.h>
13 #include <xgboost/data.h>
14 #include <xgboost/feature_map.h>
15 #include <xgboost/linalg.h> // for VectorView
16 #include <xgboost/logging.h>
17 #include <xgboost/model.h>
18 #include <xgboost/multi_target_tree_model.h> // for MultiTargetTree
19 
20 #include <algorithm>
21 #include <cstring>
22 #include <limits>
23 #include <memory> // for make_unique
24 #include <stack>
25 #include <string>
26 #include <vector>
27 
28 namespace xgboost {
29 class Json;
30 
31 // FIXME(trivialfis): Once binary IO is gone, make this parameter internal as it should
32 // not be configured by users.
34 struct TreeParam : public dmlc::Parameter<TreeParam> {
38  int num_nodes{1};
40  int num_deleted{0};
51  int reserved[31];
54  // assert compact alignment
55  static_assert(sizeof(TreeParam) == (31 + 6) * sizeof(int), "TreeParam: 64 bit align");
56  std::memset(reserved, 0, sizeof(reserved));
57  }
58 
59  // Swap byte order for all fields. Useful for transporting models between machines with different
60  // endianness (big endian vs little endian)
61  [[nodiscard]] TreeParam ByteSwap() const {
62  TreeParam x = *this;
63  dmlc::ByteSwap(&x.deprecated_num_roots, sizeof(x.deprecated_num_roots), 1);
64  dmlc::ByteSwap(&x.num_nodes, sizeof(x.num_nodes), 1);
65  dmlc::ByteSwap(&x.num_deleted, sizeof(x.num_deleted), 1);
66  dmlc::ByteSwap(&x.deprecated_max_depth, sizeof(x.deprecated_max_depth), 1);
67  dmlc::ByteSwap(&x.num_feature, sizeof(x.num_feature), 1);
68  dmlc::ByteSwap(&x.size_leaf_vector, sizeof(x.size_leaf_vector), 1);
69  dmlc::ByteSwap(x.reserved, sizeof(x.reserved[0]), sizeof(x.reserved) / sizeof(x.reserved[0]));
70  return x;
71  }
72 
73  // declare the parameters
75  // only declare the parameters that can be set by the user.
76  // other arguments are set by the algorithm.
77  DMLC_DECLARE_FIELD(num_nodes).set_lower_bound(1).set_default(1);
78  DMLC_DECLARE_FIELD(num_feature)
79  .set_default(0)
80  .describe("Number of features used in tree construction.");
81  DMLC_DECLARE_FIELD(num_deleted).set_default(0);
82  DMLC_DECLARE_FIELD(size_leaf_vector)
83  .set_lower_bound(0)
84  .set_default(1)
85  .describe("Size of leaf vector, reserved for vector tree");
86  }
87 
88  bool operator==(const TreeParam& b) const {
89  return num_nodes == b.num_nodes && num_deleted == b.num_deleted &&
91  }
92 };
93 
95 struct RTreeNodeStat {
103  int leaf_child_cnt {0};
104 
105  RTreeNodeStat() = default;
106  RTreeNodeStat(float loss_chg, float sum_hess, float weight) :
108  bool operator==(const RTreeNodeStat& b) const {
109  return loss_chg == b.loss_chg && sum_hess == b.sum_hess &&
111  }
112  // Swap byte order for all fields. Useful for transporting models between machines with different
113  // endianness (big endian vs little endian)
114  [[nodiscard]] RTreeNodeStat ByteSwap() const {
115  RTreeNodeStat x = *this;
116  dmlc::ByteSwap(&x.loss_chg, sizeof(x.loss_chg), 1);
117  dmlc::ByteSwap(&x.sum_hess, sizeof(x.sum_hess), 1);
118  dmlc::ByteSwap(&x.base_weight, sizeof(x.base_weight), 1);
119  dmlc::ByteSwap(&x.leaf_child_cnt, sizeof(x.leaf_child_cnt), 1);
120  return x;
121  }
122 };
123 
127 template <typename T>
129  std::unique_ptr<T> ptr_{nullptr};
130 
131  public:
132  CopyUniquePtr() = default;
134  ptr_.reset(nullptr);
135  if (that.ptr_) {
136  ptr_ = std::make_unique<T>(*that);
137  }
138  }
139  T* get() const noexcept { return ptr_.get(); } // NOLINT
140 
141  T& operator*() { return *ptr_; }
142  T* operator->() noexcept { return this->get(); }
143 
144  T const& operator*() const { return *ptr_; }
145  T const* operator->() const noexcept { return this->get(); }
146 
147  explicit operator bool() const { return static_cast<bool>(ptr_); }
148  bool operator!() const { return !ptr_; }
149  void reset(T* ptr) { ptr_.reset(ptr); } // NOLINT
150 };
151 
157 class RegTree : public Model {
158  public:
161  static constexpr uint32_t kDeletedNodeMarker = std::numeric_limits<uint32_t>::max();
162  static constexpr bst_node_t kRoot{0};
163 
165  class Node {
166  public:
168  // assert compact alignment
169  static_assert(sizeof(Node) == 4 * sizeof(int) + sizeof(Info),
170  "Node: 64 bit align");
171  }
172  Node(int32_t cleft, int32_t cright, int32_t parent,
173  uint32_t split_ind, float split_cond, bool default_left) :
174  parent_{parent}, cleft_{cleft}, cright_{cright} {
175  this->SetParent(parent_);
176  this->SetSplit(split_ind, split_cond, default_left);
177  }
178 
180  [[nodiscard]] XGBOOST_DEVICE int LeftChild() const { return this->cleft_; }
182  [[nodiscard]] XGBOOST_DEVICE int RightChild() const { return this->cright_; }
184  [[nodiscard]] XGBOOST_DEVICE int DefaultChild() const {
185  return this->DefaultLeft() ? this->LeftChild() : this->RightChild();
186  }
188  [[nodiscard]] XGBOOST_DEVICE bst_feature_t SplitIndex() const {
189  static_assert(!std::is_signed_v<bst_feature_t>);
190  return sindex_ & ((1U << 31) - 1U);
191  }
193  [[nodiscard]] XGBOOST_DEVICE bool DefaultLeft() const { return (sindex_ >> 31) != 0; }
195  [[nodiscard]] XGBOOST_DEVICE bool IsLeaf() const { return cleft_ == kInvalidNodeId; }
197  [[nodiscard]] XGBOOST_DEVICE float LeafValue() const { return (this->info_).leaf_value; }
199  [[nodiscard]] XGBOOST_DEVICE SplitCondT SplitCond() const { return (this->info_).split_cond; }
201  [[nodiscard]] XGBOOST_DEVICE int Parent() const { return parent_ & ((1U << 31) - 1); }
203  [[nodiscard]] XGBOOST_DEVICE bool IsLeftChild() const { return (parent_ & (1U << 31)) != 0; }
205  [[nodiscard]] XGBOOST_DEVICE bool IsDeleted() const { return sindex_ == kDeletedNodeMarker; }
207  [[nodiscard]] XGBOOST_DEVICE bool IsRoot() const { return parent_ == kInvalidNodeId; }
212  XGBOOST_DEVICE void SetLeftChild(int nid) {
213  this->cleft_ = nid;
214  }
220  this->cright_ = nid;
221  }
228  XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond,
229  bool default_left = false) {
230  if (default_left) split_index |= (1U << 31);
231  this->sindex_ = split_index;
232  (this->info_).split_cond = split_cond;
233  }
240  XGBOOST_DEVICE void SetLeaf(bst_float value, int right = kInvalidNodeId) {
241  (this->info_).leaf_value = value;
242  this->cleft_ = kInvalidNodeId;
243  this->cright_ = right;
244  }
247  this->sindex_ = kDeletedNodeMarker;
248  }
251  this->sindex_ = 0;
252  }
253  // set parent
254  XGBOOST_DEVICE void SetParent(int pidx, bool is_left_child = true) {
255  if (is_left_child) pidx |= (1U << 31);
256  this->parent_ = pidx;
257  }
258  bool operator==(const Node& b) const {
259  return parent_ == b.parent_ && cleft_ == b.cleft_ &&
260  cright_ == b.cright_ && sindex_ == b.sindex_ &&
261  info_.leaf_value == b.info_.leaf_value;
262  }
263 
264  [[nodiscard]] Node ByteSwap() const {
265  Node x = *this;
266  dmlc::ByteSwap(&x.parent_, sizeof(x.parent_), 1);
267  dmlc::ByteSwap(&x.cleft_, sizeof(x.cleft_), 1);
268  dmlc::ByteSwap(&x.cright_, sizeof(x.cright_), 1);
269  dmlc::ByteSwap(&x.sindex_, sizeof(x.sindex_), 1);
270  dmlc::ByteSwap(&x.info_, sizeof(x.info_), 1);
271  return x;
272  }
273 
274  private:
279  union Info{
280  bst_float leaf_value;
281  SplitCondT split_cond;
282  };
283  // pointer to parent, highest bit is used to
284  // indicate whether it's a left child or not
285  int32_t parent_{kInvalidNodeId};
286  // pointer to left, right
287  int32_t cleft_{kInvalidNodeId}, cright_{kInvalidNodeId};
288  // split feature index, left split or right split depends on the highest bit
289  uint32_t sindex_{0};
290  // extra info
291  Info info_;
292  };
293 
299  void ChangeToLeaf(int rid, bst_float value) {
300  CHECK(nodes_[nodes_[rid].LeftChild() ].IsLeaf());
301  CHECK(nodes_[nodes_[rid].RightChild()].IsLeaf());
302  this->DeleteNode(nodes_[rid].LeftChild());
303  this->DeleteNode(nodes_[rid].RightChild());
304  nodes_[rid].SetLeaf(value);
305  }
311  void CollapseToLeaf(int rid, bst_float value) {
312  if (nodes_[rid].IsLeaf()) return;
313  if (!nodes_[nodes_[rid].LeftChild() ].IsLeaf()) {
314  CollapseToLeaf(nodes_[rid].LeftChild(), 0.0f);
315  }
316  if (!nodes_[nodes_[rid].RightChild() ].IsLeaf()) {
317  CollapseToLeaf(nodes_[rid].RightChild(), 0.0f);
318  }
319  this->ChangeToLeaf(rid, value);
320  }
321 
323  param_.Init(Args{});
324  nodes_.resize(param_.num_nodes);
325  stats_.resize(param_.num_nodes);
326  split_types_.resize(param_.num_nodes, FeatureType::kNumerical);
327  split_categories_segments_.resize(param_.num_nodes);
328  for (int i = 0; i < param_.num_nodes; i++) {
329  nodes_[i].SetLeaf(0.0f);
330  nodes_[i].SetParent(kInvalidNodeId);
331  }
332  }
336  explicit RegTree(bst_target_t n_targets, bst_feature_t n_features) : RegTree{} {
337  param_.num_feature = n_features;
338  param_.size_leaf_vector = n_targets;
339  if (n_targets > 1) {
340  this->p_mt_tree_.reset(new MultiTargetTree{&param_});
341  }
342  }
343 
345  Node& operator[](int nid) {
346  return nodes_[nid];
347  }
349  const Node& operator[](int nid) const {
350  return nodes_[nid];
351  }
352 
354  [[nodiscard]] const std::vector<Node>& GetNodes() const { return nodes_; }
355 
357  [[nodiscard]] const std::vector<RTreeNodeStat>& GetStats() const { return stats_; }
358 
360  RTreeNodeStat& Stat(int nid) {
361  return stats_[nid];
362  }
364  [[nodiscard]] const RTreeNodeStat& Stat(int nid) const {
365  return stats_[nid];
366  }
367 
372  void Load(dmlc::Stream* fi);
377  void Save(dmlc::Stream* fo) const;
378 
379  void LoadModel(Json const& in) override;
380  void SaveModel(Json* out) const override;
381 
382  bool operator==(const RegTree& b) const {
383  return nodes_ == b.nodes_ && stats_ == b.stats_ &&
384  deleted_nodes_ == b.deleted_nodes_ && param_ == b.param_;
385  }
386  /* \brief Iterate through all nodes in this tree.
387  *
388  * \param Function that accepts a node index, and returns false when iteration should
389  * stop, otherwise returns true.
390  */
391  template <typename Func> void WalkTree(Func func) const {
392  std::stack<bst_node_t> nodes;
393  nodes.push(kRoot);
394  auto &self = *this;
395  while (!nodes.empty()) {
396  auto nidx = nodes.top();
397  nodes.pop();
398  if (!func(nidx)) {
399  return;
400  }
401  auto left = self.LeftChild(nidx);
402  auto right = self.RightChild(nidx);
403  if (left != RegTree::kInvalidNodeId) {
404  nodes.push(left);
405  }
406  if (right != RegTree::kInvalidNodeId) {
407  nodes.push(right);
408  }
409  }
410  }
417  [[nodiscard]] bool Equal(const RegTree& b) const;
418 
436  void ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value,
437  bool default_left, bst_float base_weight,
438  bst_float left_leaf_weight, bst_float right_leaf_weight,
439  bst_float loss_change, float sum_hess, float left_sum,
440  float right_sum,
441  bst_node_t leaf_right_child = kInvalidNodeId);
445  void ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left,
448  linalg::VectorView<float const> right_weight);
449 
466  common::Span<const uint32_t> split_cat, bool default_left,
467  bst_float base_weight, bst_float left_leaf_weight,
468  bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
469  float left_sum, float right_sum);
473  [[nodiscard]] bool HasCategoricalSplit() const { return !split_categories_.empty(); }
477  [[nodiscard]] bool IsMultiTarget() const { return static_cast<bool>(p_mt_tree_); }
481  [[nodiscard]] bst_target_t NumTargets() const { return param_.size_leaf_vector; }
485  [[nodiscard]] auto GetMultiTargetTree() const {
486  CHECK(IsMultiTarget());
487  return p_mt_tree_.get();
488  }
492  [[nodiscard]] bst_feature_t NumFeatures() const noexcept { return param_.num_feature; }
496  [[nodiscard]] bst_node_t NumNodes() const noexcept { return param_.num_nodes; }
500  [[nodiscard]] bst_node_t NumValidNodes() const noexcept {
501  return param_.num_nodes - param_.num_deleted;
502  }
506  [[nodiscard]] bst_node_t NumExtraNodes() const noexcept {
507  return param_.num_nodes - 1 - param_.num_deleted;
508  }
509  /* \brief Count number of leaves in tree. */
510  [[nodiscard]] bst_node_t GetNumLeaves() const;
511  [[nodiscard]] bst_node_t GetNumSplitNodes() const;
512 
517  [[nodiscard]] std::int32_t GetDepth(bst_node_t nid) const {
518  if (IsMultiTarget()) {
519  return this->p_mt_tree_->Depth(nid);
520  }
521  int depth = 0;
522  while (!nodes_[nid].IsRoot()) {
523  ++depth;
524  nid = nodes_[nid].Parent();
525  }
526  return depth;
527  }
532  CHECK(IsMultiTarget());
533  return this->p_mt_tree_->SetLeaf(nidx, weight);
534  }
535 
540  [[nodiscard]] int MaxDepth(int nid) const {
541  if (nodes_[nid].IsLeaf()) return 0;
542  return std::max(MaxDepth(nodes_[nid].LeftChild()) + 1, MaxDepth(nodes_[nid].RightChild()) + 1);
543  }
544 
548  int MaxDepth() { return MaxDepth(0); }
549 
554  struct FVec {
559  void Init(size_t size);
564  void Fill(SparsePage::Inst const& inst);
565 
570  void Drop();
575  [[nodiscard]] size_t Size() const;
581  [[nodiscard]] bst_float GetFvalue(size_t i) const;
587  [[nodiscard]] bool IsMissing(size_t i) const;
588  [[nodiscard]] bool HasMissing() const;
589  void HasMissing(bool has_missing) { this->has_missing_ = has_missing; }
590 
591  [[nodiscard]] common::Span<float> Data() { return data_; }
592 
593  private:
599  std::vector<float> data_;
600  bool has_missing_;
601  };
602 
609  std::vector<float>* mean_values,
610  bst_float* out_contribs) const;
618  [[nodiscard]] std::string DumpModel(const FeatureMap& fmap, bool with_stats,
619  std::string format) const;
625  [[nodiscard]] FeatureType NodeSplitType(bst_node_t nidx) const { return split_types_.at(nidx); }
629  [[nodiscard]] std::vector<FeatureType> const& GetSplitTypes() const {
630  return split_types_;
631  }
633  return split_categories_;
634  }
639  auto node_ptr = GetCategoriesMatrix().node_ptr;
640  auto categories = GetCategoriesMatrix().categories;
641  auto segment = node_ptr[nidx];
642  auto node_cats = categories.subspan(segment.beg, segment.size);
643  return node_cats;
644  }
645  [[nodiscard]] auto const& GetSplitCategoriesPtr() const { return split_categories_segments_; }
646 
655  struct Segment {
656  std::size_t beg{0};
657  std::size_t size{0};
658  };
662  };
663 
667  view.categories = this->GetSplitCategories();
668  view.node_ptr = common::Span<CategoricalSplitMatrix::Segment const>(split_categories_segments_);
669  return view;
670  }
671 
672  [[nodiscard]] bst_feature_t SplitIndex(bst_node_t nidx) const {
673  if (IsMultiTarget()) {
674  return this->p_mt_tree_->SplitIndex(nidx);
675  }
676  return (*this)[nidx].SplitIndex();
677  }
678  [[nodiscard]] float SplitCond(bst_node_t nidx) const {
679  if (IsMultiTarget()) {
680  return this->p_mt_tree_->SplitCond(nidx);
681  }
682  return (*this)[nidx].SplitCond();
683  }
684  [[nodiscard]] bool DefaultLeft(bst_node_t nidx) const {
685  if (IsMultiTarget()) {
686  return this->p_mt_tree_->DefaultLeft(nidx);
687  }
688  return (*this)[nidx].DefaultLeft();
689  }
690  [[nodiscard]] bst_node_t DefaultChild(bst_node_t nidx) const {
691  return this->DefaultLeft(nidx) ? this->LeftChild(nidx) : this->RightChild(nidx);
692  }
693  [[nodiscard]] bool IsRoot(bst_node_t nidx) const {
694  if (IsMultiTarget()) {
695  return nidx == kRoot;
696  }
697  return (*this)[nidx].IsRoot();
698  }
699  [[nodiscard]] bool IsLeaf(bst_node_t nidx) const {
700  if (IsMultiTarget()) {
701  return this->p_mt_tree_->IsLeaf(nidx);
702  }
703  return (*this)[nidx].IsLeaf();
704  }
705  [[nodiscard]] bst_node_t Parent(bst_node_t nidx) const {
706  if (IsMultiTarget()) {
707  return this->p_mt_tree_->Parent(nidx);
708  }
709  return (*this)[nidx].Parent();
710  }
711  [[nodiscard]] bst_node_t LeftChild(bst_node_t nidx) const {
712  if (IsMultiTarget()) {
713  return this->p_mt_tree_->LeftChild(nidx);
714  }
715  return (*this)[nidx].LeftChild();
716  }
717  [[nodiscard]] bst_node_t RightChild(bst_node_t nidx) const {
718  if (IsMultiTarget()) {
719  return this->p_mt_tree_->RightChild(nidx);
720  }
721  return (*this)[nidx].RightChild();
722  }
723  [[nodiscard]] bool IsLeftChild(bst_node_t nidx) const {
724  if (IsMultiTarget()) {
725  CHECK_NE(nidx, kRoot);
726  auto p = this->p_mt_tree_->Parent(nidx);
727  return nidx == this->p_mt_tree_->LeftChild(p);
728  }
729  return (*this)[nidx].IsLeftChild();
730  }
731  [[nodiscard]] bst_node_t Size() const {
732  if (IsMultiTarget()) {
733  return this->p_mt_tree_->Size();
734  }
735  return this->nodes_.size();
736  }
737 
738  private:
739  template <bool typed>
740  void LoadCategoricalSplit(Json const& in);
741  void SaveCategoricalSplit(Json* p_out) const;
743  TreeParam param_;
744  // vector of nodes
745  std::vector<Node> nodes_;
746  // free node space, used during training process
747  std::vector<int> deleted_nodes_;
748  // stats of nodes
749  std::vector<RTreeNodeStat> stats_;
750  std::vector<FeatureType> split_types_;
751 
752  // Categories for each internal node.
753  std::vector<uint32_t> split_categories_;
754  // Ptr to split categories of each node.
755  std::vector<CategoricalSplitMatrix::Segment> split_categories_segments_;
756  // ptr to multi-target tree with vector leaf.
758  // allocate a new node,
759  // !!!!!! NOTE: may cause BUG here, nodes.resize
760  bst_node_t AllocNode() {
761  if (param_.num_deleted != 0) {
762  int nid = deleted_nodes_.back();
763  deleted_nodes_.pop_back();
764  nodes_[nid].Reuse();
765  --param_.num_deleted;
766  return nid;
767  }
768  int nd = param_.num_nodes++;
769  CHECK_LT(param_.num_nodes, std::numeric_limits<int>::max())
770  << "number of nodes in the tree exceed 2^31";
771  nodes_.resize(param_.num_nodes);
772  stats_.resize(param_.num_nodes);
773  split_types_.resize(param_.num_nodes, FeatureType::kNumerical);
774  split_categories_segments_.resize(param_.num_nodes);
775  return nd;
776  }
777  // delete a tree node, keep the parent field to allow trace back
778  void DeleteNode(int nid) {
779  CHECK_GE(nid, 1);
780  auto pid = (*this)[nid].Parent();
781  if (nid == (*this)[pid].LeftChild()) {
782  (*this)[pid].SetLeftChild(kInvalidNodeId);
783  } else {
784  (*this)[pid].SetRightChild(kInvalidNodeId);
785  }
786 
787  deleted_nodes_.push_back(nid);
788  nodes_[nid].MarkDelete();
789  ++param_.num_deleted;
790  }
791 };
792 
793 inline void RegTree::FVec::Init(size_t size) {
794  data_.resize(size);
795  std::fill(data_.begin(), data_.end(), std::numeric_limits<float>::quiet_NaN());
796  has_missing_ = true;
797 }
798 
799 inline void RegTree::FVec::Fill(SparsePage::Inst const& inst) {
800  auto p_data = inst.data();
801  auto p_out = data_.data();
802 
803  for (std::size_t i = 0, n = inst.size(); i < n; ++i) {
804  auto const& entry = p_data[i];
805  p_out[entry.index] = entry.fvalue;
806  }
807  has_missing_ = data_.size() != inst.size();
808 }
809 
810 inline void RegTree::FVec::Drop() { this->Init(this->Size()); }
811 
812 inline size_t RegTree::FVec::Size() const {
813  return data_.size();
814 }
815 
816 inline float RegTree::FVec::GetFvalue(size_t i) const {
817  return data_[i];
818 }
819 
820 inline bool RegTree::FVec::IsMissing(size_t i) const { return std::isnan(data_[i]); }
821 
822 inline bool RegTree::FVec::HasMissing() const { return has_missing_; }
823 
824 // Multi-target tree not yet implemented error
826  return " support for multi-target tree is not yet implemented.";
827 }
828 } // namespace xgboost
829 #endif // XGBOOST_TREE_MODEL_H_
Defines configuration macros and basic types for xgboost.
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition: base.h:64
Helper for defining copyable data structure that contains unique pointers.
Definition: tree_model.h:128
T const * operator->() const noexcept
Definition: tree_model.h:145
T * get() const noexcept
Definition: tree_model.h:139
bool operator!() const
Definition: tree_model.h:148
CopyUniquePtr(CopyUniquePtr const &that)
Definition: tree_model.h:133
T * operator->() noexcept
Definition: tree_model.h:142
T & operator*()
Definition: tree_model.h:141
T const & operator*() const
Definition: tree_model.h:144
void reset(T *ptr)
Definition: tree_model.h:149
Feature map data structure to help text model dump. TODO(tqchen) consider make it even more lightweig...
Definition: feature_map.h:22
Data structure representing JSON format.
Definition: json.h:378
Tree structure for multi-target model.
Definition: multi_target_tree_model.h:23
static constexpr bst_node_t InvalidNodeId()
Definition: multi_target_tree_model.h:25
tree node
Definition: tree_model.h:165
XGBOOST_DEVICE int Parent() const
get parent of the node
Definition: tree_model.h:201
XGBOOST_DEVICE void MarkDelete()
mark that this node is deleted
Definition: tree_model.h:246
XGBOOST_DEVICE bool IsRoot() const
whether current node is root
Definition: tree_model.h:207
XGBOOST_DEVICE int RightChild() const
index of right child
Definition: tree_model.h:182
XGBOOST_DEVICE float LeafValue() const
Definition: tree_model.h:197
XGBOOST_DEVICE Node()
Definition: tree_model.h:167
XGBOOST_DEVICE void SetParent(int pidx, bool is_left_child=true)
Definition: tree_model.h:254
Node ByteSwap() const
Definition: tree_model.h:264
XGBOOST_DEVICE void SetLeaf(bst_float value, int right=kInvalidNodeId)
set the leaf value of the node
Definition: tree_model.h:240
XGBOOST_DEVICE bool IsLeftChild() const
whether current node is left child
Definition: tree_model.h:203
XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond, bool default_left=false)
set split condition of current node
Definition: tree_model.h:228
XGBOOST_DEVICE void SetLeftChild(int nid)
set the left child
Definition: tree_model.h:212
XGBOOST_DEVICE bst_feature_t SplitIndex() const
feature index of split condition
Definition: tree_model.h:188
XGBOOST_DEVICE bool IsDeleted() const
whether this node is deleted
Definition: tree_model.h:205
XGBOOST_DEVICE bool IsLeaf() const
whether current node is leaf node
Definition: tree_model.h:195
bool operator==(const Node &b) const
Definition: tree_model.h:258
Node(int32_t cleft, int32_t cright, int32_t parent, uint32_t split_ind, float split_cond, bool default_left)
Definition: tree_model.h:172
XGBOOST_DEVICE void Reuse()
Reuse this deleted node.
Definition: tree_model.h:250
XGBOOST_DEVICE void SetRightChild(int nid)
set the right child
Definition: tree_model.h:219
XGBOOST_DEVICE bool DefaultLeft() const
when feature is unknown, whether goes to left child
Definition: tree_model.h:193
XGBOOST_DEVICE int LeftChild() const
index of left child
Definition: tree_model.h:180
XGBOOST_DEVICE int DefaultChild() const
index of default child when feature is missing
Definition: tree_model.h:184
XGBOOST_DEVICE SplitCondT SplitCond() const
Definition: tree_model.h:199
define regression tree to be the most common tree model.
Definition: tree_model.h:157
int MaxDepth(int nid) const
get maximum depth
Definition: tree_model.h:540
void SaveModel(Json *out) const override
saves the model config to a JSON object
bst_target_t NumTargets() const
The size of leaf weight.
Definition: tree_model.h:481
void WalkTree(Func func) const
Definition: tree_model.h:391
void Save(dmlc::Stream *fo) const
save model to stream
bool IsLeaf(bst_node_t nidx) const
Definition: tree_model.h:699
bool operator==(const RegTree &b) const
Definition: tree_model.h:382
const RTreeNodeStat & Stat(int nid) const
get node statistics given nid
Definition: tree_model.h:364
void ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left, linalg::VectorView< float const > base_weight, linalg::VectorView< float const > left_weight, linalg::VectorView< float const > right_weight)
Expands a leaf node into two additional leaf nodes for a multi-target tree.
bst_node_t Parent(bst_node_t nidx) const
Definition: tree_model.h:705
bst_node_t NumNodes() const noexcept
Get the total number of nodes including deleted ones in this tree.
Definition: tree_model.h:496
const Node & operator[](int nid) const
get node given nid
Definition: tree_model.h:349
bst_node_t DefaultChild(bst_node_t nidx) const
Definition: tree_model.h:690
void ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum, bst_node_t leaf_right_child=kInvalidNodeId)
Expands a leaf node into two additional leaf nodes.
Node & operator[](int nid)
get node given nid
Definition: tree_model.h:345
RegTree()
Definition: tree_model.h:322
static constexpr bst_node_t kInvalidNodeId
Definition: tree_model.h:160
bst_feature_t SplitIndex(bst_node_t nidx) const
Definition: tree_model.h:672
bool IsRoot(bst_node_t nidx) const
Definition: tree_model.h:693
static constexpr uint32_t kDeletedNodeMarker
Definition: tree_model.h:161
bool IsMultiTarget() const
Whether this is a multi-target tree.
Definition: tree_model.h:477
bst_node_t NumExtraNodes() const noexcept
number of extra nodes besides the root
Definition: tree_model.h:506
bool DefaultLeft(bst_node_t nidx) const
Definition: tree_model.h:684
auto GetMultiTargetTree() const
Get the underlying implementaiton of multi-target tree.
Definition: tree_model.h:485
void Load(dmlc::Stream *fi)
load model from stream
bst_node_t LeftChild(bst_node_t nidx) const
Definition: tree_model.h:711
bst_node_t GetNumLeaves() const
RegTree(bst_target_t n_targets, bst_feature_t n_features)
Constructor that initializes the tree model with shape.
Definition: tree_model.h:336
bst_node_t RightChild(bst_node_t nidx) const
Definition: tree_model.h:717
common::Span< uint32_t const > NodeCats(bst_node_t nidx) const
Get the bit storage for categories.
Definition: tree_model.h:638
bool IsLeftChild(bst_node_t nidx) const
Definition: tree_model.h:723
CategoricalSplitMatrix GetCategoriesMatrix() const
Definition: tree_model.h:664
RTreeNodeStat & Stat(int nid)
get node statistics given nid
Definition: tree_model.h:360
bst_float SplitCondT
Definition: tree_model.h:159
void ExpandCategorical(bst_node_t nid, bst_feature_t split_index, common::Span< const uint32_t > split_cat, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum)
Expands a leaf node with categories.
bool Equal(const RegTree &b) const
Compares whether 2 trees are equal from a user's perspective. The equality compares only non-deleted ...
std::vector< FeatureType > const & GetSplitTypes() const
Get split types for all nodes.
Definition: tree_model.h:629
void CollapseToLeaf(int rid, bst_float value)
collapse a non leaf node to a leaf node, delete its children
Definition: tree_model.h:311
bst_node_t NumValidNodes() const noexcept
Get the total number of valid nodes in this tree.
Definition: tree_model.h:500
void ChangeToLeaf(int rid, bst_float value)
change a non leaf node to a leaf node, delete its children
Definition: tree_model.h:299
const std::vector< RTreeNodeStat > & GetStats() const
get const reference to stats
Definition: tree_model.h:357
void SetLeaf(bst_node_t nidx, linalg::VectorView< float const > weight)
Set the leaf weight for a multi-target tree.
Definition: tree_model.h:531
const std::vector< Node > & GetNodes() const
get const reference to nodes
Definition: tree_model.h:354
void CalculateContributionsApprox(const RegTree::FVec &feat, std::vector< float > *mean_values, bst_float *out_contribs) const
calculate the approximate feature contributions for the given root
void LoadModel(Json const &in) override
load the model from a JSON object
std::string DumpModel(const FeatureMap &fmap, bool with_stats, std::string format) const
dump the model in the requested format as a text string
FeatureType NodeSplitType(bst_node_t nidx) const
Get split type for a node.
Definition: tree_model.h:625
bst_feature_t NumFeatures() const noexcept
Get the number of features.
Definition: tree_model.h:492
common::Span< uint32_t const > GetSplitCategories() const
Definition: tree_model.h:632
bool HasCategoricalSplit() const
Whether this tree has categorical split.
Definition: tree_model.h:473
std::int32_t GetDepth(bst_node_t nid) const
get current depth
Definition: tree_model.h:517
static constexpr bst_node_t kRoot
Definition: tree_model.h:162
bst_node_t GetNumSplitNodes() const
auto const & GetSplitCategoriesPtr() const
Definition: tree_model.h:645
float SplitCond(bst_node_t nidx) const
Definition: tree_model.h:678
int MaxDepth()
get maximum depth
Definition: tree_model.h:548
bst_node_t Size() const
Definition: tree_model.h:731
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition: span.h:431
constexpr XGBOOST_DEVICE pointer data() const __span_noexcept
Definition: span.h:550
XGBOOST_DEVICE auto subspan() const -> Span< element_type, detail::ExtentValue< Extent, Offset, Count >::value >
Definition: span.h:597
constexpr XGBOOST_DEVICE index_type size() const __span_noexcept
Definition: span.h:555
A tensor view with static type and dimension. It implements indexing and slicing.
Definition: linalg.h:294
The input data structure of xgboost.
Feature map data structure to help visualization and model dump.
Linear algebra related utilities.
Defines the abstract interface for different components in XGBoost.
Core data structure for multi-target trees.
Definition: base.h:89
std::vector< std::pair< std::string, std::string > > Args
Definition: base.h:316
std::int32_t bst_node_t
Type for tree node index.
Definition: base.h:111
FeatureType
Definition: data.h:41
std::uint32_t bst_target_t
Type for indexing into output targets.
Definition: base.h:119
std::uint32_t bst_feature_t
Type for data column (feature) index.
Definition: base.h:99
float bst_float
float type, used for storing statistics
Definition: base.h:95
StringView MTNotImplemented()
Definition: tree_model.h:825
Definition: model.h:17
node statistics used in regression tree
Definition: tree_model.h:95
RTreeNodeStat(float loss_chg, float sum_hess, float weight)
Definition: tree_model.h:106
bst_float loss_chg
loss change caused by current split
Definition: tree_model.h:97
int leaf_child_cnt
number of child that is leaf node known up to now
Definition: tree_model.h:103
bst_float sum_hess
sum of hessian values, used to measure coverage of data
Definition: tree_model.h:99
bool operator==(const RTreeNodeStat &b) const
Definition: tree_model.h:108
RTreeNodeStat ByteSwap() const
Definition: tree_model.h:114
bst_float base_weight
weight of current node
Definition: tree_model.h:101
std::size_t size
Definition: tree_model.h:657
std::size_t beg
Definition: tree_model.h:656
CSR-like matrix for categorical splits.
Definition: tree_model.h:654
common::Span< uint32_t const > categories
Definition: tree_model.h:660
common::Span< Segment const > node_ptr
Definition: tree_model.h:661
common::Span< FeatureType const > split_type
Definition: tree_model.h:659
dense feature vector that can be taken by RegTree and can be construct from sparse feature vector.
Definition: tree_model.h:554
void HasMissing(bool has_missing)
Definition: tree_model.h:589
void Drop()
drop the trace after fill, must be called after fill.
Definition: tree_model.h:810
bool HasMissing() const
Definition: tree_model.h:822
bool IsMissing(size_t i) const
check whether i-th entry is missing
Definition: tree_model.h:820
size_t Size() const
returns the size of the feature vector
Definition: tree_model.h:812
void Init(size_t size)
initialize the vector with size vector
Definition: tree_model.h:793
common::Span< float > Data()
Definition: tree_model.h:591
void Fill(SparsePage::Inst const &inst)
fill the vector with sparse vector
Definition: tree_model.h:799
bst_float GetFvalue(size_t i) const
get ith value
Definition: tree_model.h:816
Definition: string_view.h:16
meta parameters of the tree
Definition: tree_model.h:34
bst_feature_t num_feature
number of features used for tree construction
Definition: tree_model.h:44
int num_nodes
total number of nodes
Definition: tree_model.h:38
int num_deleted
number of deleted nodes
Definition: tree_model.h:40
bool operator==(const TreeParam &b) const
Definition: tree_model.h:88
int reserved[31]
reserved part, make sure alignment works for 64bit
Definition: tree_model.h:51
TreeParam ByteSwap() const
Definition: tree_model.h:61
TreeParam()
constructor
Definition: tree_model.h:53
DMLC_DECLARE_PARAMETER(TreeParam)
Definition: tree_model.h:74
bst_target_t size_leaf_vector
leaf vector size, used for vector tree used to store more than one dimensional information in tree
Definition: tree_model.h:49
int deprecated_num_roots
(Deprecated) number of start root
Definition: tree_model.h:36
int deprecated_max_depth
maximum depth, this is a statistics of the tree
Definition: tree_model.h:42