xgboost
tree_model.h
Go to the documentation of this file.
1 
7 #ifndef XGBOOST_TREE_MODEL_H_
8 #define XGBOOST_TREE_MODEL_H_
9 
10 #include <dmlc/io.h>
11 #include <dmlc/parameter.h>
12 
13 #include <xgboost/base.h>
14 #include <xgboost/data.h>
15 #include <xgboost/logging.h>
16 #include <xgboost/feature_map.h>
17 #include <xgboost/model.h>
18 
19 #include <limits>
20 #include <vector>
21 #include <string>
22 #include <cstring>
23 #include <algorithm>
24 #include <tuple>
25 
26 namespace xgboost {
27 
28 struct PathElement; // forward declaration
29 
30 class Json;
31 // FIXME(trivialfis): Once binary IO is gone, make this parameter internal as it should
32 // not be configured by users.
34 struct TreeParam : public dmlc::Parameter<TreeParam> {
38  int num_nodes;
51  int reserved[31];
54  // assert compact alignment
55  static_assert(sizeof(TreeParam) == (31 + 6) * sizeof(int),
56  "TreeParam: 64 bit align");
57  std::memset(this, 0, sizeof(TreeParam));
58  num_nodes = 1;
59  deprecated_num_roots = 1;
60  }
61  // declare the parameters
63  // only declare the parameters that can be set by the user.
64  // other arguments are set by the algorithm.
65  DMLC_DECLARE_FIELD(num_nodes).set_lower_bound(1).set_default(1);
66  DMLC_DECLARE_FIELD(num_feature)
67  .describe("Number of features used in tree construction.");
68  DMLC_DECLARE_FIELD(size_leaf_vector).set_lower_bound(0).set_default(0)
69  .describe("Size of leaf vector, reserved for vector tree");
70  }
71 
72  bool operator==(const TreeParam& b) const {
73  return num_nodes == b.num_nodes &&
74  num_deleted == b.num_deleted &&
75  num_feature == b.num_feature &&
76  size_leaf_vector == b.size_leaf_vector;
77  }
78 };
79 
81 struct RTreeNodeStat {
89  int leaf_child_cnt {0};
90  bool operator==(const RTreeNodeStat& b) const {
91  return loss_chg == b.loss_chg && sum_hess == b.sum_hess &&
92  base_weight == b.base_weight && leaf_child_cnt == b.leaf_child_cnt;
93  }
94 };
95 
100 class RegTree : public Model {
101  public:
104  static constexpr int32_t kInvalidNodeId {-1};
106  class Node {
107  public:
108  Node() {
109  // assert compact alignment
110  static_assert(sizeof(Node) == 4 * sizeof(int) + sizeof(Info),
111  "Node: 64 bit align");
112  }
113  Node(int32_t cleft, int32_t cright, int32_t parent,
114  uint32_t split_ind, float split_cond, bool default_left) :
115  parent_{parent}, cleft_{cleft}, cright_{cright} {
116  this->SetSplit(split_ind, split_cond, default_left);
117  }
118 
120  XGBOOST_DEVICE int LeftChild() const {
121  return this->cleft_;
122  }
125  return this->cright_;
126  }
129  return this->DefaultLeft() ? this->LeftChild() : this->RightChild();
130  }
132  XGBOOST_DEVICE unsigned SplitIndex() const {
133  return sindex_ & ((1U << 31) - 1U);
134  }
137  return (sindex_ >> 31) != 0;
138  }
140  XGBOOST_DEVICE bool IsLeaf() const {
141  return cleft_ == kInvalidNodeId;
142  }
145  return (this->info_).leaf_value;
146  }
149  return (this->info_).split_cond;
150  }
152  XGBOOST_DEVICE int Parent() const {
153  return parent_ & ((1U << 31) - 1);
154  }
157  return (parent_ & (1U << 31)) != 0;
158  }
160  XGBOOST_DEVICE bool IsDeleted() const {
161  return sindex_ == std::numeric_limits<unsigned>::max();
162  }
164  XGBOOST_DEVICE bool IsRoot() const { return parent_ == kInvalidNodeId; }
169  XGBOOST_DEVICE void SetLeftChild(int nid) {
170  this->cleft_ = nid;
171  }
177  this->cright_ = nid;
178  }
185  XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond,
186  bool default_left = false) {
187  if (default_left) split_index |= (1U << 31);
188  this->sindex_ = split_index;
189  (this->info_).split_cond = split_cond;
190  }
197  XGBOOST_DEVICE void SetLeaf(bst_float value, int right = kInvalidNodeId) {
198  (this->info_).leaf_value = value;
199  this->cleft_ = kInvalidNodeId;
200  this->cright_ = right;
201  }
204  this->sindex_ = std::numeric_limits<unsigned>::max();
205  }
208  this->sindex_ = 0;
209  }
210  // set parent
211  XGBOOST_DEVICE void SetParent(int pidx, bool is_left_child = true) {
212  if (is_left_child) pidx |= (1U << 31);
213  this->parent_ = pidx;
214  }
215  bool operator==(const Node& b) const {
216  return parent_ == b.parent_ && cleft_ == b.cleft_ &&
217  cright_ == b.cright_ && sindex_ == b.sindex_ &&
218  info_.leaf_value == b.info_.leaf_value;
219  }
220 
221  private:
226  union Info{
227  bst_float leaf_value;
228  SplitCondT split_cond;
229  };
230  // pointer to parent, highest bit is used to
231  // indicate whether it's a left child or not
232  int32_t parent_{kInvalidNodeId};
233  // pointer to left, right
234  int32_t cleft_{kInvalidNodeId}, cright_{kInvalidNodeId};
235  // split feature index, left split or right split depends on the highest bit
236  uint32_t sindex_{0};
237  // extra info
238  Info info_;
239  };
240 
246  void ChangeToLeaf(int rid, bst_float value) {
247  CHECK(nodes_[nodes_[rid].LeftChild() ].IsLeaf());
248  CHECK(nodes_[nodes_[rid].RightChild()].IsLeaf());
249  this->DeleteNode(nodes_[rid].LeftChild());
250  this->DeleteNode(nodes_[rid].RightChild());
251  nodes_[rid].SetLeaf(value);
252  }
258  void CollapseToLeaf(int rid, bst_float value) {
259  if (nodes_[rid].IsLeaf()) return;
260  if (!nodes_[nodes_[rid].LeftChild() ].IsLeaf()) {
261  CollapseToLeaf(nodes_[rid].LeftChild(), 0.0f);
262  }
263  if (!nodes_[nodes_[rid].RightChild() ].IsLeaf()) {
264  CollapseToLeaf(nodes_[rid].RightChild(), 0.0f);
265  }
266  this->ChangeToLeaf(rid, value);
267  }
268 
273  param.num_nodes = 1;
274  param.num_deleted = 0;
275  nodes_.resize(param.num_nodes);
276  stats_.resize(param.num_nodes);
277  for (int i = 0; i < param.num_nodes; i ++) {
278  nodes_[i].SetLeaf(0.0f);
279  nodes_[i].SetParent(kInvalidNodeId);
280  }
281  }
283  Node& operator[](int nid) {
284  return nodes_[nid];
285  }
287  const Node& operator[](int nid) const {
288  return nodes_[nid];
289  }
290 
292  const std::vector<Node>& GetNodes() const { return nodes_; }
293 
295  RTreeNodeStat& Stat(int nid) {
296  return stats_[nid];
297  }
299  const RTreeNodeStat& Stat(int nid) const {
300  return stats_[nid];
301  }
302 
307  void Load(dmlc::Stream* fi);
312  void Save(dmlc::Stream* fo) const;
313 
314  void LoadModel(Json const& in) override;
315  void SaveModel(Json* out) const override;
316 
317  bool operator==(const RegTree& b) const {
318  return nodes_ == b.nodes_ && stats_ == b.stats_ &&
319  deleted_nodes_ == b.deleted_nodes_ && param == b.param;
320  }
321 
337  void ExpandNode(int nid, unsigned split_index, bst_float split_value,
338  bool default_left, bst_float base_weight,
339  bst_float left_leaf_weight, bst_float right_leaf_weight,
340  bst_float loss_change, float sum_hess,
341  bst_node_t leaf_right_child = kInvalidNodeId) {
342  int pleft = this->AllocNode();
343  int pright = this->AllocNode();
344  auto &node = nodes_[nid];
345  CHECK(node.IsLeaf());
346  node.SetLeftChild(pleft);
347  node.SetRightChild(pright);
348  nodes_[node.LeftChild()].SetParent(nid, true);
349  nodes_[node.RightChild()].SetParent(nid, false);
350  node.SetSplit(split_index, split_value,
351  default_left);
352 
353  nodes_[pleft].SetLeaf(left_leaf_weight, leaf_right_child);
354  nodes_[pright].SetLeaf(right_leaf_weight, leaf_right_child);
355 
356  this->Stat(nid).loss_chg = loss_change;
357  this->Stat(nid).base_weight = base_weight;
358  this->Stat(nid).sum_hess = sum_hess;
359  }
360 
365  int GetDepth(int nid) const {
366  int depth = 0;
367  while (!nodes_[nid].IsRoot()) {
368  ++depth;
369  nid = nodes_[nid].Parent();
370  }
371  return depth;
372  }
377  int MaxDepth(int nid) const {
378  if (nodes_[nid].IsLeaf()) return 0;
379  return std::max(MaxDepth(nodes_[nid].LeftChild())+1,
380  MaxDepth(nodes_[nid].RightChild())+1);
381  }
382 
386  int MaxDepth() {
387  return MaxDepth(0);
388  }
389 
391  int NumExtraNodes() const {
392  return param.num_nodes - 1 - param.num_deleted;
393  }
394 
399  struct FVec {
404  void Init(size_t size);
409  void Fill(const SparsePage::Inst& inst);
414  void Drop(const SparsePage::Inst& inst);
419  size_t Size() const;
425  bst_float Fvalue(size_t i) const;
431  bool IsMissing(size_t i) const;
432 
433  private:
438  union Entry {
439  bst_float fvalue;
440  int flag;
441  };
442  std::vector<Entry> data_;
443  };
449  int GetLeafIndex(const FVec& feat) const;
457  void CalculateContributions(const RegTree::FVec& feat,
458  bst_float* out_contribs, int condition = 0,
459  unsigned condition_feature = 0) const;
474  void TreeShap(const RegTree::FVec& feat, bst_float* phi, unsigned node_index,
475  unsigned unique_depth, PathElement* parent_unique_path,
476  bst_float parent_zero_fraction, bst_float parent_one_fraction,
477  int parent_feature_index, int condition,
478  unsigned condition_feature, bst_float condition_fraction) const;
479 
485  void CalculateContributionsApprox(const RegTree::FVec& feat,
486  bst_float* out_contribs) const;
493  inline int GetNext(int pid, bst_float fvalue, bool is_unknown) const;
501  std::string DumpModel(const FeatureMap& fmap,
502  bool with_stats,
503  std::string format) const;
507  void FillNodeMeanValues();
508 
509  private:
510  // vector of nodes
511  std::vector<Node> nodes_;
512  // free node space, used during training process
513  std::vector<int> deleted_nodes_;
514  // stats of nodes
515  std::vector<RTreeNodeStat> stats_;
516  std::vector<bst_float> node_mean_values_;
517  // allocate a new node,
518  // !!!!!! NOTE: may cause BUG here, nodes.resize
519  int AllocNode() {
520  if (param.num_deleted != 0) {
521  int nid = deleted_nodes_.back();
522  deleted_nodes_.pop_back();
523  nodes_[nid].Reuse();
524  --param.num_deleted;
525  return nid;
526  }
527  int nd = param.num_nodes++;
528  CHECK_LT(param.num_nodes, std::numeric_limits<int>::max())
529  << "number of nodes in the tree exceed 2^31";
530  nodes_.resize(param.num_nodes);
531  stats_.resize(param.num_nodes);
532  return nd;
533  }
534  // delete a tree node, keep the parent field to allow trace back
535  void DeleteNode(int nid) {
536  CHECK_GE(nid, 1);
537  deleted_nodes_.push_back(nid);
538  nodes_[nid].MarkDelete();
539  ++param.num_deleted;
540  }
541  bst_float FillNodeMeanValue(int nid);
542 };
543 
544 inline void RegTree::FVec::Init(size_t size) {
545  Entry e; e.flag = -1;
546  data_.resize(size);
547  std::fill(data_.begin(), data_.end(), e);
548 }
549 
550 inline void RegTree::FVec::Fill(const SparsePage::Inst& inst) {
551  for (bst_uint i = 0; i < inst.size(); ++i) {
552  if (inst[i].index >= data_.size()) continue;
553  data_[inst[i].index].fvalue = inst[i].fvalue;
554  }
555 }
556 
557 inline void RegTree::FVec::Drop(const SparsePage::Inst& inst) {
558  for (bst_uint i = 0; i < inst.size(); ++i) {
559  if (inst[i].index >= data_.size()) continue;
560  data_[inst[i].index].flag = -1;
561  }
562 }
563 
564 inline size_t RegTree::FVec::Size() const {
565  return data_.size();
566 }
567 
568 inline bst_float RegTree::FVec::Fvalue(size_t i) const {
569  return data_[i].fvalue;
570 }
571 
572 inline bool RegTree::FVec::IsMissing(size_t i) const {
573  return data_[i].flag == -1;
574 }
575 
576 inline int RegTree::GetLeafIndex(const RegTree::FVec& feat) const {
577  bst_node_t nid = 0;
578  while (!(*this)[nid].IsLeaf()) {
579  unsigned split_index = (*this)[nid].SplitIndex();
580  nid = this->GetNext(nid, feat.Fvalue(split_index), feat.IsMissing(split_index));
581  }
582  return nid;
583 }
584 
586 inline int RegTree::GetNext(int pid, bst_float fvalue, bool is_unknown) const {
587  bst_float split_value = (*this)[pid].SplitCond();
588  if (is_unknown) {
589  return (*this)[pid].DefaultChild();
590  } else {
591  if (fvalue < split_value) {
592  return (*this)[pid].LeftChild();
593  } else {
594  return (*this)[pid].RightChild();
595  }
596  }
597 }
598 } // namespace xgboost
599 #endif // XGBOOST_TREE_MODEL_H_
int NumExtraNodes() const
number of extra nodes besides the root
Definition: tree_model.h:391
int deprecated_num_roots
(Deprecated) number of start root
Definition: tree_model.h:36
int GetNext(int pid, bst_float fvalue, bool is_unknown) const
get next position of the tree given current pid
Definition: tree_model.h:586
Node(int32_t cleft, int32_t cright, int32_t parent, uint32_t split_ind, float split_cond, bool default_left)
Definition: tree_model.h:113
XGBOOST_DEVICE int RightChild() const
index of right child
Definition: tree_model.h:124
float bst_float
float type, used for storing statistics
Definition: base.h:111
XGBOOST_DEVICE constexpr index_type size() const __span_noexcept
Definition: span.h:521
int deprecated_max_depth
maximum depth, this is a statistics of the tree
Definition: tree_model.h:42
XGBOOST_DEVICE void SetLeaf(bst_float value, int right=kInvalidNodeId)
set the leaf value of the node
Definition: tree_model.h:197
XGBOOST_DEVICE bst_float LeafValue() const
Definition: tree_model.h:144
XGBOOST_DEVICE void SetParent(int pidx, bool is_left_child=true)
Definition: tree_model.h:211
The input data structure of xgboost.
void Fill(const SparsePage::Inst &inst)
fill the vector with sparse vector
Definition: tree_model.h:550
const Node & operator[](int nid) const
get node given nid
Definition: tree_model.h:287
Defines the abstract interface for different components in XGBoost.
XGBOOST_DEVICE void SetRightChild(int nid)
set the right child
Definition: tree_model.h:176
XGBOOST_DEVICE unsigned SplitIndex() const
feature index of split condition
Definition: tree_model.h:132
int leaf_child_cnt
number of child that is leaf node known up to now
Definition: tree_model.h:89
bst_float Fvalue(size_t i) const
get ith value
Definition: tree_model.h:568
int GetDepth(int nid) const
get current depth
Definition: tree_model.h:365
Feature map data structure to help text model dump. TODO(tqchen) consider make it even more lightweig...
Definition: feature_map.h:22
XGBOOST_DEVICE bool DefaultLeft() const
when feature is unknown, whether goes to left child
Definition: tree_model.h:136
bst_float base_weight
weight of current node
Definition: tree_model.h:87
define regression tree to be the most common tree model. This is the data structure used in xgboost&#39;s...
Definition: tree_model.h:100
TreeParam()
constructor
Definition: tree_model.h:53
int size_leaf_vector
leaf vector size, used for vector tree used to store more than one dimensional information in tree ...
Definition: tree_model.h:49
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition: span.h:115
bool operator==(const RTreeNodeStat &b) const
Definition: tree_model.h:90
node statistics used in regression tree
Definition: tree_model.h:81
bst_float sum_hess
sum of hessian values, used to measure coverage of data
Definition: tree_model.h:85
void ExpandNode(int nid, unsigned split_index, bst_float split_value, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, bst_node_t leaf_right_child=kInvalidNodeId)
Expands a leaf node into two additional leaf nodes.
Definition: tree_model.h:337
void ChangeToLeaf(int rid, bst_float value)
change a non leaf node to a leaf node, delete its children
Definition: tree_model.h:246
Definition: model.h:17
meta parameters of the tree
Definition: tree_model.h:34
bool operator==(const RegTree &b) const
Definition: tree_model.h:317
int num_deleted
number of deleted nodes
Definition: tree_model.h:40
int MaxDepth(int nid) const
get maximum depth
Definition: tree_model.h:377
int num_nodes
total number of nodes
Definition: tree_model.h:38
XGBOOST_DEVICE void SetLeftChild(int nid)
set the left child
Definition: tree_model.h:169
int32_t bst_node_t
Type for tree node index.
Definition: base.h:123
TreeParam param
model parameter
Definition: tree_model.h:270
bool operator==(const Node &b) const
Definition: tree_model.h:215
int GetLeafIndex(const FVec &feat) const
get the leaf index
Definition: tree_model.h:576
bool operator==(const TreeParam &b) const
Definition: tree_model.h:72
const RTreeNodeStat & Stat(int nid) const
get node statistics given nid
Definition: tree_model.h:299
void Init(size_t size)
initialize the vector with size vector
Definition: tree_model.h:544
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition: base.h:84
Feature map data structure to help visualization and model dump.
XGBOOST_DEVICE bool IsRoot() const
whether current node is root
Definition: tree_model.h:164
const std::vector< Node > & GetNodes() const
get const reference to nodes
Definition: tree_model.h:292
namespace of xgboost
Definition: base.h:102
tree node
Definition: tree_model.h:106
int MaxDepth()
get maximum depth
Definition: tree_model.h:386
XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond, bool default_left=false)
set split condition of current node
Definition: tree_model.h:185
XGBOOST_DEVICE int DefaultChild() const
index of default child when feature is missing
Definition: tree_model.h:128
defines configuration macros of xgboost.
XGBOOST_DEVICE void MarkDelete()
mark that this node is deleted
Definition: tree_model.h:203
XGBOOST_DEVICE bool IsLeftChild() const
whether current node is left child
Definition: tree_model.h:156
Node()
Definition: tree_model.h:108
size_t Size() const
returns the size of the feature vector
Definition: tree_model.h:564
XGBOOST_DEVICE bool IsLeaf() const
whether current node is leaf node
Definition: tree_model.h:140
RegTree()
constructor
Definition: tree_model.h:272
XGBOOST_DEVICE bool IsDeleted() const
whether this node is deleted
Definition: tree_model.h:160
uint32_t bst_uint
unsigned integer type used for feature index.
Definition: base.h:105
int num_feature
number of features used for tree construction
Definition: tree_model.h:44
void Drop(const SparsePage::Inst &inst)
drop the trace after fill, must be called after fill.
Definition: tree_model.h:557
Data structure representing JSON format.
Definition: json.h:325
int reserved[31]
reserved part, make sure alignment works for 64bit
Definition: tree_model.h:51
void CollapseToLeaf(int rid, bst_float value)
collapse a non leaf node to a leaf node, delete its children
Definition: tree_model.h:258
Node & operator[](int nid)
get node given nid
Definition: tree_model.h:283
bst_float SplitCondT
auxiliary statistics of node to help tree building
Definition: tree_model.h:103
XGBOOST_DEVICE int LeftChild() const
index of left child
Definition: tree_model.h:120
RTreeNodeStat & Stat(int nid)
get node statistics given nid
Definition: tree_model.h:295
XGBOOST_DEVICE void Reuse()
Reuse this deleted node.
Definition: tree_model.h:207
bst_float loss_chg
loss change caused by current split
Definition: tree_model.h:83
XGBOOST_DEVICE int Parent() const
get parent of the node
Definition: tree_model.h:152
XGBOOST_DEVICE SplitCondT SplitCond() const
Definition: tree_model.h:148
bool IsMissing(size_t i) const
check whether i-th entry is missing
Definition: tree_model.h:572
dense feature vector that can be taken by RegTree and can be construct from sparse feature vector...
Definition: tree_model.h:399
DMLC_DECLARE_PARAMETER(TreeParam)
Definition: tree_model.h:62