xgboost
tree_model.h
Go to the documentation of this file.
1 
7 #ifndef XGBOOST_TREE_MODEL_H_
8 #define XGBOOST_TREE_MODEL_H_
9 
10 #include <dmlc/io.h>
11 #include <dmlc/parameter.h>
12 #include <limits>
13 #include <vector>
14 #include <string>
15 #include <cstring>
16 #include <algorithm>
17 #include <tuple>
18 #include "./base.h"
19 #include "./data.h"
20 #include "./logging.h"
21 #include "./feature_map.h"
22 
23 namespace xgboost {
24 
25 struct PathElement; // forward declaration
26 
28 struct TreeParam : public dmlc::Parameter<TreeParam> {
30  int num_roots;
32  int num_nodes;
36  int max_depth;
45  int reserved[31];
48  // assert compact alignment
49  static_assert(sizeof(TreeParam) == (31 + 6) * sizeof(int),
50  "TreeParam: 64 bit align");
51  std::memset(this, 0, sizeof(TreeParam));
52  num_nodes = num_roots = 1;
53  }
54  // declare the parameters
56  // only declare the parameters that can be set by the user.
57  // other arguments are set by the algorithm.
58  DMLC_DECLARE_FIELD(num_roots).set_lower_bound(1).set_default(1)
59  .describe("Number of start root of trees.");
60  DMLC_DECLARE_FIELD(num_feature)
61  .describe("Number of features used in tree construction.");
62  DMLC_DECLARE_FIELD(size_leaf_vector).set_lower_bound(0).set_default(0)
63  .describe("Size of leaf vector, reserved for vector tree");
64  }
65 
66  bool operator==(const TreeParam& b) const {
67  return num_roots == b.num_roots && num_nodes == b.num_nodes &&
68  num_deleted == b.num_deleted && max_depth == b.max_depth &&
69  num_feature == b.num_feature &&
70  size_leaf_vector == b.size_leaf_vector;
71  }
72 };
73 
75 struct RTreeNodeStat {
84  bool operator==(const RTreeNodeStat& b) const {
85  return loss_chg == b.loss_chg && sum_hess == b.sum_hess &&
86  base_weight == b.base_weight && leaf_child_cnt == b.leaf_child_cnt;
87  }
88 };
89 
94 class RegTree {
95  public:
99  class Node {
100  public:
101  Node() {
102  // assert compact alignment
103  static_assert(sizeof(Node) == 4 * sizeof(int) + sizeof(Info),
104  "Node: 64 bit align");
105  }
107  XGBOOST_DEVICE int LeftChild() const {
108  return this->cleft_;
109  }
112  return this->cright_;
113  }
116  return this->DefaultLeft() ? this->LeftChild() : this->RightChild();
117  }
119  XGBOOST_DEVICE unsigned SplitIndex() const {
120  return sindex_ & ((1U << 31) - 1U);
121  }
124  return (sindex_ >> 31) != 0;
125  }
127  XGBOOST_DEVICE bool IsLeaf() const {
128  return cleft_ == -1;
129  }
132  return (this->info_).leaf_value;
133  }
136  return (this->info_).split_cond;
137  }
139  XGBOOST_DEVICE int Parent() const {
140  return parent_ & ((1U << 31) - 1);
141  }
144  return (parent_ & (1U << 31)) != 0;
145  }
147  XGBOOST_DEVICE bool IsDeleted() const {
148  return sindex_ == std::numeric_limits<unsigned>::max();
149  }
151  XGBOOST_DEVICE bool IsRoot() const { return parent_ == -1; }
156  XGBOOST_DEVICE void SetLeftChild(int nid) {
157  this->cleft_ = nid;
158  }
164  this->cright_ = nid;
165  }
172  XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond,
173  bool default_left = false) {
174  if (default_left) split_index |= (1U << 31);
175  this->sindex_ = split_index;
176  (this->info_).split_cond = split_cond;
177  }
184  XGBOOST_DEVICE void SetLeaf(bst_float value, int right = -1) {
185  (this->info_).leaf_value = value;
186  this->cleft_ = -1;
187  this->cright_ = right;
188  }
191  this->sindex_ = std::numeric_limits<unsigned>::max();
192  }
195  this->sindex_ = 0;
196  }
197  // set parent
198  XGBOOST_DEVICE void SetParent(int pidx, bool is_left_child = true) {
199  if (is_left_child) pidx |= (1U << 31);
200  this->parent_ = pidx;
201  }
202  bool operator==(const Node& b) const {
203  return parent_ == b.parent_ && cleft_ == b.cleft_ &&
204  cright_ == b.cright_ && sindex_ == b.sindex_ &&
205  info_.leaf_value == b.info_.leaf_value;
206  }
207 
208  private:
213  union Info{
214  bst_float leaf_value;
215  SplitCondT split_cond;
216  };
217  // pointer to parent, highest bit is used to
218  // indicate whether it's a left child or not
219  int parent_;
220  // pointer to left, right
221  int cleft_, cright_;
222  // split feature index, left split or right split depends on the highest bit
223  unsigned sindex_{0};
224  // extra info
225  Info info_;
226  };
227 
233  void ChangeToLeaf(int rid, bst_float value) {
234  CHECK(nodes_[nodes_[rid].LeftChild() ].IsLeaf());
235  CHECK(nodes_[nodes_[rid].RightChild()].IsLeaf());
236  this->DeleteNode(nodes_[rid].LeftChild());
237  this->DeleteNode(nodes_[rid].RightChild());
238  nodes_[rid].SetLeaf(value);
239  }
245  void CollapseToLeaf(int rid, bst_float value) {
246  if (nodes_[rid].IsLeaf()) return;
247  if (!nodes_[nodes_[rid].LeftChild() ].IsLeaf()) {
248  CollapseToLeaf(nodes_[rid].LeftChild(), 0.0f);
249  }
250  if (!nodes_[nodes_[rid].RightChild() ].IsLeaf()) {
251  CollapseToLeaf(nodes_[rid].RightChild(), 0.0f);
252  }
253  this->ChangeToLeaf(rid, value);
254  }
255 
260  param.num_nodes = 1;
261  param.num_roots = 1;
262  param.num_deleted = 0;
263  nodes_.resize(param.num_nodes);
264  stats_.resize(param.num_nodes);
265  for (int i = 0; i < param.num_nodes; i ++) {
266  nodes_[i].SetLeaf(0.0f);
267  nodes_[i].SetParent(-1);
268  }
269  }
271  Node& operator[](int nid) {
272  return nodes_[nid];
273  }
275  const Node& operator[](int nid) const {
276  return nodes_[nid];
277  }
278 
280  const std::vector<Node>& GetNodes() const { return nodes_; }
281 
283  RTreeNodeStat& Stat(int nid) {
284  return stats_[nid];
285  }
287  const RTreeNodeStat& Stat(int nid) const {
288  return stats_[nid];
289  }
294  void Load(dmlc::Stream* fi) {
295  CHECK_EQ(fi->Read(&param, sizeof(TreeParam)), sizeof(TreeParam));
296  nodes_.resize(param.num_nodes);
297  stats_.resize(param.num_nodes);
298  CHECK_NE(param.num_nodes, 0);
299  CHECK_EQ(fi->Read(dmlc::BeginPtr(nodes_), sizeof(Node) * nodes_.size()),
300  sizeof(Node) * nodes_.size());
301  CHECK_EQ(fi->Read(dmlc::BeginPtr(stats_), sizeof(RTreeNodeStat) * stats_.size()),
302  sizeof(RTreeNodeStat) * stats_.size());
303  // chg deleted nodes
304  deleted_nodes_.resize(0);
305  for (int i = param.num_roots; i < param.num_nodes; ++i) {
306  if (nodes_[i].IsDeleted()) deleted_nodes_.push_back(i);
307  }
308  CHECK_EQ(static_cast<int>(deleted_nodes_.size()), param.num_deleted);
309  }
314  void Save(dmlc::Stream* fo) const {
315  CHECK_EQ(param.num_nodes, static_cast<int>(nodes_.size()));
316  CHECK_EQ(param.num_nodes, static_cast<int>(stats_.size()));
317  fo->Write(&param, sizeof(TreeParam));
318  CHECK_NE(param.num_nodes, 0);
319  fo->Write(dmlc::BeginPtr(nodes_), sizeof(Node) * nodes_.size());
320  fo->Write(dmlc::BeginPtr(stats_), sizeof(RTreeNodeStat) * nodes_.size());
321  }
322 
323  bool operator==(const RegTree& b) const {
324  return nodes_ == b.nodes_ && stats_ == b.stats_ &&
325  deleted_nodes_ == b.deleted_nodes_ && param == b.param;
326  }
327 
341  void ExpandNode(int nid, unsigned split_index, bst_float split_value,
342  bool default_left, bst_float base_weight,
343  bst_float left_leaf_weight, bst_float right_leaf_weight,
344  bst_float loss_change, float sum_hess) {
345  int pleft = this->AllocNode();
346  int pright = this->AllocNode();
347  auto &node = nodes_[nid];
348  CHECK(node.IsLeaf());
349  node.SetLeftChild(pleft);
350  node.SetRightChild(pright);
351  nodes_[node.LeftChild()].SetParent(nid, true);
352  nodes_[node.RightChild()].SetParent(nid, false);
353  node.SetSplit(split_index, split_value,
354  default_left);
355  // mark right child as 0, to indicate fresh leaf
356  nodes_[pleft].SetLeaf(left_leaf_weight, 0);
357  nodes_[pright].SetLeaf(right_leaf_weight, 0);
358 
359  this->Stat(nid).loss_chg = loss_change;
360  this->Stat(nid).base_weight = base_weight;
361  this->Stat(nid).sum_hess = sum_hess;
362  }
363 
368  int GetDepth(int nid) const {
369  int depth = 0;
370  while (!nodes_[nid].IsRoot()) {
371  ++depth;
372  nid = nodes_[nid].Parent();
373  }
374  return depth;
375  }
380  int MaxDepth(int nid) const {
381  if (nodes_[nid].IsLeaf()) return 0;
382  return std::max(MaxDepth(nodes_[nid].LeftChild())+1,
383  MaxDepth(nodes_[nid].RightChild())+1);
384  }
385 
389  int MaxDepth() {
390  int maxd = 0;
391  for (int i = 0; i < param.num_roots; ++i) {
392  maxd = std::max(maxd, MaxDepth(i));
393  }
394  return maxd;
395  }
396 
398  int NumExtraNodes() const {
399  return param.num_nodes - param.num_roots - param.num_deleted;
400  }
401 
406  struct FVec {
411  void Init(size_t size);
416  void Fill(const SparsePage::Inst& inst);
421  void Drop(const SparsePage::Inst& inst);
426  size_t Size() const;
432  bst_float Fvalue(size_t i) const;
438  bool IsMissing(size_t i) const;
439 
440  private:
445  union Entry {
446  bst_float fvalue;
447  int flag;
448  };
449  std::vector<Entry> data_;
450  };
457  int GetLeafIndex(const FVec& feat, unsigned root_id = 0) const;
466  void CalculateContributions(const RegTree::FVec& feat, unsigned root_id,
467  bst_float* out_contribs, int condition = 0,
468  unsigned condition_feature = 0) const;
483  void TreeShap(const RegTree::FVec& feat, bst_float* phi, unsigned node_index,
484  unsigned unique_depth, PathElement* parent_unique_path,
485  bst_float parent_zero_fraction, bst_float parent_one_fraction,
486  int parent_feature_index, int condition,
487  unsigned condition_feature, bst_float condition_fraction) const;
488 
495  void CalculateContributionsApprox(const RegTree::FVec& feat, unsigned root_id,
496  bst_float* out_contribs) const;
503  inline int GetNext(int pid, bst_float fvalue, bool is_unknown) const;
511  std::string DumpModel(const FeatureMap& fmap,
512  bool with_stats,
513  std::string format) const;
517  void FillNodeMeanValues();
518 
519  private:
520  // vector of nodes
521  std::vector<Node> nodes_;
522  // free node space, used during training process
523  std::vector<int> deleted_nodes_;
524  // stats of nodes
525  std::vector<RTreeNodeStat> stats_;
526  std::vector<bst_float> node_mean_values_;
527  // allocate a new node,
528  // !!!!!! NOTE: may cause BUG here, nodes.resize
529  int AllocNode() {
530  if (param.num_deleted != 0) {
531  int nid = deleted_nodes_.back();
532  deleted_nodes_.pop_back();
533  nodes_[nid].Reuse();
534  --param.num_deleted;
535  return nid;
536  }
537  int nd = param.num_nodes++;
538  CHECK_LT(param.num_nodes, std::numeric_limits<int>::max())
539  << "number of nodes in the tree exceed 2^31";
540  nodes_.resize(param.num_nodes);
541  stats_.resize(param.num_nodes);
542  return nd;
543  }
544  // delete a tree node, keep the parent field to allow trace back
545  void DeleteNode(int nid) {
546  CHECK_GE(nid, param.num_roots);
547  deleted_nodes_.push_back(nid);
548  nodes_[nid].MarkDelete();
549  ++param.num_deleted;
550  }
551  bst_float FillNodeMeanValue(int nid);
552 };
553 
554 inline void RegTree::FVec::Init(size_t size) {
555  Entry e; e.flag = -1;
556  data_.resize(size);
557  std::fill(data_.begin(), data_.end(), e);
558 }
559 
560 inline void RegTree::FVec::Fill(const SparsePage::Inst& inst) {
561  for (bst_uint i = 0; i < inst.size(); ++i) {
562  if (inst[i].index >= data_.size()) continue;
563  data_[inst[i].index].fvalue = inst[i].fvalue;
564  }
565 }
566 
567 inline void RegTree::FVec::Drop(const SparsePage::Inst& inst) {
568  for (bst_uint i = 0; i < inst.size(); ++i) {
569  if (inst[i].index >= data_.size()) continue;
570  data_[inst[i].index].flag = -1;
571  }
572 }
573 
574 inline size_t RegTree::FVec::Size() const {
575  return data_.size();
576 }
577 
578 inline bst_float RegTree::FVec::Fvalue(size_t i) const {
579  return data_[i].fvalue;
580 }
581 
582 inline bool RegTree::FVec::IsMissing(size_t i) const {
583  return data_[i].flag == -1;
584 }
585 
586 inline int RegTree::GetLeafIndex(const RegTree::FVec& feat,
587  unsigned root_id) const {
588  auto pid = static_cast<int>(root_id);
589  while (!(*this)[pid].IsLeaf()) {
590  unsigned split_index = (*this)[pid].SplitIndex();
591  pid = this->GetNext(pid, feat.Fvalue(split_index), feat.IsMissing(split_index));
592  }
593  return pid;
594 }
595 
597 inline int RegTree::GetNext(int pid, bst_float fvalue, bool is_unknown) const {
598  bst_float split_value = (*this)[pid].SplitCond();
599  if (is_unknown) {
600  return (*this)[pid].DefaultChild();
601  } else {
602  if (fvalue < split_value) {
603  return (*this)[pid].LeftChild();
604  } else {
605  return (*this)[pid].RightChild();
606  }
607  }
608 }
609 } // namespace xgboost
610 #endif // XGBOOST_TREE_MODEL_H_
int NumExtraNodes() const
number of extra nodes besides the root
Definition: tree_model.h:398
int GetNext(int pid, bst_float fvalue, bool is_unknown) const
get next position of the tree given current pid
Definition: tree_model.h:597
XGBOOST_DEVICE int RightChild() const
index of right child
Definition: tree_model.h:111
float bst_float
float type, used for storing statistics
Definition: base.h:89
XGBOOST_DEVICE constexpr index_type size() const __span_noexcept
Definition: span.h:502
XGBOOST_DEVICE bst_float LeafValue() const
Definition: tree_model.h:131
XGBOOST_DEVICE void SetParent(int pidx, bool is_left_child=true)
Definition: tree_model.h:198
The input data structure of xgboost.
XGBOOST_DEVICE void SetLeaf(bst_float value, int right=-1)
set the leaf value of the node
Definition: tree_model.h:184
void Fill(const SparsePage::Inst &inst)
fill the vector with sparse vector
Definition: tree_model.h:560
const Node & operator[](int nid) const
get node given nid
Definition: tree_model.h:275
XGBOOST_DEVICE void SetRightChild(int nid)
set the right child
Definition: tree_model.h:163
XGBOOST_DEVICE unsigned SplitIndex() const
feature index of split condition
Definition: tree_model.h:119
int GetLeafIndex(const FVec &feat, unsigned root_id=0) const
get the leaf index
Definition: tree_model.h:586
int leaf_child_cnt
number of child that is leaf node known up to now
Definition: tree_model.h:83
bst_float Fvalue(size_t i) const
get ith value
Definition: tree_model.h:578
int GetDepth(int nid) const
get current depth
Definition: tree_model.h:368
Feature map data structure to help text model dump. TODO(tqchen) consider make it even more lightweig...
Definition: feature_map.h:20
XGBOOST_DEVICE bool DefaultLeft() const
when feature is unknown, whether goes to left child
Definition: tree_model.h:123
bst_float base_weight
weight of current node
Definition: tree_model.h:81
define regression tree to be the most common tree model. This is the data structure used in xgboost&#39;s...
Definition: tree_model.h:94
TreeParam()
constructor
Definition: tree_model.h:47
int size_leaf_vector
leaf vector size, used for vector tree used to store more than one dimensional information in tree ...
Definition: tree_model.h:43
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition: span.h:109
bool operator==(const RTreeNodeStat &b) const
Definition: tree_model.h:84
node statistics used in regression tree
Definition: tree_model.h:75
bst_float sum_hess
sum of hessian values, used to measure coverage of data
Definition: tree_model.h:79
void ChangeToLeaf(int rid, bst_float value)
change a non leaf node to a leaf node, delete its children
Definition: tree_model.h:233
meta parameters of the tree
Definition: tree_model.h:28
bool operator==(const RegTree &b) const
Definition: tree_model.h:323
int num_deleted
number of deleted nodes
Definition: tree_model.h:34
int MaxDepth(int nid) const
get maximum depth
Definition: tree_model.h:380
int num_nodes
total number of nodes
Definition: tree_model.h:32
XGBOOST_DEVICE void SetLeftChild(int nid)
set the left child
Definition: tree_model.h:156
TreeParam param
model parameter
Definition: tree_model.h:257
bool operator==(const Node &b) const
Definition: tree_model.h:202
bool operator==(const TreeParam &b) const
Definition: tree_model.h:66
const RTreeNodeStat & Stat(int nid) const
get node statistics given nid
Definition: tree_model.h:287
void Init(size_t size)
initialize the vector with size vector
Definition: tree_model.h:554
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition: base.h:75
Feature map data structure to help visualization and model dump.
XGBOOST_DEVICE bool IsRoot() const
whether current node is root
Definition: tree_model.h:151
const std::vector< Node > & GetNodes() const
get const reference to nodes
Definition: tree_model.h:280
namespace of xgboost
Definition: base.h:79
void ExpandNode(int nid, unsigned split_index, bst_float split_value, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess)
Expands a leaf node into two additional leaf nodes.
Definition: tree_model.h:341
tree node
Definition: tree_model.h:99
int MaxDepth()
get maximum depth
Definition: tree_model.h:389
XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond, bool default_left=false)
set split condition of current node
Definition: tree_model.h:172
XGBOOST_DEVICE int DefaultChild() const
index of default child when feature is missing
Definition: tree_model.h:115
defines configuration macros of xgboost.
XGBOOST_DEVICE void MarkDelete()
mark that this node is deleted
Definition: tree_model.h:190
XGBOOST_DEVICE bool IsLeftChild() const
whether current node is left child
Definition: tree_model.h:143
Node()
Definition: tree_model.h:101
size_t Size() const
returns the size of the feature vector
Definition: tree_model.h:574
void Save(dmlc::Stream *fo) const
save model to stream
Definition: tree_model.h:314
int num_roots
number of start root
Definition: tree_model.h:30
XGBOOST_DEVICE bool IsLeaf() const
whether current node is leaf node
Definition: tree_model.h:127
RegTree()
constructor
Definition: tree_model.h:259
XGBOOST_DEVICE bool IsDeleted() const
whether this node is deleted
Definition: tree_model.h:147
uint32_t bst_uint
unsigned integer type used in boost, used for feature index and row index.
Definition: base.h:84
void Load(dmlc::Stream *fi)
load model from stream
Definition: tree_model.h:294
int num_feature
number of features used for tree construction
Definition: tree_model.h:38
void Drop(const SparsePage::Inst &inst)
drop the trace after fill, must be called after fill.
Definition: tree_model.h:567
int reserved[31]
reserved part, make sure alignment works for 64bit
Definition: tree_model.h:45
void CollapseToLeaf(int rid, bst_float value)
collapse a non leaf node to a leaf node, delete its children
Definition: tree_model.h:245
Node & operator[](int nid)
get node given nid
Definition: tree_model.h:271
bst_float SplitCondT
auxiliary statistics of node to help tree building
Definition: tree_model.h:97
XGBOOST_DEVICE int LeftChild() const
index of left child
Definition: tree_model.h:107
RTreeNodeStat & Stat(int nid)
get node statistics given nid
Definition: tree_model.h:283
XGBOOST_DEVICE void Reuse()
Reuse this deleted node.
Definition: tree_model.h:194
bst_float loss_chg
loss change caused by current split
Definition: tree_model.h:77
XGBOOST_DEVICE int Parent() const
get parent of the node
Definition: tree_model.h:139
XGBOOST_DEVICE SplitCondT SplitCond() const
Definition: tree_model.h:135
bool IsMissing(size_t i) const
check whether i-th entry is missing
Definition: tree_model.h:582
dense feature vector that can be taken by RegTree and can be construct from sparse feature vector...
Definition: tree_model.h:406
int max_depth
maximum depth, this is a statistics of the tree
Definition: tree_model.h:36
DMLC_DECLARE_PARAMETER(TreeParam)
Definition: tree_model.h:55