xgboost
tree_model.h
Go to the documentation of this file.
1 
7 #ifndef XGBOOST_TREE_MODEL_H_
8 #define XGBOOST_TREE_MODEL_H_
9 
10 #include <dmlc/io.h>
11 #include <dmlc/parameter.h>
12 
13 #include <xgboost/base.h>
14 #include <xgboost/data.h>
15 #include <xgboost/logging.h>
16 #include <xgboost/feature_map.h>
17 #include <xgboost/model.h>
18 
19 #include <limits>
20 #include <vector>
21 #include <string>
22 #include <cstring>
23 #include <algorithm>
24 #include <tuple>
25 #include <stack>
26 
27 namespace xgboost {
28 
29 struct PathElement; // forward declaration
30 
31 class Json;
32 // FIXME(trivialfis): Once binary IO is gone, make this parameter internal as it should
33 // not be configured by users.
35 struct TreeParam : public dmlc::Parameter<TreeParam> {
39  int num_nodes;
52  int reserved[31];
55  // assert compact alignment
56  static_assert(sizeof(TreeParam) == (31 + 6) * sizeof(int),
57  "TreeParam: 64 bit align");
58  std::memset(this, 0, sizeof(TreeParam));
59  num_nodes = 1;
60  deprecated_num_roots = 1;
61  }
62  // declare the parameters
64  // only declare the parameters that can be set by the user.
65  // other arguments are set by the algorithm.
66  DMLC_DECLARE_FIELD(num_nodes).set_lower_bound(1).set_default(1);
67  DMLC_DECLARE_FIELD(num_feature)
68  .describe("Number of features used in tree construction.");
69  DMLC_DECLARE_FIELD(num_deleted);
70  DMLC_DECLARE_FIELD(size_leaf_vector).set_lower_bound(0).set_default(0)
71  .describe("Size of leaf vector, reserved for vector tree");
72  }
73 
74  bool operator==(const TreeParam& b) const {
75  return num_nodes == b.num_nodes &&
76  num_deleted == b.num_deleted &&
77  num_feature == b.num_feature &&
78  size_leaf_vector == b.size_leaf_vector;
79  }
80 };
81 
83 struct RTreeNodeStat {
91  int leaf_child_cnt {0};
92 
93  RTreeNodeStat() = default;
94  RTreeNodeStat(float loss_chg, float sum_hess, float weight) :
95  loss_chg{loss_chg}, sum_hess{sum_hess}, base_weight{weight} {}
96  bool operator==(const RTreeNodeStat& b) const {
97  return loss_chg == b.loss_chg && sum_hess == b.sum_hess &&
98  base_weight == b.base_weight && leaf_child_cnt == b.leaf_child_cnt;
99  }
100 };
101 
106 class RegTree : public Model {
107  public:
109  static constexpr bst_node_t kInvalidNodeId {-1};
110  static constexpr uint32_t kDeletedNodeMarker = std::numeric_limits<uint32_t>::max();
111  static constexpr bst_node_t kRoot { 0 };
112 
114  class Node {
115  public:
117  // assert compact alignment
118  static_assert(sizeof(Node) == 4 * sizeof(int) + sizeof(Info),
119  "Node: 64 bit align");
120  }
121  Node(int32_t cleft, int32_t cright, int32_t parent,
122  uint32_t split_ind, float split_cond, bool default_left) :
123  parent_{parent}, cleft_{cleft}, cright_{cright} {
124  this->SetParent(parent_);
125  this->SetSplit(split_ind, split_cond, default_left);
126  }
127 
129  XGBOOST_DEVICE int LeftChild() const {
130  return this->cleft_;
131  }
134  return this->cright_;
135  }
138  return this->DefaultLeft() ? this->LeftChild() : this->RightChild();
139  }
141  XGBOOST_DEVICE unsigned SplitIndex() const {
142  return sindex_ & ((1U << 31) - 1U);
143  }
146  return (sindex_ >> 31) != 0;
147  }
149  XGBOOST_DEVICE bool IsLeaf() const {
150  return cleft_ == kInvalidNodeId;
151  }
154  return (this->info_).leaf_value;
155  }
158  return (this->info_).split_cond;
159  }
161  XGBOOST_DEVICE int Parent() const {
162  return parent_ & ((1U << 31) - 1);
163  }
166  return (parent_ & (1U << 31)) != 0;
167  }
169  XGBOOST_DEVICE bool IsDeleted() const {
170  return sindex_ == kDeletedNodeMarker;
171  }
173  XGBOOST_DEVICE bool IsRoot() const { return parent_ == kInvalidNodeId; }
178  XGBOOST_DEVICE void SetLeftChild(int nid) {
179  this->cleft_ = nid;
180  }
186  this->cright_ = nid;
187  }
194  XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond,
195  bool default_left = false) {
196  if (default_left) split_index |= (1U << 31);
197  this->sindex_ = split_index;
198  (this->info_).split_cond = split_cond;
199  }
206  XGBOOST_DEVICE void SetLeaf(bst_float value, int right = kInvalidNodeId) {
207  (this->info_).leaf_value = value;
208  this->cleft_ = kInvalidNodeId;
209  this->cright_ = right;
210  }
213  this->sindex_ = kDeletedNodeMarker;
214  }
217  this->sindex_ = 0;
218  }
219  // set parent
220  XGBOOST_DEVICE void SetParent(int pidx, bool is_left_child = true) {
221  if (is_left_child) pidx |= (1U << 31);
222  this->parent_ = pidx;
223  }
224  bool operator==(const Node& b) const {
225  return parent_ == b.parent_ && cleft_ == b.cleft_ &&
226  cright_ == b.cright_ && sindex_ == b.sindex_ &&
227  info_.leaf_value == b.info_.leaf_value;
228  }
229 
230  private:
235  union Info{
236  bst_float leaf_value;
237  SplitCondT split_cond;
238  };
239  // pointer to parent, highest bit is used to
240  // indicate whether it's a left child or not
241  int32_t parent_{kInvalidNodeId};
242  // pointer to left, right
243  int32_t cleft_{kInvalidNodeId}, cright_{kInvalidNodeId};
244  // split feature index, left split or right split depends on the highest bit
245  uint32_t sindex_{0};
246  // extra info
247  Info info_;
248  };
249 
255  void ChangeToLeaf(int rid, bst_float value) {
256  CHECK(nodes_[nodes_[rid].LeftChild() ].IsLeaf());
257  CHECK(nodes_[nodes_[rid].RightChild()].IsLeaf());
258  this->DeleteNode(nodes_[rid].LeftChild());
259  this->DeleteNode(nodes_[rid].RightChild());
260  nodes_[rid].SetLeaf(value);
261  }
267  void CollapseToLeaf(int rid, bst_float value) {
268  if (nodes_[rid].IsLeaf()) return;
269  if (!nodes_[nodes_[rid].LeftChild() ].IsLeaf()) {
270  CollapseToLeaf(nodes_[rid].LeftChild(), 0.0f);
271  }
272  if (!nodes_[nodes_[rid].RightChild() ].IsLeaf()) {
273  CollapseToLeaf(nodes_[rid].RightChild(), 0.0f);
274  }
275  this->ChangeToLeaf(rid, value);
276  }
277 
282  param.num_nodes = 1;
283  param.num_deleted = 0;
284  nodes_.resize(param.num_nodes);
285  stats_.resize(param.num_nodes);
286  for (int i = 0; i < param.num_nodes; i ++) {
287  nodes_[i].SetLeaf(0.0f);
288  nodes_[i].SetParent(kInvalidNodeId);
289  }
290  }
292  Node& operator[](int nid) {
293  return nodes_[nid];
294  }
296  const Node& operator[](int nid) const {
297  return nodes_[nid];
298  }
299 
301  const std::vector<Node>& GetNodes() const { return nodes_; }
302 
304  RTreeNodeStat& Stat(int nid) {
305  return stats_[nid];
306  }
308  const RTreeNodeStat& Stat(int nid) const {
309  return stats_[nid];
310  }
311 
316  void Load(dmlc::Stream* fi);
321  void Save(dmlc::Stream* fo) const;
322 
323  void LoadModel(Json const& in) override;
324  void SaveModel(Json* out) const override;
325 
326  bool operator==(const RegTree& b) const {
327  return nodes_ == b.nodes_ && stats_ == b.stats_ &&
328  deleted_nodes_ == b.deleted_nodes_ && param == b.param;
329  }
330  /* \brief Iterate through all nodes in this tree.
331  *
332  * \param Function that accepts a node index, and returns false when iteration should
333  * stop, otherwise returns true.
334  */
335  template <typename Func> void WalkTree(Func func) const {
336  std::stack<bst_node_t> nodes;
337  nodes.push(kRoot);
338  auto &self = *this;
339  while (!nodes.empty()) {
340  auto nidx = nodes.top();
341  nodes.pop();
342  if (!func(nidx)) {
343  return;
344  }
345  auto left = self[nidx].LeftChild();
346  auto right = self[nidx].RightChild();
347  if (left != RegTree::kInvalidNodeId) {
348  nodes.push(left);
349  }
350  if (right != RegTree::kInvalidNodeId) {
351  nodes.push(right);
352  }
353  }
354  }
361  bool Equal(const RegTree& b) const;
362 
380  void ExpandNode(int nid, unsigned split_index, bst_float split_value,
381  bool default_left, bst_float base_weight,
382  bst_float left_leaf_weight, bst_float right_leaf_weight,
383  bst_float loss_change, float sum_hess, float left_sum,
384  float right_sum,
385  bst_node_t leaf_right_child = kInvalidNodeId) {
386  int pleft = this->AllocNode();
387  int pright = this->AllocNode();
388  auto &node = nodes_[nid];
389  CHECK(node.IsLeaf());
390  node.SetLeftChild(pleft);
391  node.SetRightChild(pright);
392  nodes_[node.LeftChild()].SetParent(nid, true);
393  nodes_[node.RightChild()].SetParent(nid, false);
394  node.SetSplit(split_index, split_value,
395  default_left);
396 
397  nodes_[pleft].SetLeaf(left_leaf_weight, leaf_right_child);
398  nodes_[pright].SetLeaf(right_leaf_weight, leaf_right_child);
399 
400  this->Stat(nid) = {loss_change, sum_hess, base_weight};
401  this->Stat(pleft) = {0.0f, left_sum, left_leaf_weight};
402  this->Stat(pright) = {0.0f, right_sum, right_leaf_weight};
403  }
404 
409  int GetDepth(int nid) const {
410  int depth = 0;
411  while (!nodes_[nid].IsRoot()) {
412  ++depth;
413  nid = nodes_[nid].Parent();
414  }
415  return depth;
416  }
421  int MaxDepth(int nid) const {
422  if (nodes_[nid].IsLeaf()) return 0;
423  return std::max(MaxDepth(nodes_[nid].LeftChild())+1,
424  MaxDepth(nodes_[nid].RightChild())+1);
425  }
426 
430  int MaxDepth() {
431  return MaxDepth(0);
432  }
433 
435  int NumExtraNodes() const {
436  return param.num_nodes - 1 - param.num_deleted;
437  }
438 
439  /* \brief Count number of leaves in tree. */
440  bst_node_t GetNumLeaves() const;
441  bst_node_t GetNumSplitNodes() const;
442 
447  struct FVec {
452  void Init(size_t size);
457  void Fill(const SparsePage::Inst& inst);
462  void Drop(const SparsePage::Inst& inst);
467  size_t Size() const;
473  bst_float GetFvalue(size_t i) const;
479  bool IsMissing(size_t i) const;
480 
481  private:
486  union Entry {
487  bst_float fvalue;
488  int flag;
489  };
490  std::vector<Entry> data_;
491  };
497  int GetLeafIndex(const FVec& feat) const;
505  void CalculateContributions(const RegTree::FVec& feat,
506  bst_float* out_contribs, int condition = 0,
507  unsigned condition_feature = 0) const;
522  void TreeShap(const RegTree::FVec& feat, bst_float* phi, unsigned node_index,
523  unsigned unique_depth, PathElement* parent_unique_path,
524  bst_float parent_zero_fraction, bst_float parent_one_fraction,
525  int parent_feature_index, int condition,
526  unsigned condition_feature, bst_float condition_fraction) const;
527 
533  void CalculateContributionsApprox(const RegTree::FVec& feat,
534  bst_float* out_contribs) const;
541  inline int GetNext(int pid, bst_float fvalue, bool is_unknown) const;
549  std::string DumpModel(const FeatureMap& fmap,
550  bool with_stats,
551  std::string format) const;
555  void FillNodeMeanValues();
556 
557  private:
558  // vector of nodes
559  std::vector<Node> nodes_;
560  // free node space, used during training process
561  std::vector<int> deleted_nodes_;
562  // stats of nodes
563  std::vector<RTreeNodeStat> stats_;
564  std::vector<bst_float> node_mean_values_;
565  // allocate a new node,
566  // !!!!!! NOTE: may cause BUG here, nodes.resize
567  int AllocNode() {
568  if (param.num_deleted != 0) {
569  int nid = deleted_nodes_.back();
570  deleted_nodes_.pop_back();
571  nodes_[nid].Reuse();
572  --param.num_deleted;
573  return nid;
574  }
575  int nd = param.num_nodes++;
576  CHECK_LT(param.num_nodes, std::numeric_limits<int>::max())
577  << "number of nodes in the tree exceed 2^31";
578  nodes_.resize(param.num_nodes);
579  stats_.resize(param.num_nodes);
580  return nd;
581  }
582  // delete a tree node, keep the parent field to allow trace back
583  void DeleteNode(int nid) {
584  CHECK_GE(nid, 1);
585  auto pid = (*this)[nid].Parent();
586  if (nid == (*this)[pid].LeftChild()) {
587  (*this)[pid].SetLeftChild(kInvalidNodeId);
588  } else {
589  (*this)[pid].SetRightChild(kInvalidNodeId);
590  }
591 
592  deleted_nodes_.push_back(nid);
593  nodes_[nid].MarkDelete();
594  ++param.num_deleted;
595  }
596  bst_float FillNodeMeanValue(int nid);
597 };
598 
599 inline void RegTree::FVec::Init(size_t size) {
600  Entry e; e.flag = -1;
601  data_.resize(size);
602  std::fill(data_.begin(), data_.end(), e);
603 }
604 
605 inline void RegTree::FVec::Fill(const SparsePage::Inst& inst) {
606  for (auto const& entry : inst) {
607  if (entry.index >= data_.size()) {
608  continue;
609  }
610  data_[entry.index].fvalue = entry.fvalue;
611  }
612 }
613 
614 inline void RegTree::FVec::Drop(const SparsePage::Inst& inst) {
615  for (auto const& entry : inst) {
616  if (entry.index >= data_.size()) {
617  continue;
618  }
619  data_[entry.index].flag = -1;
620  }
621 }
622 
623 inline size_t RegTree::FVec::Size() const {
624  return data_.size();
625 }
626 
627 inline bst_float RegTree::FVec::GetFvalue(size_t i) const {
628  return data_[i].fvalue;
629 }
630 
631 inline bool RegTree::FVec::IsMissing(size_t i) const {
632  return data_[i].flag == -1;
633 }
634 
635 inline int RegTree::GetLeafIndex(const RegTree::FVec& feat) const {
636  bst_node_t nid = 0;
637  while (!(*this)[nid].IsLeaf()) {
638  unsigned split_index = (*this)[nid].SplitIndex();
639  nid = this->GetNext(nid, feat.GetFvalue(split_index), feat.IsMissing(split_index));
640  }
641  return nid;
642 }
643 
645 inline int RegTree::GetNext(int pid, bst_float fvalue, bool is_unknown) const {
646  bst_float split_value = (*this)[pid].SplitCond();
647  if (is_unknown) {
648  return (*this)[pid].DefaultChild();
649  } else {
650  if (fvalue < split_value) {
651  return (*this)[pid].LeftChild();
652  } else {
653  return (*this)[pid].RightChild();
654  }
655  }
656 }
657 } // namespace xgboost
658 #endif // XGBOOST_TREE_MODEL_H_
int NumExtraNodes() const
number of extra nodes besides the root
Definition: tree_model.h:435
RTreeNodeStat(float loss_chg, float sum_hess, float weight)
Definition: tree_model.h:94
int deprecated_num_roots
(Deprecated) number of start root
Definition: tree_model.h:37
int GetNext(int pid, bst_float fvalue, bool is_unknown) const
get next position of the tree given current pid
Definition: tree_model.h:645
Node(int32_t cleft, int32_t cright, int32_t parent, uint32_t split_ind, float split_cond, bool default_left)
Definition: tree_model.h:121
XGBOOST_DEVICE int RightChild() const
index of right child
Definition: tree_model.h:133
void ExpandNode(int nid, unsigned split_index, bst_float split_value, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum, bst_node_t leaf_right_child=kInvalidNodeId)
Expands a leaf node into two additional leaf nodes.
Definition: tree_model.h:380
float bst_float
float type, used for storing statistics
Definition: base.h:111
int deprecated_max_depth
maximum depth, this is a statistics of the tree
Definition: tree_model.h:43
XGBOOST_DEVICE void SetLeaf(bst_float value, int right=kInvalidNodeId)
set the leaf value of the node
Definition: tree_model.h:206
XGBOOST_DEVICE bst_float LeafValue() const
Definition: tree_model.h:153
XGBOOST_DEVICE void SetParent(int pidx, bool is_left_child=true)
Definition: tree_model.h:220
The input data structure of xgboost.
void Fill(const SparsePage::Inst &inst)
fill the vector with sparse vector
Definition: tree_model.h:605
const Node & operator[](int nid) const
get node given nid
Definition: tree_model.h:296
Defines the abstract interface for different components in XGBoost.
bst_float GetFvalue(size_t i) const
get ith value
Definition: tree_model.h:627
XGBOOST_DEVICE void SetRightChild(int nid)
set the right child
Definition: tree_model.h:185
XGBOOST_DEVICE unsigned SplitIndex() const
feature index of split condition
Definition: tree_model.h:141
int leaf_child_cnt
number of child that is leaf node known up to now
Definition: tree_model.h:91
int GetDepth(int nid) const
get current depth
Definition: tree_model.h:409
Feature map data structure to help text model dump. TODO(tqchen) consider make it even more lightweig...
Definition: feature_map.h:22
XGBOOST_DEVICE bool DefaultLeft() const
when feature is unknown, whether goes to left child
Definition: tree_model.h:145
bst_float base_weight
weight of current node
Definition: tree_model.h:89
define regression tree to be the most common tree model. This is the data structure used in xgboost&#39;s...
Definition: tree_model.h:106
TreeParam()
constructor
Definition: tree_model.h:54
int size_leaf_vector
leaf vector size, used for vector tree used to store more than one dimensional information in tree ...
Definition: tree_model.h:50
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition: span.h:126
bool operator==(const RTreeNodeStat &b) const
Definition: tree_model.h:96
node statistics used in regression tree
Definition: tree_model.h:83
bst_float sum_hess
sum of hessian values, used to measure coverage of data
Definition: tree_model.h:87
void ChangeToLeaf(int rid, bst_float value)
change a non leaf node to a leaf node, delete its children
Definition: tree_model.h:255
Definition: model.h:17
XGBOOST_DEVICE Node()
Definition: tree_model.h:116
meta parameters of the tree
Definition: tree_model.h:35
bool operator==(const RegTree &b) const
Definition: tree_model.h:326
int num_deleted
number of deleted nodes
Definition: tree_model.h:41
int MaxDepth(int nid) const
get maximum depth
Definition: tree_model.h:421
int num_nodes
total number of nodes
Definition: tree_model.h:39
XGBOOST_DEVICE void SetLeftChild(int nid)
set the left child
Definition: tree_model.h:178
int32_t bst_node_t
Type for tree node index.
Definition: base.h:123
TreeParam param
model parameter
Definition: tree_model.h:279
bool operator==(const Node &b) const
Definition: tree_model.h:224
int GetLeafIndex(const FVec &feat) const
get the leaf index
Definition: tree_model.h:635
bool operator==(const TreeParam &b) const
Definition: tree_model.h:74
const RTreeNodeStat & Stat(int nid) const
get node statistics given nid
Definition: tree_model.h:308
void Init(size_t size)
initialize the vector with size vector
Definition: tree_model.h:599
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition: base.h:84
Feature map data structure to help visualization and model dump.
XGBOOST_DEVICE bool IsRoot() const
whether current node is root
Definition: tree_model.h:173
const std::vector< Node > & GetNodes() const
get const reference to nodes
Definition: tree_model.h:301
namespace of xgboost
Definition: base.h:102
static constexpr bst_node_t kInvalidNodeId
Definition: tree_model.h:109
tree node
Definition: tree_model.h:114
int MaxDepth()
get maximum depth
Definition: tree_model.h:430
XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond, bool default_left=false)
set split condition of current node
Definition: tree_model.h:194
XGBOOST_DEVICE int DefaultChild() const
index of default child when feature is missing
Definition: tree_model.h:137
defines configuration macros of xgboost.
XGBOOST_DEVICE void MarkDelete()
mark that this node is deleted
Definition: tree_model.h:212
XGBOOST_DEVICE bool IsLeftChild() const
whether current node is left child
Definition: tree_model.h:165
void WalkTree(Func func) const
Definition: tree_model.h:335
size_t Size() const
returns the size of the feature vector
Definition: tree_model.h:623
XGBOOST_DEVICE bool IsLeaf() const
whether current node is leaf node
Definition: tree_model.h:149
RegTree()
constructor
Definition: tree_model.h:281
XGBOOST_DEVICE bool IsDeleted() const
whether this node is deleted
Definition: tree_model.h:169
int num_feature
number of features used for tree construction
Definition: tree_model.h:45
void Drop(const SparsePage::Inst &inst)
drop the trace after fill, must be called after fill.
Definition: tree_model.h:614
Data structure representing JSON format.
Definition: json.h:326
int reserved[31]
reserved part, make sure alignment works for 64bit
Definition: tree_model.h:52
void CollapseToLeaf(int rid, bst_float value)
collapse a non leaf node to a leaf node, delete its children
Definition: tree_model.h:267
Node & operator[](int nid)
get node given nid
Definition: tree_model.h:292
bst_float SplitCondT
Definition: tree_model.h:108
XGBOOST_DEVICE int LeftChild() const
index of left child
Definition: tree_model.h:129
RTreeNodeStat & Stat(int nid)
get node statistics given nid
Definition: tree_model.h:304
XGBOOST_DEVICE void Reuse()
Reuse this deleted node.
Definition: tree_model.h:216
bst_float loss_chg
loss change caused by current split
Definition: tree_model.h:85
XGBOOST_DEVICE int Parent() const
get parent of the node
Definition: tree_model.h:161
XGBOOST_DEVICE SplitCondT SplitCond() const
Definition: tree_model.h:157
bool IsMissing(size_t i) const
check whether i-th entry is missing
Definition: tree_model.h:631
dense feature vector that can be taken by RegTree and can be construct from sparse feature vector...
Definition: tree_model.h:447
DMLC_DECLARE_PARAMETER(TreeParam)
Definition: tree_model.h:63