10 #include <dmlc/io.h>
11 #include <dmlc/parameter.h>
12 #include <xgboost/base.h>
13 #include <xgboost/data.h>
14 #include <xgboost/feature_map.h>
15 #include <xgboost/linalg.h> // for VectorView
16 #include <xgboost/logging.h>
17 #include <xgboost/model.h>
18 #include <xgboost/multi_target_tree_model.h> // for MultiTargetTree
20 #include <algorithm>
21 #include <cstring>
22 #include <limits>
23 #include <memory> // for make_unique
24 #include <stack>
25 #include <string>
26 #include <tuple>
27 #include <vector>
29 namespace xgboost {
30 class Json;
32 // FIXME(trivialfis): Once binary IO is gone, make this parameter internal as it should
33 // not be configured by users.
35 struct TreeParam : public dmlc::Parameter<TreeParam> {
39  int num_nodes{1};
41  int num_deleted{0};
52  int reserved[31];
55  // assert compact alignment
56  static_assert(sizeof(TreeParam) == (31 + 6) * sizeof(int), "TreeParam: 64 bit align");
57  std::memset(reserved, 0, sizeof(reserved));
58  }
60  // Swap byte order for all fields. Useful for transporting models between machines with different
61  // endianness (big endian vs little endian)
62  [[nodiscard]] TreeParam ByteSwap() const {
63  TreeParam x = *this;
64  dmlc::ByteSwap(&x.deprecated_num_roots, sizeof(x.deprecated_num_roots), 1);
65  dmlc::ByteSwap(&x.num_nodes, sizeof(x.num_nodes), 1);
66  dmlc::ByteSwap(&x.num_deleted, sizeof(x.num_deleted), 1);
67  dmlc::ByteSwap(&x.deprecated_max_depth, sizeof(x.deprecated_max_depth), 1);
68  dmlc::ByteSwap(&x.num_feature, sizeof(x.num_feature), 1);
69  dmlc::ByteSwap(&x.size_leaf_vector, sizeof(x.size_leaf_vector), 1);
70  dmlc::ByteSwap(x.reserved, sizeof(x.reserved[0]), sizeof(x.reserved) / sizeof(x.reserved[0]));
71  return x;
72  }
74  // declare the parameters
76  // only declare the parameters that can be set by the user.
77  // other arguments are set by the algorithm.
78  DMLC_DECLARE_FIELD(num_nodes).set_lower_bound(1).set_default(1);
79  DMLC_DECLARE_FIELD(num_feature)
80  .set_default(0)
81  .describe("Number of features used in tree construction.");
82  DMLC_DECLARE_FIELD(num_deleted).set_default(0);
83  DMLC_DECLARE_FIELD(size_leaf_vector)
84  .set_lower_bound(0)
85  .set_default(1)
86  .describe("Size of leaf vector, reserved for vector tree");
87  }
89  bool operator==(const TreeParam& b) const {
90  return num_nodes == b.num_nodes && num_deleted == b.num_deleted &&
92  }
93 };
96 struct RTreeNodeStat {
104  int leaf_child_cnt {0};
106  RTreeNodeStat() = default;
107  RTreeNodeStat(float loss_chg, float sum_hess, float weight) :
109  bool operator==(const RTreeNodeStat& b) const {
110  return loss_chg == b.loss_chg && sum_hess == b.sum_hess &&
112  }
113  // Swap byte order for all fields. Useful for transporting models between machines with different
114  // endianness (big endian vs little endian)
115  [[nodiscard]] RTreeNodeStat ByteSwap() const {
116  RTreeNodeStat x = *this;
117  dmlc::ByteSwap(&x.loss_chg, sizeof(x.loss_chg), 1);
118  dmlc::ByteSwap(&x.sum_hess, sizeof(x.sum_hess), 1);
119  dmlc::ByteSwap(&x.base_weight, sizeof(x.base_weight), 1);
120  dmlc::ByteSwap(&x.leaf_child_cnt, sizeof(x.leaf_child_cnt), 1);
121  return x;
122  }
123 };
128 template <typename T>
130  std::unique_ptr<T> ptr_{nullptr};
132  public:
133  CopyUniquePtr() = default;
135  ptr_.reset(nullptr);
136  if (that.ptr_) {
137  ptr_ = std::make_unique<T>(*that);
138  }
139  }
140  T* get() const noexcept { return ptr_.get(); } // NOLINT
142  T& operator*() { return *ptr_; }
143  T* operator->() noexcept { return this->get(); }
145  T const& operator*() const { return *ptr_; }
146  T const* operator->() const noexcept { return this->get(); }
148  explicit operator bool() const { return static_cast<bool>(ptr_); }
149  bool operator!() const { return !ptr_; }
150  void reset(T* ptr) { ptr_.reset(ptr); } // NOLINT
151 };
158 class RegTree : public Model {
159  public:
162  static constexpr uint32_t kDeletedNodeMarker = std::numeric_limits<uint32_t>::max();
163  static constexpr bst_node_t kRoot{0};
166  class Node {
167  public:
169  // assert compact alignment
170  static_assert(sizeof(Node) == 4 * sizeof(int) + sizeof(Info),
171  "Node: 64 bit align");
172  }
173  Node(int32_t cleft, int32_t cright, int32_t parent,
174  uint32_t split_ind, float split_cond, bool default_left) :
175  parent_{parent}, cleft_{cleft}, cright_{cright} {
176  this->SetParent(parent_);
177  this->SetSplit(split_ind, split_cond, default_left);
178  }
181  [[nodiscard]] XGBOOST_DEVICE int LeftChild() const { return this->cleft_; }
183  [[nodiscard]] XGBOOST_DEVICE int RightChild() const { return this->cright_; }
185  [[nodiscard]] XGBOOST_DEVICE int DefaultChild() const {
186  return this->DefaultLeft() ? this->LeftChild() : this->RightChild();
187  }
189  [[nodiscard]] XGBOOST_DEVICE unsigned SplitIndex() const {
190  return sindex_ & ((1U << 31) - 1U);
191  }
193  [[nodiscard]] XGBOOST_DEVICE bool DefaultLeft() const { return (sindex_ >> 31) != 0; }
195  [[nodiscard]] XGBOOST_DEVICE bool IsLeaf() const { return cleft_ == kInvalidNodeId; }
197  [[nodiscard]] XGBOOST_DEVICE float LeafValue() const { return (this->info_).leaf_value; }
199  [[nodiscard]] XGBOOST_DEVICE SplitCondT SplitCond() const { return (this->info_).split_cond; }
201  [[nodiscard]] XGBOOST_DEVICE int Parent() const { return parent_ & ((1U << 31) - 1); }
203  [[nodiscard]] XGBOOST_DEVICE bool IsLeftChild() const { return (parent_ & (1U << 31)) != 0; }
205  [[nodiscard]] XGBOOST_DEVICE bool IsDeleted() const { return sindex_ == kDeletedNodeMarker; }
207  [[nodiscard]] XGBOOST_DEVICE bool IsRoot() const { return parent_ == kInvalidNodeId; }
212  XGBOOST_DEVICE void SetLeftChild(int nid) {
213  this->cleft_ = nid;
214  }
220  this->cright_ = nid;
221  }
228  XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond,
229  bool default_left = false) {
230  if (default_left) split_index |= (1U << 31);
231  this->sindex_ = split_index;
232  (this->info_).split_cond = split_cond;
233  }
240  XGBOOST_DEVICE void SetLeaf(bst_float value, int right = kInvalidNodeId) {
241  (this->info_).leaf_value = value;
242  this->cleft_ = kInvalidNodeId;
243  this->cright_ = right;
244  }
247  this->sindex_ = kDeletedNodeMarker;
248  }
251  this->sindex_ = 0;
252  }
253  // set parent
254  XGBOOST_DEVICE void SetParent(int pidx, bool is_left_child = true) {
255  if (is_left_child) pidx |= (1U << 31);
256  this->parent_ = pidx;
257  }
258  bool operator==(const Node& b) const {
259  return parent_ == b.parent_ && cleft_ == b.cleft_ &&
260  cright_ == b.cright_ && sindex_ == b.sindex_ &&
261  info_.leaf_value == b.info_.leaf_value;
262  }
264  [[nodiscard]] Node ByteSwap() const {
265  Node x = *this;
266  dmlc::ByteSwap(&x.parent_, sizeof(x.parent_), 1);
267  dmlc::ByteSwap(&x.cleft_, sizeof(x.cleft_), 1);
268  dmlc::ByteSwap(&x.cright_, sizeof(x.cright_), 1);
269  dmlc::ByteSwap(&x.sindex_, sizeof(x.sindex_), 1);
270  dmlc::ByteSwap(&x.info_, sizeof(x.info_), 1);
271  return x;
272  }
274  private:
279  union Info{
280  bst_float leaf_value;
281  SplitCondT split_cond;
282  };
283  // pointer to parent, highest bit is used to
284  // indicate whether it's a left child or not
285  int32_t parent_{kInvalidNodeId};
286  // pointer to left, right
287  int32_t cleft_{kInvalidNodeId}, cright_{kInvalidNodeId};
288  // split feature index, left split or right split depends on the highest bit
289  uint32_t sindex_{0};
290  // extra info
291  Info info_;
292  };
299  void ChangeToLeaf(int rid, bst_float value) {
300  CHECK(nodes_[nodes_[rid].LeftChild() ].IsLeaf());
301  CHECK(nodes_[nodes_[rid].RightChild()].IsLeaf());
302  this->DeleteNode(nodes_[rid].LeftChild());
303  this->DeleteNode(nodes_[rid].RightChild());
304  nodes_[rid].SetLeaf(value);
305  }
311  void CollapseToLeaf(int rid, bst_float value) {
312  if (nodes_[rid].IsLeaf()) return;
313  if (!nodes_[nodes_[rid].LeftChild() ].IsLeaf()) {
314  CollapseToLeaf(nodes_[rid].LeftChild(), 0.0f);
315  }
316  if (!nodes_[nodes_[rid].RightChild() ].IsLeaf()) {
317  CollapseToLeaf(nodes_[rid].RightChild(), 0.0f);
318  }
319  this->ChangeToLeaf(rid, value);
320  }
323  param_.Init(Args{});
324  nodes_.resize(param_.num_nodes);
325  stats_.resize(param_.num_nodes);
326  split_types_.resize(param_.num_nodes, FeatureType::kNumerical);
327  split_categories_segments_.resize(param_.num_nodes);
328  for (int i = 0; i < param_.num_nodes; i++) {
329  nodes_[i].SetLeaf(0.0f);
330  nodes_[i].SetParent(kInvalidNodeId);
331  }
332  }
336  explicit RegTree(bst_target_t n_targets, bst_feature_t n_features) : RegTree{} {
337  param_.num_feature = n_features;
338  param_.size_leaf_vector = n_targets;
339  if (n_targets > 1) {
340  this->p_mt_tree_.reset(new MultiTargetTree{&param_});
341  }
342  }
345  Node& operator[](int nid) {
346  return nodes_[nid];
347  }
349  const Node& operator[](int nid) const {
350  return nodes_[nid];
351  }
354  [[nodiscard]] const std::vector<Node>& GetNodes() const { return nodes_; }
357  [[nodiscard]] const std::vector<RTreeNodeStat>& GetStats() const { return stats_; }
360  RTreeNodeStat& Stat(int nid) {
361  return stats_[nid];
362  }
364  [[nodiscard]] const RTreeNodeStat& Stat(int nid) const {
365  return stats_[nid];
366  }
372  void Load(dmlc::Stream* fi);
377  void Save(dmlc::Stream* fo) const;
379  void LoadModel(Json const& in) override;
380  void SaveModel(Json* out) const override;
382  bool operator==(const RegTree& b) const {
383  return nodes_ == b.nodes_ && stats_ == b.stats_ &&
384  deleted_nodes_ == b.deleted_nodes_ && param_ == b.param_;
385  }
386  /* \brief Iterate through all nodes in this tree.
387  *
388  * \param Function that accepts a node index, and returns false when iteration should
389  * stop, otherwise returns true.
390  */
391  template <typename Func> void WalkTree(Func func) const {
392  std::stack<bst_node_t> nodes;
393  nodes.push(kRoot);
394  auto &self = *this;
395  while (!nodes.empty()) {
396  auto nidx = nodes.top();
397  nodes.pop();
398  if (!func(nidx)) {
399  return;
400  }
401  auto left = self[nidx].LeftChild();
402  auto right = self[nidx].RightChild();
403  if (left != RegTree::kInvalidNodeId) {
404  nodes.push(left);
405  }
406  if (right != RegTree::kInvalidNodeId) {
407  nodes.push(right);
408  }
409  }
410  }
417  [[nodiscard]] bool Equal(const RegTree& b) const;
436  void ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value,
437  bool default_left, bst_float base_weight,
438  bst_float left_leaf_weight, bst_float right_leaf_weight,
439  bst_float loss_change, float sum_hess, float left_sum,
440  float right_sum,
441  bst_node_t leaf_right_child = kInvalidNodeId);
445  void ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left,
448  linalg::VectorView<float const> right_weight);
466  common::Span<const uint32_t> split_cat, bool default_left,
467  bst_float base_weight, bst_float left_leaf_weight,
468  bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
469  float left_sum, float right_sum);
473  [[nodiscard]] bool HasCategoricalSplit() const { return !split_categories_.empty(); }
477  [[nodiscard]] bool IsMultiTarget() const { return static_cast<bool>(p_mt_tree_); }
481  [[nodiscard]] bst_target_t NumTargets() const { return param_.size_leaf_vector; }
485  [[nodiscard]] auto GetMultiTargetTree() const {
486  CHECK(IsMultiTarget());
487  return p_mt_tree_.get();
488  }
492  [[nodiscard]] bst_feature_t NumFeatures() const noexcept { return param_.num_feature; }
496  [[nodiscard]] bst_node_t NumNodes() const noexcept { return param_.num_nodes; }
500  [[nodiscard]] bst_node_t NumValidNodes() const noexcept {
501  return param_.num_nodes - param_.num_deleted;
502  }
506  [[nodiscard]] bst_node_t NumExtraNodes() const noexcept {
507  return param_.num_nodes - 1 - param_.num_deleted;
508  }
509  /* \brief Count number of leaves in tree. */
510  [[nodiscard]] bst_node_t GetNumLeaves() const;
511  [[nodiscard]] bst_node_t GetNumSplitNodes() const;
517  [[nodiscard]] std::int32_t GetDepth(bst_node_t nid) const {
518  if (IsMultiTarget()) {
519  return this->p_mt_tree_->Depth(nid);
520  }
521  int depth = 0;
522  while (!nodes_[nid].IsRoot()) {
523  ++depth;
524  nid = nodes_[nid].Parent();
525  }
526  return depth;
527  }
532  CHECK(IsMultiTarget());
533  return this->p_mt_tree_->SetLeaf(nidx, weight);
534  }
540  [[nodiscard]] int MaxDepth(int nid) const {
541  if (nodes_[nid].IsLeaf()) return 0;
542  return std::max(MaxDepth(nodes_[nid].LeftChild()) + 1, MaxDepth(nodes_[nid].RightChild()) + 1);
543  }
548  int MaxDepth() { return MaxDepth(0); }
554  struct FVec {
559  void Init(size_t size);
564  void Fill(const SparsePage::Inst& inst);
570  void Drop();
575  [[nodiscard]] size_t Size() const;
581  [[nodiscard]] bst_float GetFvalue(size_t i) const;
587  [[nodiscard]] bool IsMissing(size_t i) const;
588  [[nodiscard]] bool HasMissing() const;
591  private:
596  union Entry {
597  bst_float fvalue;
598  int flag;
599  };
600  std::vector<Entry> data_;
601  bool has_missing_;
602  };
610  std::vector<float>* mean_values,
611  bst_float* out_contribs) const;
619  [[nodiscard]] std::string DumpModel(const FeatureMap& fmap, bool with_stats,
620  std::string format) const;
626  [[nodiscard]] FeatureType NodeSplitType(bst_node_t nidx) const { return split_types_.at(nidx); }
630  [[nodiscard]] std::vector<FeatureType> const& GetSplitTypes() const {
631  return split_types_;
632  }
634  return split_categories_;
635  }
640  auto node_ptr = GetCategoriesMatrix().node_ptr;
641  auto categories = GetCategoriesMatrix().categories;
642  auto segment = node_ptr[nidx];
643  auto node_cats = categories.subspan(segment.beg, segment.size);
644  return node_cats;
645  }
646  [[nodiscard]] auto const& GetSplitCategoriesPtr() const { return split_categories_segments_; }
656  struct Segment {
657  std::size_t beg{0};
658  std::size_t size{0};
659  };
663  };
668  view.categories = this->GetSplitCategories();
669  view.node_ptr = common::Span<CategoricalSplitMatrix::Segment const>(split_categories_segments_);
670  return view;
671  }
673  [[nodiscard]] bst_feature_t SplitIndex(bst_node_t nidx) const {
674  if (IsMultiTarget()) {
675  return this->p_mt_tree_->SplitIndex(nidx);
676  }
677  return (*this)[nidx].SplitIndex();
678  }
679  [[nodiscard]] float SplitCond(bst_node_t nidx) const {
680  if (IsMultiTarget()) {
681  return this->p_mt_tree_->SplitCond(nidx);
682  }
683  return (*this)[nidx].SplitCond();
684  }
685  [[nodiscard]] bool DefaultLeft(bst_node_t nidx) const {
686  if (IsMultiTarget()) {
687  return this->p_mt_tree_->DefaultLeft(nidx);
688  }
689  return (*this)[nidx].DefaultLeft();
690  }
691  [[nodiscard]] bool IsRoot(bst_node_t nidx) const {
692  if (IsMultiTarget()) {
693  return nidx == kRoot;
694  }
695  return (*this)[nidx].IsRoot();
696  }
697  [[nodiscard]] bool IsLeaf(bst_node_t nidx) const {
698  if (IsMultiTarget()) {
699  return this->p_mt_tree_->IsLeaf(nidx);
700  }
701  return (*this)[nidx].IsLeaf();
702  }
703  [[nodiscard]] bst_node_t Parent(bst_node_t nidx) const {
704  if (IsMultiTarget()) {
705  return this->p_mt_tree_->Parent(nidx);
706  }
707  return (*this)[nidx].Parent();
708  }
709  [[nodiscard]] bst_node_t LeftChild(bst_node_t nidx) const {
710  if (IsMultiTarget()) {
711  return this->p_mt_tree_->LeftChild(nidx);
712  }
713  return (*this)[nidx].LeftChild();
714  }
715  [[nodiscard]] bst_node_t RightChild(bst_node_t nidx) const {
716  if (IsMultiTarget()) {
717  return this->p_mt_tree_->RightChild(nidx);
718  }
719  return (*this)[nidx].RightChild();
720  }
721  [[nodiscard]] bool IsLeftChild(bst_node_t nidx) const {
722  if (IsMultiTarget()) {
723  CHECK_NE(nidx, kRoot);
724  auto p = this->p_mt_tree_->Parent(nidx);
725  return nidx == this->p_mt_tree_->LeftChild(p);
726  }
727  return (*this)[nidx].IsLeftChild();
728  }
729  [[nodiscard]] bst_node_t Size() const {
730  if (IsMultiTarget()) {
731  return this->p_mt_tree_->Size();
732  }
733  return this->nodes_.size();
734  }
736  private:
737  template <bool typed>
738  void LoadCategoricalSplit(Json const& in);
739  void SaveCategoricalSplit(Json* p_out) const;
741  TreeParam param_;
742  // vector of nodes
743  std::vector<Node> nodes_;
744  // free node space, used during training process
745  std::vector<int> deleted_nodes_;
746  // stats of nodes
747  std::vector<RTreeNodeStat> stats_;
748  std::vector<FeatureType> split_types_;
750  // Categories for each internal node.
751  std::vector<uint32_t> split_categories_;
752  // Ptr to split categories of each node.
753  std::vector<CategoricalSplitMatrix::Segment> split_categories_segments_;
754  // ptr to multi-target tree with vector leaf.
756  // allocate a new node,
757  // !!!!!! NOTE: may cause BUG here, nodes.resize
758  bst_node_t AllocNode() {
759  if (param_.num_deleted != 0) {
760  int nid = deleted_nodes_.back();
761  deleted_nodes_.pop_back();
762  nodes_[nid].Reuse();
763  --param_.num_deleted;
764  return nid;
765  }
766  int nd = param_.num_nodes++;
767  CHECK_LT(param_.num_nodes, std::numeric_limits<int>::max())
768  << "number of nodes in the tree exceed 2^31";
769  nodes_.resize(param_.num_nodes);
770  stats_.resize(param_.num_nodes);
771  split_types_.resize(param_.num_nodes, FeatureType::kNumerical);
772  split_categories_segments_.resize(param_.num_nodes);
773  return nd;
774  }
775  // delete a tree node, keep the parent field to allow trace back
776  void DeleteNode(int nid) {
777  CHECK_GE(nid, 1);
778  auto pid = (*this)[nid].Parent();
779  if (nid == (*this)[pid].LeftChild()) {
780  (*this)[pid].SetLeftChild(kInvalidNodeId);
781  } else {
782  (*this)[pid].SetRightChild(kInvalidNodeId);
783  }
785  deleted_nodes_.push_back(nid);
786  nodes_[nid].MarkDelete();
787  ++param_.num_deleted;
788  }
789 };
791 inline void RegTree::FVec::Init(size_t size) {
792  Entry e; e.flag = -1;
793  data_.resize(size);
794  std::fill(data_.begin(), data_.end(), e);
795  has_missing_ = true;
796 }
798 inline void RegTree::FVec::Fill(const SparsePage::Inst& inst) {
799  size_t feature_count = 0;
800  for (auto const& entry : inst) {
801  if (entry.index >= data_.size()) {
802  continue;
803  }
804  data_[entry.index].fvalue = entry.fvalue;
805  ++feature_count;
806  }
807  has_missing_ = data_.size() != feature_count;
808 }
810 inline void RegTree::FVec::Drop() {
811  Entry e{};
812  e.flag = -1;
813  std::fill_n(data_.data(), data_.size(), e);
814  has_missing_ = true;
815 }
817 inline size_t RegTree::FVec::Size() const {
818  return data_.size();
819 }
821 inline bst_float RegTree::FVec::GetFvalue(size_t i) const {
822  return data_[i].fvalue;
823 }
825 inline bool RegTree::FVec::IsMissing(size_t i) const {
826  return data_[i].flag == -1;
827 }
829 inline bool RegTree::FVec::HasMissing() const {
830  return has_missing_;
831 }
833 // Multi-target tree not yet implemented error
835  return " support for multi-target tree is not yet implemented.";
836 }
837 } // namespace xgboost
838 #endif // XGBOOST_TREE_MODEL_H_
