xgboost
tree_model.h
Go to the documentation of this file.
1 
7 #ifndef XGBOOST_TREE_MODEL_H_
8 #define XGBOOST_TREE_MODEL_H_
9 
10 #include <xgboost/base.h>
11 #include <xgboost/data.h>
12 #include <xgboost/feature_map.h>
13 #include <xgboost/host_device_vector.h> // for HostDeviceVector
14 #include <xgboost/linalg.h> // for VectorView
15 #include <xgboost/logging.h>
16 #include <xgboost/model.h>
17 #include <xgboost/multi_target_tree_model.h> // for MultiTargetTree
18 
19 #include <algorithm>
20 #include <cstring>
21 #include <limits> // for numeric_limits
22 #include <memory> // for unique_ptr
23 #include <string>
24 #include <type_traits> // for is_signed_v
25 #include <vector>
26 
27 namespace xgboost {
28 
29 namespace tree {
30 struct ScalarTreeView;
31 struct MultiTargetTreeView;
32 } // namespace tree
33 
34 class Json;
35 
37 struct TreeParam {
46 
47  bool operator==(const TreeParam& b) const {
48  return num_nodes == b.num_nodes && num_deleted == b.num_deleted &&
50  }
51 
52  void FromJson(Json const& in);
53  void ToJson(Json* p_out) const;
54 };
55 
57 struct RTreeNodeStat {
59  float loss_chg;
61  float sum_hess;
63  float base_weight;
66 
67  RTreeNodeStat() = default;
68  RTreeNodeStat(float loss_chg, float sum_hess, float weight)
70  bool operator==(const RTreeNodeStat& b) const {
71  return loss_chg == b.loss_chg && sum_hess == b.sum_hess && base_weight == b.base_weight &&
73  }
74 };
75 
81 class RegTree : public Model {
82  public:
83  using SplitCondT = float;
85  static constexpr uint32_t kDeletedNodeMarker = std::numeric_limits<uint32_t>::max();
86  static constexpr bst_node_t kRoot{0};
87 
89  class Node {
90  public:
92  // assert compact alignment
93  static_assert(sizeof(Node) == 4 * sizeof(int) + sizeof(Info), "Node: 64 bit align");
94  }
95  Node(int32_t cleft, int32_t cright, int32_t parent, uint32_t split_ind, float split_cond,
96  bool default_left)
97  : parent_{parent}, cleft_{cleft}, cright_{cright} {
98  this->SetParent(parent_);
99  this->SetSplit(split_ind, split_cond, default_left);
100  }
101 
103  [[nodiscard]] XGBOOST_DEVICE int LeftChild() const { return this->cleft_; }
105  [[nodiscard]] XGBOOST_DEVICE int RightChild() const { return this->cright_; }
107  [[nodiscard]] XGBOOST_DEVICE int DefaultChild() const {
108  return this->DefaultLeft() ? this->LeftChild() : this->RightChild();
109  }
111  [[nodiscard]] XGBOOST_DEVICE bst_feature_t SplitIndex() const {
112  static_assert(!std::is_signed_v<bst_feature_t>);
113  return sindex_ & ((1U << 31) - 1U);
114  }
116  [[nodiscard]] XGBOOST_DEVICE bool DefaultLeft() const { return (sindex_ >> 31) != 0; }
118  [[nodiscard]] XGBOOST_DEVICE bool IsLeaf() const { return cleft_ == kInvalidNodeId; }
120  [[nodiscard]] XGBOOST_DEVICE float LeafValue() const { return (this->info_).leaf_value; }
122  [[nodiscard]] XGBOOST_DEVICE SplitCondT SplitCond() const { return (this->info_).split_cond; }
124  [[nodiscard]] XGBOOST_DEVICE int Parent() const { return parent_ & ((1U << 31) - 1); }
126  [[nodiscard]] XGBOOST_DEVICE bool IsLeftChild() const { return (parent_ & (1U << 31)) != 0; }
128  [[nodiscard]] XGBOOST_DEVICE bool IsDeleted() const { return sindex_ == kDeletedNodeMarker; }
130  [[nodiscard]] XGBOOST_DEVICE bool IsRoot() const { return parent_ == kInvalidNodeId; }
135  XGBOOST_DEVICE void SetLeftChild(int nid) { this->cleft_ = nid; }
140  XGBOOST_DEVICE void SetRightChild(int nid) { this->cright_ = nid; }
147  XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond,
148  bool default_left = false) {
149  if (default_left) split_index |= (1U << 31);
150  this->sindex_ = split_index;
151  (this->info_).split_cond = split_cond;
152  }
159  XGBOOST_DEVICE void SetLeaf(bst_float value, int right = kInvalidNodeId) {
160  (this->info_).leaf_value = value;
161  this->cleft_ = kInvalidNodeId;
162  this->cright_ = right;
163  }
165  XGBOOST_DEVICE void MarkDelete() { this->sindex_ = kDeletedNodeMarker; }
167  XGBOOST_DEVICE void Reuse() { this->sindex_ = 0; }
168  // set parent
169  XGBOOST_DEVICE void SetParent(int pidx, bool is_left_child = true) {
170  if (is_left_child) pidx |= (1U << 31);
171  this->parent_ = pidx;
172  }
173  bool operator==(const Node& b) const {
174  return parent_ == b.parent_ && cleft_ == b.cleft_ && cright_ == b.cright_ &&
175  sindex_ == b.sindex_ && info_.leaf_value == b.info_.leaf_value;
176  }
177 
178  private:
183  union Info {
184  bst_float leaf_value;
185  SplitCondT split_cond;
186  };
187  // pointer to parent, highest bit is used to
188  // indicate whether it's a left child or not
189  int32_t parent_{kInvalidNodeId};
190  // pointer to left, right
191  int32_t cleft_{kInvalidNodeId}, cright_{kInvalidNodeId};
192  // split feature index, left split or right split depends on the highest bit
193  uint32_t sindex_{0};
194  // extra info
195  Info info_;
196  };
197 
204  void ChangeToLeaf(bst_node_t nidx, float value) {
205  auto& h_nodes = nodes_.HostVector();
206  CHECK(h_nodes[h_nodes[nidx].LeftChild()].IsLeaf());
207  CHECK(h_nodes[h_nodes[nidx].RightChild()].IsLeaf());
208  this->DeleteNode(h_nodes[nidx].LeftChild());
209  this->DeleteNode(h_nodes[nidx].RightChild());
210  h_nodes[nidx].SetLeaf(value);
211  }
218  void CollapseToLeaf(bst_node_t nidx, float value) {
219  auto& h_nodes = nodes_.HostVector();
220  if (h_nodes[nidx].IsLeaf()) return;
221  if (!h_nodes[h_nodes[nidx].LeftChild()].IsLeaf()) {
222  CollapseToLeaf(h_nodes[nidx].LeftChild(), 0.0f);
223  }
224  if (!h_nodes[h_nodes[nidx].RightChild()].IsLeaf()) {
225  CollapseToLeaf(h_nodes[nidx].RightChild(), 0.0f);
226  }
227  this->ChangeToLeaf(nidx, value);
228  }
229 
231  nodes_.HostVector().resize(param_.num_nodes);
232  stats_.HostVector().resize(param_.num_nodes);
233  split_types_.HostVector().resize(param_.num_nodes, FeatureType::kNumerical);
234  split_categories_segments_.HostVector().resize(param_.num_nodes);
235  auto& h_nodes = nodes_.HostVector();
236  for (int i = 0; i < param_.num_nodes; i++) {
237  h_nodes[i].SetLeaf(0.0f);
238  h_nodes[i].SetParent(kInvalidNodeId);
239  }
240  }
244  explicit RegTree(bst_target_t n_targets, bst_feature_t n_features) : RegTree{} {
245  param_.num_feature = n_features;
246  param_.size_leaf_vector = n_targets;
247  if (n_targets > 1) {
248  this->p_mt_tree_.reset(new MultiTargetTree{&param_});
249  }
250  }
251 
253  Node& operator[](bst_node_t nidx) { return nodes_.HostVector()[nidx]; }
254 
255  public:
257  [[nodiscard]] common::Span<Node const> GetNodes(DeviceOrd device) const {
258  CHECK(!this->IsMultiTarget());
259  return device.IsCPU() ? nodes_.ConstHostSpan()
260  : (nodes_.SetDevice(device), nodes_.ConstDeviceSpan());
261  }
262 
265  CHECK(!this->IsMultiTarget());
266  return device.IsCPU() ? stats_.ConstHostSpan()
267  : (stats_.SetDevice(device), stats_.ConstDeviceSpan());
268  }
269 
271  RTreeNodeStat& Stat(int nid) { return stats_.HostVector()[nid]; }
272 
273  void LoadModel(Json const& in) override;
274  void SaveModel(Json* out) const override;
275 
276  bool operator==(const RegTree& b) const {
277  return nodes_.ConstHostVector() == b.nodes_.ConstHostVector() &&
278  stats_.ConstHostVector() == b.stats_.ConstHostVector() &&
279  deleted_nodes_ == b.deleted_nodes_ && param_ == b.param_;
280  }
287  [[nodiscard]] bool Equal(const RegTree& b) const;
288 
306  void ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value, bool default_left,
307  bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight,
308  bst_float loss_change, float sum_hess, float left_sum, float right_sum,
309  bst_node_t leaf_right_child = kInvalidNodeId);
318  void ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left,
321  linalg::VectorView<float const> right_weight, float loss_chg, float sum_hess,
322  float left_sum, float right_sum);
333  void SetLeaves(std::vector<bst_node_t> leaves, common::Span<float const> weights);
334 
351  common::Span<const uint32_t> split_cat, bool default_left,
352  bst_float base_weight, bst_float left_leaf_weight,
353  bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
354  float left_sum, float right_sum);
359  common::Span<const uint32_t> split_cat, bool default_left,
362  linalg::VectorView<float const> right_weight, float loss_chg,
363  float sum_hess, float left_sum, float right_sum);
367  [[nodiscard]] bool HasCategoricalSplit() const { return !split_categories_.Empty(); }
371  [[nodiscard]] bool IsMultiTarget() const { return static_cast<bool>(p_mt_tree_); }
375  [[nodiscard]] bst_target_t NumTargets() const { return param_.size_leaf_vector; }
379  [[nodiscard]] auto GetMultiTargetTree() const {
380  CHECK(IsMultiTarget());
381  return p_mt_tree_.get();
382  }
386  [[nodiscard]] bst_feature_t NumFeatures() const noexcept { return param_.num_feature; }
390  [[nodiscard]] bst_node_t NumNodes() const noexcept { return param_.num_nodes; }
394  [[nodiscard]] bst_node_t NumValidNodes() const noexcept {
395  return param_.num_nodes - param_.num_deleted;
396  }
400  [[nodiscard]] bst_node_t NumExtraNodes() const noexcept {
401  return param_.num_nodes - 1 - param_.num_deleted;
402  }
403  /* \brief Count number of leaves in tree. */
404  [[nodiscard]] bst_node_t GetNumLeaves() const;
405  [[nodiscard]] bst_node_t GetNumSplitNodes() const;
406 
410  [[nodiscard]] bst_node_t GetDepth(bst_node_t nidx) const;
417  void SetRoot(linalg::VectorView<float const> weight, float sum_hess) {
418  CHECK(IsMultiTarget());
419  return this->p_mt_tree_->SetRoot(weight, sum_hess);
420  }
424  [[nodiscard]] bst_node_t MaxDepth() const;
425 
430  struct FVec {
435  void Init(size_t size);
440  void Fill(SparsePage::Inst const& inst);
441 
446  void Drop();
451  [[nodiscard]] size_t Size() const;
457  [[nodiscard]] bst_float GetFvalue(size_t i) const;
463  [[nodiscard]] bool IsMissing(size_t i) const;
464  [[nodiscard]] bool HasMissing() const;
465  void HasMissing(bool has_missing) { this->has_missing_ = has_missing; }
466 
467  [[nodiscard]] common::Span<float> Data() { return data_; }
468 
469  private:
475  std::vector<float> data_;
476  bool has_missing_;
477  };
478 
486  [[nodiscard]] std::string DumpModel(const FeatureMap& fmap, bool with_stats,
487  std::string format) const;
492  return device.IsCPU() ? split_types_.ConstHostSpan()
493  : (split_types_.SetDevice(device), split_types_.ConstDeviceSpan());
494  }
496  return device.IsCPU()
497  ? split_categories_.ConstHostSpan()
498  : (split_categories_.SetDevice(device), split_categories_.ConstDeviceSpan());
499  }
500  [[nodiscard]] auto const& GetSplitCategoriesPtr() const {
501  return split_categories_segments_.ConstHostVector();
502  }
503 
512  struct Segment {
513  std::size_t beg{0};
514  std::size_t size{0};
515  };
519  };
520 
523  view.split_type = this->GetSplitTypes(device);
524  view.categories = this->GetSplitCategories(device);
525  if (device.IsCPU()) {
526  view.node_ptr = split_categories_segments_.ConstHostSpan();
527  } else {
528  split_categories_segments_.SetDevice(device);
529  view.node_ptr = split_categories_segments_.ConstDeviceSpan();
530  }
531  return view;
532  }
533 
534  [[nodiscard]] bst_node_t LeftChild(bst_node_t nidx) const {
535  if (IsMultiTarget()) {
536  return this->p_mt_tree_->LeftChild(nidx);
537  }
538  return nodes_.ConstHostVector()[nidx].LeftChild();
539  }
540  [[nodiscard]] bst_node_t RightChild(bst_node_t nidx) const {
541  if (IsMultiTarget()) {
542  return this->p_mt_tree_->RightChild(nidx);
543  }
544  return nodes_.ConstHostVector()[nidx].RightChild();
545  }
546  [[nodiscard]] bst_node_t Size() const {
547  if (IsMultiTarget()) {
548  return this->p_mt_tree_->Size();
549  }
550  return this->nodes_.Size();
551  }
552 
553  [[nodiscard]] RegTree* Copy() const;
554  tree::ScalarTreeView HostScView() const;
555  tree::MultiTargetTreeView HostMtView() const;
556 
557  private:
558  template <bool typed>
559  void LoadCategoricalSplit(Json const& in);
560  void SaveCategoricalSplit(Json* p_out) const;
562  TreeParam param_;
563  // vector of nodes
564  HostDeviceVector<Node> nodes_;
565  // free node space, used during training process
566  std::vector<int> deleted_nodes_;
567  // stats of nodes
569  HostDeviceVector<FeatureType> split_types_;
570 
571  // Categories for each internal node.
572  HostDeviceVector<uint32_t> split_categories_;
573  // Ptr to split categories of each node.
574  HostDeviceVector<CategoricalSplitMatrix::Segment> split_categories_segments_;
575  // ptr to multi-target tree with vector leaf.
576  std::unique_ptr<MultiTargetTree> p_mt_tree_;
577  // allocate a new node,
578  // !!!!!! NOTE: may cause BUG here, nodes.resize
579  bst_node_t AllocNode() {
580  if (param_.num_deleted != 0) {
581  int nid = deleted_nodes_.back();
582  deleted_nodes_.pop_back();
583  nodes_.HostVector()[nid].Reuse();
584  --param_.num_deleted;
585  return nid;
586  }
587  int nd = param_.num_nodes++;
588  CHECK_LT(param_.num_nodes, std::numeric_limits<int>::max())
589  << "number of nodes in the tree exceed 2^31";
590  nodes_.HostVector().resize(param_.num_nodes);
591  stats_.HostVector().resize(param_.num_nodes);
592  split_types_.HostVector().resize(param_.num_nodes, FeatureType::kNumerical);
593  split_categories_segments_.HostVector().resize(param_.num_nodes);
594  return nd;
595  }
596  // delete a tree node, keep the parent field to allow trace back
597  void DeleteNode(int nid) {
598  CHECK_GE(nid, 1);
599  auto pid = (*this)[nid].Parent();
600  if (nid == (*this)[pid].LeftChild()) {
601  (*this)[pid].SetLeftChild(kInvalidNodeId);
602  } else {
603  (*this)[pid].SetRightChild(kInvalidNodeId);
604  }
605 
606  deleted_nodes_.push_back(nid);
607  nodes_.HostVector()[nid].MarkDelete();
608  ++param_.num_deleted;
609  }
610 };
611 
612 inline void RegTree::FVec::Init(size_t size) {
613  data_.resize(size);
614  std::fill(data_.begin(), data_.end(), std::numeric_limits<float>::quiet_NaN());
615  has_missing_ = true;
616 }
617 
618 inline void RegTree::FVec::Fill(SparsePage::Inst const& inst) {
619  auto p_data = inst.data();
620  auto p_out = data_.data();
621 
622  for (std::size_t i = 0, n = inst.size(); i < n; ++i) {
623  auto const& entry = p_data[i];
624  p_out[entry.index] = entry.fvalue;
625  }
626  has_missing_ = data_.size() != inst.size();
627 }
628 
629 inline void RegTree::FVec::Drop() { this->Init(this->Size()); }
630 
631 inline size_t RegTree::FVec::Size() const { return data_.size(); }
632 
633 inline float RegTree::FVec::GetFvalue(size_t i) const { return data_[i]; }
634 
635 inline bool RegTree::FVec::IsMissing(size_t i) const { return std::isnan(data_[i]); }
636 
637 inline bool RegTree::FVec::HasMissing() const { return has_missing_; }
638 
639 // Multi-target tree not yet implemented error
641  return " support for multi-target tree is not yet implemented.";
642 }
643 } // namespace xgboost
644 #endif // XGBOOST_TREE_MODEL_H_
Defines configuration macros and basic types for xgboost.
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition: base.h:57
Feature map data structure to help text model dump. TODO(tqchen) consider make it even more lightweig...
Definition: feature_map.h:22
Definition: host_device_vector.h:89
bool Empty() const
Definition: host_device_vector.h:104
common::Span< T const > ConstHostSpan() const
Definition: host_device_vector.h:118
std::vector< T > & HostVector()
common::Span< const T > ConstDeviceSpan() const
void SetDevice(DeviceOrd device) const
Data structure representing JSON format.
Definition: json.h:396
Tree structure for multi-target model.
Definition: multi_target_tree_model.h:38
static constexpr bst_node_t InvalidNodeId()
Definition: multi_target_tree_model.h:40
tree node
Definition: tree_model.h:89
XGBOOST_DEVICE int Parent() const
get parent of the node
Definition: tree_model.h:124
XGBOOST_DEVICE void MarkDelete()
mark that this node is deleted
Definition: tree_model.h:165
XGBOOST_DEVICE bool IsRoot() const
whether current node is root
Definition: tree_model.h:130
XGBOOST_DEVICE int RightChild() const
index of right child
Definition: tree_model.h:105
XGBOOST_DEVICE float LeafValue() const
Definition: tree_model.h:120
XGBOOST_DEVICE Node()
Definition: tree_model.h:91
XGBOOST_DEVICE void SetParent(int pidx, bool is_left_child=true)
Definition: tree_model.h:169
XGBOOST_DEVICE void SetLeaf(bst_float value, int right=kInvalidNodeId)
set the leaf value of the node
Definition: tree_model.h:159
XGBOOST_DEVICE bool IsLeftChild() const
whether current node is left child
Definition: tree_model.h:126
XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond, bool default_left=false)
set split condition of current node
Definition: tree_model.h:147
XGBOOST_DEVICE void SetLeftChild(int nid)
set the left child
Definition: tree_model.h:135
XGBOOST_DEVICE bst_feature_t SplitIndex() const
feature index of split condition
Definition: tree_model.h:111
XGBOOST_DEVICE bool IsDeleted() const
whether this node is deleted
Definition: tree_model.h:128
XGBOOST_DEVICE bool IsLeaf() const
whether current node is leaf node
Definition: tree_model.h:118
bool operator==(const Node &b) const
Definition: tree_model.h:173
Node(int32_t cleft, int32_t cright, int32_t parent, uint32_t split_ind, float split_cond, bool default_left)
Definition: tree_model.h:95
XGBOOST_DEVICE void Reuse()
Reuse this deleted node.
Definition: tree_model.h:167
XGBOOST_DEVICE void SetRightChild(int nid)
set the right child
Definition: tree_model.h:140
XGBOOST_DEVICE bool DefaultLeft() const
when feature is unknown, whether goes to left child
Definition: tree_model.h:116
XGBOOST_DEVICE int LeftChild() const
index of left child
Definition: tree_model.h:103
XGBOOST_DEVICE int DefaultChild() const
index of default child when feature is missing
Definition: tree_model.h:107
XGBOOST_DEVICE SplitCondT SplitCond() const
Definition: tree_model.h:122
define regression tree to be the most common tree model.
Definition: tree_model.h:81
void SaveModel(Json *out) const override
saves the model config to a JSON object
tree::MultiTargetTreeView HostMtView() const
RegTree * Copy() const
void ChangeToLeaf(bst_node_t nidx, float value)
Change a non leaf node to a leaf node, delete its children.
Definition: tree_model.h:204
bst_target_t NumTargets() const
The size of leaf weight.
Definition: tree_model.h:375
bool operator==(const RegTree &b) const
Definition: tree_model.h:276
bst_node_t NumNodes() const noexcept
Get the total number of nodes including deleted ones in this tree.
Definition: tree_model.h:390
void ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum, bst_node_t leaf_right_child=kInvalidNodeId)
Expands a leaf node into two additional leaf nodes.
RegTree()
Definition: tree_model.h:230
static constexpr bst_node_t kInvalidNodeId
Definition: tree_model.h:84
void ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left, linalg::VectorView< float const > base_weight, linalg::VectorView< float const > left_weight, linalg::VectorView< float const > right_weight, float loss_chg, float sum_hess, float left_sum, float right_sum)
Expands a leaf node into two additional leaf nodes for a multi-target tree.
Node & operator[](bst_node_t nidx)
get node given nid
Definition: tree_model.h:253
static constexpr uint32_t kDeletedNodeMarker
Definition: tree_model.h:85
bool IsMultiTarget() const
Whether this is a multi-target tree.
Definition: tree_model.h:371
bst_node_t NumExtraNodes() const noexcept
number of extra nodes besides the root
Definition: tree_model.h:400
bst_node_t MaxDepth() const
Get the maximum depth.
auto GetMultiTargetTree() const
Get the underlying implementaiton of multi-target tree.
Definition: tree_model.h:379
void ExpandCategorical(bst_node_t nidx, bst_feature_t split_index, common::Span< const uint32_t > split_cat, bool default_left, linalg::VectorView< float const > base_weight, linalg::VectorView< float const > left_weight, linalg::VectorView< float const > right_weight, float loss_chg, float sum_hess, float left_sum, float right_sum)
Expands a leaf node with categories for a multi-target tree.
bst_node_t LeftChild(bst_node_t nidx) const
Definition: tree_model.h:534
common::Span< RTreeNodeStat const > GetStats(DeviceOrd device) const
Get const reference to stats.
Definition: tree_model.h:264
void SetLeaves(std::vector< bst_node_t > leaves, common::Span< float const > weights)
Set all leaf weights for a multi-target tree.
void CollapseToLeaf(bst_node_t nidx, float value)
Collapse a non leaf node to a leaf node, delete its children.
Definition: tree_model.h:218
bst_node_t GetNumLeaves() const
common::Span< Node const > GetNodes(DeviceOrd device) const
Get const reference to nodes.
Definition: tree_model.h:257
RegTree(bst_target_t n_targets, bst_feature_t n_features)
Constructor that initializes the tree model with shape.
Definition: tree_model.h:244
bst_node_t RightChild(bst_node_t nidx) const
Definition: tree_model.h:540
common::Span< FeatureType const > GetSplitTypes(DeviceOrd device) const
Get split types for all nodes.
Definition: tree_model.h:491
RTreeNodeStat & Stat(int nid)
get node statistics given nid
Definition: tree_model.h:271
void SetRoot(linalg::VectorView< float const > weight, float sum_hess)
Set the root weight and statistics for a multi-target tree.
Definition: tree_model.h:417
void ExpandCategorical(bst_node_t nid, bst_feature_t split_index, common::Span< const uint32_t > split_cat, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum)
Expands a leaf node with categories.
bool Equal(const RegTree &b) const
Compares whether 2 trees are equal from a user's perspective. The equality compares only non-deleted ...
bst_node_t NumValidNodes() const noexcept
Get the total number of valid nodes in this tree.
Definition: tree_model.h:394
float SplitCondT
Definition: tree_model.h:83
CategoricalSplitMatrix GetCategoriesMatrix(DeviceOrd device) const
Definition: tree_model.h:521
common::Span< uint32_t const > GetSplitCategories(DeviceOrd device) const
Definition: tree_model.h:495
void LoadModel(Json const &in) override
load the model from a JSON object
std::string DumpModel(const FeatureMap &fmap, bool with_stats, std::string format) const
dump the model in the requested format as a text string
bst_feature_t NumFeatures() const noexcept
Get the number of features.
Definition: tree_model.h:386
tree::ScalarTreeView HostScView() const
bool HasCategoricalSplit() const
Whether this tree has categorical split.
Definition: tree_model.h:367
static constexpr bst_node_t kRoot
Definition: tree_model.h:86
bst_node_t GetNumSplitNodes() const
auto const & GetSplitCategoriesPtr() const
Definition: tree_model.h:500
bst_node_t GetDepth(bst_node_t nidx) const
Get the depth of a node.
bst_node_t Size() const
Definition: tree_model.h:546
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition: span.h:435
constexpr XGBOOST_DEVICE pointer data() const __span_noexcept
Definition: span.h:554
constexpr XGBOOST_DEVICE index_type size() const __span_noexcept
Definition: span.h:559
A tensor view with static type and dimension. It implements indexing and slicing.
Definition: linalg.h:278
The input data structure of xgboost.
Feature map data structure to help visualization and model dump.
A device-and-host vector abstraction layer.
Linear algebra related utilities.
Defines the abstract interface for different components in XGBoost.
Learner interface that integrates objective, gbm and evaluation together. This is the user facing XGB...
Definition: base.h:89
std::int32_t bst_node_t
Type for tree node index and tree depth.
Definition: base.h:111
std::uint32_t bst_target_t
Type for indexing into output targets.
Definition: base.h:119
std::uint32_t bst_feature_t
Type for data column (feature) index.
Definition: base.h:99
float bst_float
float type, used for storing statistics
Definition: base.h:95
StringView MTNotImplemented()
Definition: tree_model.h:640
A type for device ordinal. The type is packed into 32-bit for efficient use in viewing types like lin...
Definition: context.h:40
bool IsCPU() const
Definition: context.h:56
Definition: model.h:14
node statistics used in regression tree
Definition: tree_model.h:57
RTreeNodeStat(float loss_chg, float sum_hess, float weight)
Definition: tree_model.h:68
float sum_hess
sum of hessian values, used to measure coverage of data
Definition: tree_model.h:61
int leaf_child_cnt
number of child that is leaf node known up to now
Definition: tree_model.h:65
float loss_chg
loss change caused by current split
Definition: tree_model.h:59
float base_weight
weight of current node
Definition: tree_model.h:63
bool operator==(const RTreeNodeStat &b) const
Definition: tree_model.h:70
std::size_t size
Definition: tree_model.h:514
std::size_t beg
Definition: tree_model.h:513
CSR-like matrix for categorical splits.
Definition: tree_model.h:511
common::Span< uint32_t const > categories
Definition: tree_model.h:517
common::Span< Segment const > node_ptr
Definition: tree_model.h:518
common::Span< FeatureType const > split_type
Definition: tree_model.h:516
dense feature vector that can be taken by RegTree and can be construct from sparse feature vector.
Definition: tree_model.h:430
void HasMissing(bool has_missing)
Definition: tree_model.h:465
void Drop()
drop the trace after fill, must be called after fill.
Definition: tree_model.h:629
bool HasMissing() const
Definition: tree_model.h:637
bool IsMissing(size_t i) const
check whether i-th entry is missing
Definition: tree_model.h:635
size_t Size() const
returns the size of the feature vector
Definition: tree_model.h:631
void Init(size_t size)
initialize the vector with size vector
Definition: tree_model.h:612
common::Span< float > Data()
Definition: tree_model.h:467
void Fill(SparsePage::Inst const &inst)
fill the vector with sparse vector
Definition: tree_model.h:618
bst_float GetFvalue(size_t i) const
get ith value
Definition: tree_model.h:633
Definition: string_view.h:16
meta parameters of the tree
Definition: tree_model.h:37
bst_node_t num_deleted
The number of deleted nodes.
Definition: tree_model.h:41
bst_feature_t num_feature
The number of features used for tree construction.
Definition: tree_model.h:43
bool operator==(const TreeParam &b) const
Definition: tree_model.h:47
bst_node_t num_nodes
The number of nodes.
Definition: tree_model.h:39
void ToJson(Json *p_out) const
void FromJson(Json const &in)
bst_target_t size_leaf_vector
leaf vector size. Used by the vector leaf.
Definition: tree_model.h:45