xgboost
tree_model.h
Go to the documentation of this file.
1 
7 #ifndef XGBOOST_TREE_MODEL_H_
8 #define XGBOOST_TREE_MODEL_H_
9 
10 #include <dmlc/io.h>
11 #include <dmlc/parameter.h>
12 
13 #include <xgboost/base.h>
14 #include <xgboost/data.h>
15 #include <xgboost/logging.h>
16 #include <xgboost/feature_map.h>
17 #include <xgboost/model.h>
18 
19 #include <limits>
20 #include <vector>
21 #include <string>
22 #include <cstring>
23 #include <algorithm>
24 #include <tuple>
25 #include <stack>
26 
27 namespace xgboost {
28 
29 struct PathElement; // forward declaration
30 
31 class Json;
32 // FIXME(trivialfis): Once binary IO is gone, make this parameter internal as it should
33 // not be configured by users.
35 struct TreeParam : public dmlc::Parameter<TreeParam> {
39  int num_nodes;
52  int reserved[31];
55  // assert compact alignment
56  static_assert(sizeof(TreeParam) == (31 + 6) * sizeof(int),
57  "TreeParam: 64 bit align");
58  std::memset(this, 0, sizeof(TreeParam));
59  num_nodes = 1;
61  }
62 
63  // Swap byte order for all fields. Useful for transporting models between machines with different
64  // endianness (big endian vs little endian)
65  inline TreeParam ByteSwap() const {
66  TreeParam x = *this;
67  dmlc::ByteSwap(&x.deprecated_num_roots, sizeof(x.deprecated_num_roots), 1);
68  dmlc::ByteSwap(&x.num_nodes, sizeof(x.num_nodes), 1);
69  dmlc::ByteSwap(&x.num_deleted, sizeof(x.num_deleted), 1);
70  dmlc::ByteSwap(&x.deprecated_max_depth, sizeof(x.deprecated_max_depth), 1);
71  dmlc::ByteSwap(&x.num_feature, sizeof(x.num_feature), 1);
72  dmlc::ByteSwap(&x.size_leaf_vector, sizeof(x.size_leaf_vector), 1);
73  dmlc::ByteSwap(x.reserved, sizeof(x.reserved[0]), sizeof(x.reserved) / sizeof(x.reserved[0]));
74  return x;
75  }
76 
77  // declare the parameters
79  // only declare the parameters that can be set by the user.
80  // other arguments are set by the algorithm.
81  DMLC_DECLARE_FIELD(num_nodes).set_lower_bound(1).set_default(1);
82  DMLC_DECLARE_FIELD(num_feature)
83  .describe("Number of features used in tree construction.");
84  DMLC_DECLARE_FIELD(num_deleted);
85  DMLC_DECLARE_FIELD(size_leaf_vector).set_lower_bound(0).set_default(0)
86  .describe("Size of leaf vector, reserved for vector tree");
87  }
88 
89  bool operator==(const TreeParam& b) const {
90  return num_nodes == b.num_nodes &&
91  num_deleted == b.num_deleted &&
92  num_feature == b.num_feature &&
94  }
95 };
96 
98 struct RTreeNodeStat {
106  int leaf_child_cnt {0};
107 
108  RTreeNodeStat() = default;
109  RTreeNodeStat(float loss_chg, float sum_hess, float weight) :
111  bool operator==(const RTreeNodeStat& b) const {
112  return loss_chg == b.loss_chg && sum_hess == b.sum_hess &&
114  }
115  // Swap byte order for all fields. Useful for transporting models between machines with different
116  // endianness (big endian vs little endian)
117  inline RTreeNodeStat ByteSwap() const {
118  RTreeNodeStat x = *this;
119  dmlc::ByteSwap(&x.loss_chg, sizeof(x.loss_chg), 1);
120  dmlc::ByteSwap(&x.sum_hess, sizeof(x.sum_hess), 1);
121  dmlc::ByteSwap(&x.base_weight, sizeof(x.base_weight), 1);
122  dmlc::ByteSwap(&x.leaf_child_cnt, sizeof(x.leaf_child_cnt), 1);
123  return x;
124  }
125 };
126 
131 class RegTree : public Model {
132  public:
134  static constexpr bst_node_t kInvalidNodeId {-1};
135  static constexpr uint32_t kDeletedNodeMarker = std::numeric_limits<uint32_t>::max();
136  static constexpr bst_node_t kRoot { 0 };
137 
139  class Node {
140  public:
142  // assert compact alignment
143  static_assert(sizeof(Node) == 4 * sizeof(int) + sizeof(Info),
144  "Node: 64 bit align");
145  }
146  Node(int32_t cleft, int32_t cright, int32_t parent,
147  uint32_t split_ind, float split_cond, bool default_left) :
148  parent_{parent}, cleft_{cleft}, cright_{cright} {
149  this->SetParent(parent_);
150  this->SetSplit(split_ind, split_cond, default_left);
151  }
152 
154  XGBOOST_DEVICE int LeftChild() const {
155  return this->cleft_;
156  }
159  return this->cright_;
160  }
163  return this->DefaultLeft() ? this->LeftChild() : this->RightChild();
164  }
166  XGBOOST_DEVICE unsigned SplitIndex() const {
167  return sindex_ & ((1U << 31) - 1U);
168  }
171  return (sindex_ >> 31) != 0;
172  }
174  XGBOOST_DEVICE bool IsLeaf() const {
175  return cleft_ == kInvalidNodeId;
176  }
179  return (this->info_).leaf_value;
180  }
183  return (this->info_).split_cond;
184  }
186  XGBOOST_DEVICE int Parent() const {
187  return parent_ & ((1U << 31) - 1);
188  }
191  return (parent_ & (1U << 31)) != 0;
192  }
194  XGBOOST_DEVICE bool IsDeleted() const {
195  return sindex_ == kDeletedNodeMarker;
196  }
198  XGBOOST_DEVICE bool IsRoot() const { return parent_ == kInvalidNodeId; }
203  XGBOOST_DEVICE void SetLeftChild(int nid) {
204  this->cleft_ = nid;
205  }
211  this->cright_ = nid;
212  }
219  XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond,
220  bool default_left = false) {
221  if (default_left) split_index |= (1U << 31);
222  this->sindex_ = split_index;
223  (this->info_).split_cond = split_cond;
224  }
231  XGBOOST_DEVICE void SetLeaf(bst_float value, int right = kInvalidNodeId) {
232  (this->info_).leaf_value = value;
233  this->cleft_ = kInvalidNodeId;
234  this->cright_ = right;
235  }
238  this->sindex_ = kDeletedNodeMarker;
239  }
242  this->sindex_ = 0;
243  }
244  // set parent
245  XGBOOST_DEVICE void SetParent(int pidx, bool is_left_child = true) {
246  if (is_left_child) pidx |= (1U << 31);
247  this->parent_ = pidx;
248  }
249  bool operator==(const Node& b) const {
250  return parent_ == b.parent_ && cleft_ == b.cleft_ &&
251  cright_ == b.cright_ && sindex_ == b.sindex_ &&
252  info_.leaf_value == b.info_.leaf_value;
253  }
254 
255  inline Node ByteSwap() const {
256  Node x = *this;
257  dmlc::ByteSwap(&x.parent_, sizeof(x.parent_), 1);
258  dmlc::ByteSwap(&x.cleft_, sizeof(x.cleft_), 1);
259  dmlc::ByteSwap(&x.cright_, sizeof(x.cright_), 1);
260  dmlc::ByteSwap(&x.sindex_, sizeof(x.sindex_), 1);
261  dmlc::ByteSwap(&x.info_, sizeof(x.info_), 1);
262  return x;
263  }
264 
265  private:
270  union Info{
271  bst_float leaf_value;
272  SplitCondT split_cond;
273  };
274  // pointer to parent, highest bit is used to
275  // indicate whether it's a left child or not
276  int32_t parent_{kInvalidNodeId};
277  // pointer to left, right
278  int32_t cleft_{kInvalidNodeId}, cright_{kInvalidNodeId};
279  // split feature index, left split or right split depends on the highest bit
280  uint32_t sindex_{0};
281  // extra info
282  Info info_;
283  };
284 
290  void ChangeToLeaf(int rid, bst_float value) {
291  CHECK(nodes_[nodes_[rid].LeftChild() ].IsLeaf());
292  CHECK(nodes_[nodes_[rid].RightChild()].IsLeaf());
293  this->DeleteNode(nodes_[rid].LeftChild());
294  this->DeleteNode(nodes_[rid].RightChild());
295  nodes_[rid].SetLeaf(value);
296  }
302  void CollapseToLeaf(int rid, bst_float value) {
303  if (nodes_[rid].IsLeaf()) return;
304  if (!nodes_[nodes_[rid].LeftChild() ].IsLeaf()) {
305  CollapseToLeaf(nodes_[rid].LeftChild(), 0.0f);
306  }
307  if (!nodes_[nodes_[rid].RightChild() ].IsLeaf()) {
308  CollapseToLeaf(nodes_[rid].RightChild(), 0.0f);
309  }
310  this->ChangeToLeaf(rid, value);
311  }
312 
317  param.num_nodes = 1;
318  param.num_deleted = 0;
319  nodes_.resize(param.num_nodes);
320  stats_.resize(param.num_nodes);
321  split_types_.resize(param.num_nodes, FeatureType::kNumerical);
322  split_categories_segments_.resize(param.num_nodes);
323  for (int i = 0; i < param.num_nodes; i ++) {
324  nodes_[i].SetLeaf(0.0f);
325  nodes_[i].SetParent(kInvalidNodeId);
326  }
327  }
329  Node& operator[](int nid) {
330  return nodes_[nid];
331  }
333  const Node& operator[](int nid) const {
334  return nodes_[nid];
335  }
336 
338  const std::vector<Node>& GetNodes() const { return nodes_; }
339 
341  const std::vector<RTreeNodeStat>& GetStats() const { return stats_; }
342 
344  RTreeNodeStat& Stat(int nid) {
345  return stats_[nid];
346  }
348  const RTreeNodeStat& Stat(int nid) const {
349  return stats_[nid];
350  }
351 
356  void Load(dmlc::Stream* fi);
361  void Save(dmlc::Stream* fo) const;
362 
363  void LoadModel(Json const& in) override;
364  void SaveModel(Json* out) const override;
365 
366  bool operator==(const RegTree& b) const {
367  return nodes_ == b.nodes_ && stats_ == b.stats_ &&
368  deleted_nodes_ == b.deleted_nodes_ && param == b.param;
369  }
370  /* \brief Iterate through all nodes in this tree.
371  *
372  * \param Function that accepts a node index, and returns false when iteration should
373  * stop, otherwise returns true.
374  */
375  template <typename Func> void WalkTree(Func func) const {
376  std::stack<bst_node_t> nodes;
377  nodes.push(kRoot);
378  auto &self = *this;
379  while (!nodes.empty()) {
380  auto nidx = nodes.top();
381  nodes.pop();
382  if (!func(nidx)) {
383  return;
384  }
385  auto left = self[nidx].LeftChild();
386  auto right = self[nidx].RightChild();
387  if (left != RegTree::kInvalidNodeId) {
388  nodes.push(left);
389  }
390  if (right != RegTree::kInvalidNodeId) {
391  nodes.push(right);
392  }
393  }
394  }
401  bool Equal(const RegTree& b) const;
402 
420  void ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value,
421  bool default_left, bst_float base_weight,
422  bst_float left_leaf_weight, bst_float right_leaf_weight,
423  bst_float loss_change, float sum_hess, float left_sum,
424  float right_sum,
425  bst_node_t leaf_right_child = kInvalidNodeId);
426 
442  void ExpandCategorical(bst_node_t nid, unsigned split_index,
443  common::Span<uint32_t> split_cat, bool default_left,
444  bst_float base_weight, bst_float left_leaf_weight,
445  bst_float right_leaf_weight, bst_float loss_change,
446  float sum_hess, float left_sum, float right_sum);
447 
448  bool HasCategoricalSplit() const {
449  return !split_categories_.empty();
450  }
451 
456  int GetDepth(int nid) const {
457  int depth = 0;
458  while (!nodes_[nid].IsRoot()) {
459  ++depth;
460  nid = nodes_[nid].Parent();
461  }
462  return depth;
463  }
464 
469  int MaxDepth(int nid) const {
470  if (nodes_[nid].IsLeaf()) return 0;
471  return std::max(MaxDepth(nodes_[nid].LeftChild())+1,
472  MaxDepth(nodes_[nid].RightChild())+1);
473  }
474 
478  int MaxDepth() {
479  return MaxDepth(0);
480  }
481 
483  int NumExtraNodes() const {
484  return param.num_nodes - 1 - param.num_deleted;
485  }
486 
487  /* \brief Count number of leaves in tree. */
488  bst_node_t GetNumLeaves() const;
490 
495  struct FVec {
500  void Init(size_t size);
505  void Fill(const SparsePage::Inst& inst);
506 
511  void Drop(const SparsePage::Inst& inst);
516  size_t Size() const;
522  bst_float GetFvalue(size_t i) const;
528  bool IsMissing(size_t i) const;
529  bool HasMissing() const;
530 
531 
532  private:
537  union Entry {
538  bst_float fvalue;
539  int flag;
540  };
541  std::vector<Entry> data_;
542  bool has_missing_;
543  };
544 
552  void CalculateContributions(const RegTree::FVec& feat,
553  std::vector<float>* mean_values,
554  bst_float* out_contribs, int condition = 0,
555  unsigned condition_feature = 0) const;
570  void TreeShap(const RegTree::FVec& feat, bst_float* phi, bst_node_t node_index,
571  unsigned unique_depth, PathElement* parent_unique_path,
572  bst_float parent_zero_fraction, bst_float parent_one_fraction,
573  int parent_feature_index, int condition,
574  unsigned condition_feature, bst_float condition_fraction) const;
575 
582  std::vector<float>* mean_values,
583  bst_float* out_contribs) const;
591  std::string DumpModel(const FeatureMap& fmap,
592  bool with_stats,
593  std::string format) const;
600  return split_types_.at(nidx);
601  }
605  std::vector<FeatureType> const &GetSplitTypes() const { return split_types_; }
606  common::Span<uint32_t const> GetSplitCategories() const { return split_categories_; }
611  auto node_ptr = GetCategoriesMatrix().node_ptr;
612  auto categories = GetCategoriesMatrix().categories;
613  auto segment = node_ptr[nidx];
614  auto node_cats = categories.subspan(segment.beg, segment.size);
615  return node_cats;
616  }
617  auto const& GetSplitCategoriesPtr() const { return split_categories_segments_; }
618 
619  // The fields of split_categories_segments_[i] are set such that
620  // the range split_categories_[beg:(beg+size)] stores the bitset for
621  // the matching categories for the i-th node.
622  struct Segment {
623  size_t beg {0};
624  size_t size {0};
625  };
626 
631  };
632 
636  view.categories = this->GetSplitCategories();
637  view.node_ptr = common::Span<Segment const>(split_categories_segments_);
638  return view;
639  }
640 
641  private:
642  template <bool typed>
643  void LoadCategoricalSplit(Json const& in);
644  void SaveCategoricalSplit(Json* p_out) const;
645  // vector of nodes
646  std::vector<Node> nodes_;
647  // free node space, used during training process
648  std::vector<int> deleted_nodes_;
649  // stats of nodes
650  std::vector<RTreeNodeStat> stats_;
651  std::vector<FeatureType> split_types_;
652 
653  // Categories for each internal node.
654  std::vector<uint32_t> split_categories_;
655  // Ptr to split categories of each node.
656  std::vector<Segment> split_categories_segments_;
657 
658  // allocate a new node,
659  // !!!!!! NOTE: may cause BUG here, nodes.resize
660  bst_node_t AllocNode() {
661  if (param.num_deleted != 0) {
662  int nid = deleted_nodes_.back();
663  deleted_nodes_.pop_back();
664  nodes_[nid].Reuse();
665  --param.num_deleted;
666  return nid;
667  }
668  int nd = param.num_nodes++;
669  CHECK_LT(param.num_nodes, std::numeric_limits<int>::max())
670  << "number of nodes in the tree exceed 2^31";
671  nodes_.resize(param.num_nodes);
672  stats_.resize(param.num_nodes);
673  split_types_.resize(param.num_nodes, FeatureType::kNumerical);
674  split_categories_segments_.resize(param.num_nodes);
675  return nd;
676  }
677  // delete a tree node, keep the parent field to allow trace back
678  void DeleteNode(int nid) {
679  CHECK_GE(nid, 1);
680  auto pid = (*this)[nid].Parent();
681  if (nid == (*this)[pid].LeftChild()) {
682  (*this)[pid].SetLeftChild(kInvalidNodeId);
683  } else {
684  (*this)[pid].SetRightChild(kInvalidNodeId);
685  }
686 
687  deleted_nodes_.push_back(nid);
688  nodes_[nid].MarkDelete();
689  ++param.num_deleted;
690  }
691 };
692 
693 inline void RegTree::FVec::Init(size_t size) {
694  Entry e; e.flag = -1;
695  data_.resize(size);
696  std::fill(data_.begin(), data_.end(), e);
697  has_missing_ = true;
698 }
699 
700 inline void RegTree::FVec::Fill(const SparsePage::Inst& inst) {
701  size_t feature_count = 0;
702  for (auto const& entry : inst) {
703  if (entry.index >= data_.size()) {
704  continue;
705  }
706  data_[entry.index].fvalue = entry.fvalue;
707  ++feature_count;
708  }
709  has_missing_ = data_.size() != feature_count;
710 }
711 
712 inline void RegTree::FVec::Drop(const SparsePage::Inst& inst) {
713  for (auto const& entry : inst) {
714  if (entry.index >= data_.size()) {
715  continue;
716  }
717  data_[entry.index].flag = -1;
718  }
719  has_missing_ = true;
720 }
721 
722 inline size_t RegTree::FVec::Size() const {
723  return data_.size();
724 }
725 
726 inline bst_float RegTree::FVec::GetFvalue(size_t i) const {
727  return data_[i].fvalue;
728 }
729 
730 inline bool RegTree::FVec::IsMissing(size_t i) const {
731  return data_[i].flag == -1;
732 }
733 
734 inline bool RegTree::FVec::HasMissing() const {
735  return has_missing_;
736 }
737 } // namespace xgboost
738 #endif // XGBOOST_TREE_MODEL_H_
xgboost::RegTree::Segment::beg
size_t beg
Definition: tree_model.h:623
xgboost::RegTree::GetCategoriesMatrix
CategoricalSplitMatrix GetCategoriesMatrix() const
Definition: tree_model.h:633
xgboost::RegTree::param
TreeParam param
model parameter
Definition: tree_model.h:314
xgboost::RegTree::Node
tree node
Definition: tree_model.h:139
xgboost::RegTree::Node::IsLeftChild
XGBOOST_DEVICE bool IsLeftChild() const
whether current node is left child
Definition: tree_model.h:190
xgboost::RegTree::FVec::Size
size_t Size() const
returns the size of the feature vector
Definition: tree_model.h:722
xgboost::TreeParam::deprecated_max_depth
int deprecated_max_depth
maximum depth, this is a statistics of the tree
Definition: tree_model.h:43
xgboost::TreeParam::DMLC_DECLARE_PARAMETER
DMLC_DECLARE_PARAMETER(TreeParam)
Definition: tree_model.h:78
xgboost::RegTree::Load
void Load(dmlc::Stream *fi)
load model from stream
xgboost::RegTree::kRoot
static constexpr bst_node_t kRoot
Definition: tree_model.h:136
xgboost::TreeParam::size_leaf_vector
int size_leaf_vector
leaf vector size, used for vector tree used to store more than one dimensional information in tree
Definition: tree_model.h:50
model.h
Defines the abstract interface for different components in XGBoost.
xgboost::TreeParam::operator==
bool operator==(const TreeParam &b) const
Definition: tree_model.h:89
xgboost::RegTree::Stat
const RTreeNodeStat & Stat(int nid) const
get node statistics given nid
Definition: tree_model.h:348
xgboost::RegTree::FVec
dense feature vector that can be taken by RegTree and can be construct from sparse feature vector.
Definition: tree_model.h:495
xgboost::RTreeNodeStat::loss_chg
bst_float loss_chg
loss change caused by current split
Definition: tree_model.h:100
xgboost::RegTree::HasCategoricalSplit
bool HasCategoricalSplit() const
Definition: tree_model.h:448
xgboost::TreeParam::reserved
int reserved[31]
reserved part, make sure alignment works for 64bit
Definition: tree_model.h:52
xgboost::RegTree::Node::SetSplit
XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond, bool default_left=false)
set split condition of current node
Definition: tree_model.h:219
xgboost::RegTree::LoadModel
void LoadModel(Json const &in) override
load the model from a JSON object
xgboost::RegTree::DumpModel
std::string DumpModel(const FeatureMap &fmap, bool with_stats, std::string format) const
dump the model in the requested format as a text string
xgboost::RegTree::Node::IsLeaf
XGBOOST_DEVICE bool IsLeaf() const
whether current node is leaf node
Definition: tree_model.h:174
xgboost::RegTree::FVec::HasMissing
bool HasMissing() const
Definition: tree_model.h:734
xgboost::RegTree::Node::SetLeftChild
XGBOOST_DEVICE void SetLeftChild(int nid)
set the left child
Definition: tree_model.h:203
xgboost::RegTree::WalkTree
void WalkTree(Func func) const
Definition: tree_model.h:375
xgboost::RTreeNodeStat::operator==
bool operator==(const RTreeNodeStat &b) const
Definition: tree_model.h:111
xgboost::TreeParam::ByteSwap
TreeParam ByteSwap() const
Definition: tree_model.h:65
xgboost::RegTree::CategoricalSplitMatrix
Definition: tree_model.h:627
xgboost::RegTree::GetSplitTypes
const std::vector< FeatureType > & GetSplitTypes() const
Get split types for all nodes.
Definition: tree_model.h:605
xgboost::RegTree::operator[]
Node & operator[](int nid)
get node given nid
Definition: tree_model.h:329
xgboost::RegTree::Node::IsRoot
XGBOOST_DEVICE bool IsRoot() const
whether current node is root
Definition: tree_model.h:198
xgboost::RegTree::Node::MarkDelete
XGBOOST_DEVICE void MarkDelete()
mark that this node is deleted
Definition: tree_model.h:237
xgboost::RegTree::CategoricalSplitMatrix::categories
common::Span< uint32_t const > categories
Definition: tree_model.h:629
xgboost::RegTree::FVec::Fill
void Fill(const SparsePage::Inst &inst)
fill the vector with sparse vector
Definition: tree_model.h:700
base.h
defines configuration macros of xgboost.
xgboost::RegTree::Node::Reuse
XGBOOST_DEVICE void Reuse()
Reuse this deleted node.
Definition: tree_model.h:241
feature_map.h
Feature map data structure to help visualization and model dump.
xgboost::RegTree::Node::SetParent
XGBOOST_DEVICE void SetParent(int pidx, bool is_left_child=true)
Definition: tree_model.h:245
xgboost::RegTree::Node::Parent
XGBOOST_DEVICE int Parent() const
get parent of the node
Definition: tree_model.h:186
xgboost::RegTree::Node::DefaultChild
XGBOOST_DEVICE int DefaultChild() const
index of default child when feature is missing
Definition: tree_model.h:162
xgboost::RegTree::Node::RightChild
XGBOOST_DEVICE int RightChild() const
index of right child
Definition: tree_model.h:158
xgboost::RegTree::NodeCats
common::Span< uint32_t const > NodeCats(bst_node_t nidx) const
Get the bit storage for categories.
Definition: tree_model.h:610
xgboost::RegTree::MaxDepth
int MaxDepth(int nid) const
get maximum depth
Definition: tree_model.h:469
xgboost::FeatureMap
Feature map data structure to help text model dump. TODO(tqchen) consider make it even more lightweig...
Definition: feature_map.h:22
xgboost::RegTree::FVec::Drop
void Drop(const SparsePage::Inst &inst)
drop the trace after fill, must be called after fill.
Definition: tree_model.h:712
xgboost::RegTree::RegTree
RegTree()
constructor
Definition: tree_model.h:316
xgboost::bst_feature_t
uint32_t bst_feature_t
Type for data column (feature) index.
Definition: base.h:123
xgboost::RegTree::operator[]
const Node & operator[](int nid) const
get node given nid
Definition: tree_model.h:333
xgboost::RegTree::Segment::size
size_t size
Definition: tree_model.h:624
xgboost::RegTree::CalculateContributions
void CalculateContributions(const RegTree::FVec &feat, std::vector< float > *mean_values, bst_float *out_contribs, int condition=0, unsigned condition_feature=0) const
calculate the feature contributions (https://arxiv.org/abs/1706.06060) for the tree
xgboost::Model
Definition: model.h:17
xgboost::RegTree::kInvalidNodeId
static constexpr bst_node_t kInvalidNodeId
Definition: tree_model.h:134
xgboost::TreeParam::num_feature
bst_feature_t num_feature
number of features used for tree construction
Definition: tree_model.h:45
xgboost::RegTree::Node::SetLeaf
XGBOOST_DEVICE void SetLeaf(bst_float value, int right=kInvalidNodeId)
set the leaf value of the node
Definition: tree_model.h:231
xgboost::RegTree::FVec::GetFvalue
bst_float GetFvalue(size_t i) const
get ith value
Definition: tree_model.h:726
xgboost::RegTree::MaxDepth
int MaxDepth()
get maximum depth
Definition: tree_model.h:478
xgboost::RTreeNodeStat
node statistics used in regression tree
Definition: tree_model.h:98
xgboost::RegTree::GetDepth
int GetDepth(int nid) const
get current depth
Definition: tree_model.h:456
xgboost::RegTree::Segment
Definition: tree_model.h:622
xgboost::TreeParam::deprecated_num_roots
int deprecated_num_roots
(Deprecated) number of start root
Definition: tree_model.h:37
xgboost::RegTree::SaveModel
void SaveModel(Json *out) const override
saves the model config to a JSON object
xgboost::RegTree::Node::ByteSwap
Node ByteSwap() const
Definition: tree_model.h:255
xgboost::RegTree::GetNumSplitNodes
bst_node_t GetNumSplitNodes() const
xgboost::RegTree::Node::LeafValue
XGBOOST_DEVICE bst_float LeafValue() const
Definition: tree_model.h:178
xgboost::RTreeNodeStat::RTreeNodeStat
RTreeNodeStat(float loss_chg, float sum_hess, float weight)
Definition: tree_model.h:109
xgboost::bst_node_t
int32_t bst_node_t
Type for tree node index.
Definition: base.h:132
xgboost::common::Span::subspan
XGBOOST_DEVICE auto subspan() const -> Span< element_type, detail::ExtentValue< Extent, Offset, Count >::value >
Definition: span.h:595
xgboost::RegTree::Node::Node
Node(int32_t cleft, int32_t cright, int32_t parent, uint32_t split_ind, float split_cond, bool default_left)
Definition: tree_model.h:146
xgboost::RegTree
define regression tree to be the most common tree model. This is the data structure used in xgboost's...
Definition: tree_model.h:131
xgboost::FeatureType::kNumerical
@ kNumerical
xgboost::RTreeNodeStat::sum_hess
bst_float sum_hess
sum of hessian values, used to measure coverage of data
Definition: tree_model.h:102
xgboost::TreeParam
meta parameters of the tree
Definition: tree_model.h:35
xgboost::RTreeNodeStat::base_weight
bst_float base_weight
weight of current node
Definition: tree_model.h:104
xgboost::RegTree::CollapseToLeaf
void CollapseToLeaf(int rid, bst_float value)
collapse a non leaf node to a leaf node, delete its children
Definition: tree_model.h:302
xgboost::RegTree::Node::LeftChild
XGBOOST_DEVICE int LeftChild() const
index of left child
Definition: tree_model.h:154
xgboost::TreeParam::num_deleted
int num_deleted
number of deleted nodes
Definition: tree_model.h:41
xgboost::FeatureType
FeatureType
Definition: data.h:41
xgboost::RegTree::GetSplitCategoriesPtr
auto const & GetSplitCategoriesPtr() const
Definition: tree_model.h:617
xgboost::RegTree::Node::Node
XGBOOST_DEVICE Node()
Definition: tree_model.h:141
xgboost::RegTree::CategoricalSplitMatrix::split_type
common::Span< FeatureType const > split_type
Definition: tree_model.h:628
xgboost::RegTree::kDeletedNodeMarker
static constexpr uint32_t kDeletedNodeMarker
Definition: tree_model.h:135
xgboost::RegTree::SplitCondT
bst_float SplitCondT
Definition: tree_model.h:133
xgboost::common::Span
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition: span.h:148
xgboost::RegTree::CategoricalSplitMatrix::node_ptr
common::Span< Segment const > node_ptr
Definition: tree_model.h:630
xgboost::RegTree::Stat
RTreeNodeStat & Stat(int nid)
get node statistics given nid
Definition: tree_model.h:344
data.h
The input data structure of xgboost.
xgboost::RegTree::Node::SplitCond
XGBOOST_DEVICE SplitCondT SplitCond() const
Definition: tree_model.h:182
xgboost::RTreeNodeStat::RTreeNodeStat
RTreeNodeStat()=default
xgboost::RTreeNodeStat::ByteSwap
RTreeNodeStat ByteSwap() const
Definition: tree_model.h:117
xgboost::RegTree::GetNumLeaves
bst_node_t GetNumLeaves() const
xgboost::RegTree::NodeSplitType
FeatureType NodeSplitType(bst_node_t nidx) const
Get split type for a node.
Definition: tree_model.h:599
xgboost::RegTree::GetSplitCategories
common::Span< uint32_t const > GetSplitCategories() const
Definition: tree_model.h:606
xgboost::RegTree::GetNodes
const std::vector< Node > & GetNodes() const
get const reference to nodes
Definition: tree_model.h:338
xgboost::RegTree::FVec::IsMissing
bool IsMissing(size_t i) const
check whether i-th entry is missing
Definition: tree_model.h:730
xgboost::RegTree::Node::DefaultLeft
XGBOOST_DEVICE bool DefaultLeft() const
when feature is unknown, whether goes to left child
Definition: tree_model.h:170
xgboost::RegTree::NumExtraNodes
int NumExtraNodes() const
number of extra nodes besides the root
Definition: tree_model.h:483
xgboost::RegTree::ChangeToLeaf
void ChangeToLeaf(int rid, bst_float value)
change a non leaf node to a leaf node, delete its children
Definition: tree_model.h:290
xgboost::RegTree::Node::SplitIndex
XGBOOST_DEVICE unsigned SplitIndex() const
feature index of split condition
Definition: tree_model.h:166
xgboost::RegTree::ExpandCategorical
void ExpandCategorical(bst_node_t nid, unsigned split_index, common::Span< uint32_t > split_cat, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum)
Expands a leaf node with categories.
xgboost::TreeParam::TreeParam
TreeParam()
constructor
Definition: tree_model.h:54
xgboost::RegTree::Node::SetRightChild
XGBOOST_DEVICE void SetRightChild(int nid)
set the right child
Definition: tree_model.h:210
xgboost::RegTree::CalculateContributionsApprox
void CalculateContributionsApprox(const RegTree::FVec &feat, std::vector< float > *mean_values, bst_float *out_contribs) const
calculate the approximate feature contributions for the given root
xgboost::Json
Data structure representing JSON format.
Definition: json.h:352
xgboost::RegTree::Save
void Save(dmlc::Stream *fo) const
save model to stream
xgboost::RegTree::GetStats
const std::vector< RTreeNodeStat > & GetStats() const
get const reference to stats
Definition: tree_model.h:341
xgboost::RegTree::FVec::Init
void Init(size_t size)
initialize the vector with size vector
Definition: tree_model.h:693
xgboost::RTreeNodeStat::leaf_child_cnt
int leaf_child_cnt
number of child that is leaf node known up to now
Definition: tree_model.h:106
xgboost::RegTree::Equal
bool Equal(const RegTree &b) const
Compares whether 2 trees are equal from a user's perspective. The equality compares only non-deleted ...
xgboost::RegTree::TreeShap
void TreeShap(const RegTree::FVec &feat, bst_float *phi, bst_node_t node_index, unsigned unique_depth, PathElement *parent_unique_path, bst_float parent_zero_fraction, bst_float parent_one_fraction, int parent_feature_index, int condition, unsigned condition_feature, bst_float condition_fraction) const
Recursive function that computes the feature attributions for a single tree.
xgboost::RegTree::ExpandNode
void ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum, bst_node_t leaf_right_child=kInvalidNodeId)
Expands a leaf node into two additional leaf nodes.
XGBOOST_DEVICE
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition: base.h:84
xgboost::RegTree::operator==
bool operator==(const RegTree &b) const
Definition: tree_model.h:366
xgboost::TreeParam::num_nodes
int num_nodes
total number of nodes
Definition: tree_model.h:39
xgboost::RegTree::Node::operator==
bool operator==(const Node &b) const
Definition: tree_model.h:249
xgboost::RegTree::Node::IsDeleted
XGBOOST_DEVICE bool IsDeleted() const
whether this node is deleted
Definition: tree_model.h:194
xgboost
namespace of xgboost
Definition: base.h:110
xgboost::bst_float
float bst_float
float type, used for storing statistics
Definition: base.h:119