7 #ifndef XGBOOST_TREE_MODEL_H_ 8 #define XGBOOST_TREE_MODEL_H_ 11 #include <dmlc/parameter.h> 20 #include "./logging.h" 28 struct TreeParam :
public dmlc::Parameter<TreeParam> {
49 static_assert(
sizeof(
TreeParam) == (31 + 6) *
sizeof(
int),
50 "TreeParam: 64 bit align");
52 num_nodes = num_roots = 1;
58 DMLC_DECLARE_FIELD(num_roots).set_lower_bound(1).set_default(1)
59 .describe(
"Number of start root of trees.");
60 DMLC_DECLARE_FIELD(num_feature)
61 .describe(
"Number of features used in tree construction.");
62 DMLC_DECLARE_FIELD(size_leaf_vector).set_lower_bound(0).set_default(0)
63 .describe(
"Size of leaf vector, reserved for vector tree");
103 static_assert(
sizeof(
Node) == 4 *
sizeof(
int) +
sizeof(Info),
104 "Node: 64 bit align");
112 return this->cright_;
116 return this->DefaultLeft() ? this->LeftChild() : this->RightChild();
120 return sindex_ & ((1U << 31) - 1U);
124 return (sindex_ >> 31) != 0;
132 return (this->info_).leaf_value;
136 return (this->info_).split_cond;
140 return parent_ & ((1U << 31) - 1);
144 return (parent_ & (1U << 31)) != 0;
148 return sindex_ == std::numeric_limits<unsigned>::max();
173 bool default_left =
false) {
174 if (default_left) split_index |= (1U << 31);
175 this->sindex_ = split_index;
176 (this->info_).split_cond = split_cond;
185 (this->info_).leaf_value = value;
187 this->cright_ = right;
191 this->sindex_ = std::numeric_limits<unsigned>::max();
199 if (is_left_child) pidx |= (1U << 31);
200 this->parent_ = pidx;
203 return parent_ == b.parent_ && cleft_ == b.cleft_ &&
204 cright_ == b.cright_ && sindex_ == b.sindex_ &&
205 info_.leaf_value == b.info_.leaf_value;
234 CHECK(nodes_[nodes_[rid].LeftChild() ].IsLeaf());
235 CHECK(nodes_[nodes_[rid].RightChild()].IsLeaf());
236 this->DeleteNode(nodes_[rid].LeftChild());
237 this->DeleteNode(nodes_[rid].RightChild());
238 nodes_[rid].SetLeaf(value);
246 if (nodes_[rid].IsLeaf())
return;
247 if (!nodes_[nodes_[rid].LeftChild() ].IsLeaf()) {
248 CollapseToLeaf(nodes_[rid].LeftChild(), 0.0f);
250 if (!nodes_[nodes_[rid].RightChild() ].IsLeaf()) {
251 CollapseToLeaf(nodes_[rid].RightChild(), 0.0f);
253 this->ChangeToLeaf(rid, value);
265 for (
int i = 0; i < param.
num_nodes; i ++) {
266 nodes_[i].SetLeaf(0.0f);
267 nodes_[i].SetParent(-1);
280 const std::vector<Node>&
GetNodes()
const {
return nodes_; }
299 CHECK_EQ(fi->Read(dmlc::BeginPtr(nodes_),
sizeof(
Node) * nodes_.size()),
300 sizeof(
Node) * nodes_.size());
301 CHECK_EQ(fi->Read(dmlc::BeginPtr(stats_),
sizeof(
RTreeNodeStat) * stats_.size()),
304 deleted_nodes_.resize(0);
306 if (nodes_[i].IsDeleted()) deleted_nodes_.push_back(i);
308 CHECK_EQ(static_cast<int>(deleted_nodes_.size()), param.
num_deleted);
314 void Save(dmlc::Stream* fo)
const {
315 CHECK_EQ(param.
num_nodes, static_cast<int>(nodes_.size()));
316 CHECK_EQ(param.
num_nodes, static_cast<int>(stats_.size()));
319 fo->Write(dmlc::BeginPtr(nodes_),
sizeof(
Node) * nodes_.size());
320 fo->Write(dmlc::BeginPtr(stats_),
sizeof(
RTreeNodeStat) * nodes_.size());
324 return nodes_ == b.nodes_ && stats_ == b.stats_ &&
325 deleted_nodes_ == b.deleted_nodes_ && param == b.
param;
342 bool default_left,
bst_float base_weight,
345 int pleft = this->AllocNode();
346 int pright = this->AllocNode();
347 auto &node = nodes_[nid];
348 CHECK(node.IsLeaf());
349 node.SetLeftChild(pleft);
350 node.SetRightChild(pright);
351 nodes_[node.LeftChild()].SetParent(nid,
true);
352 nodes_[node.RightChild()].SetParent(nid,
false);
353 node.SetSplit(split_index, split_value,
356 nodes_[pleft].SetLeaf(left_leaf_weight, 0);
357 nodes_[pright].SetLeaf(right_leaf_weight, 0);
359 this->Stat(nid).loss_chg = loss_change;
360 this->Stat(nid).base_weight = base_weight;
361 this->Stat(nid).sum_hess = sum_hess;
370 while (!nodes_[nid].IsRoot()) {
372 nid = nodes_[nid].Parent();
381 if (nodes_[nid].IsLeaf())
return 0;
382 return std::max(MaxDepth(nodes_[nid].LeftChild())+1,
383 MaxDepth(nodes_[nid].RightChild())+1);
391 for (
int i = 0; i < param.
num_roots; ++i) {
392 maxd = std::max(maxd, MaxDepth(i));
411 void Init(
size_t size);
438 bool IsMissing(
size_t i)
const;
449 std::vector<Entry> data_;
457 int GetLeafIndex(
const FVec& feat,
unsigned root_id = 0)
const;
466 void CalculateContributions(
const RegTree::FVec& feat,
unsigned root_id,
467 bst_float* out_contribs,
int condition = 0,
468 unsigned condition_feature = 0)
const;
484 unsigned unique_depth, PathElement* parent_unique_path,
486 int parent_feature_index,
int condition,
487 unsigned condition_feature,
bst_float condition_fraction)
const;
495 void CalculateContributionsApprox(
const RegTree::FVec& feat,
unsigned root_id,
503 inline int GetNext(
int pid,
bst_float fvalue,
bool is_unknown)
const;
513 std::string format)
const;
517 void FillNodeMeanValues();
521 std::vector<Node> nodes_;
523 std::vector<int> deleted_nodes_;
525 std::vector<RTreeNodeStat> stats_;
526 std::vector<bst_float> node_mean_values_;
531 int nid = deleted_nodes_.back();
532 deleted_nodes_.pop_back();
538 CHECK_LT(param.
num_nodes, std::numeric_limits<int>::max())
539 <<
"number of nodes in the tree exceed 2^31";
545 void DeleteNode(
int nid) {
547 deleted_nodes_.push_back(nid);
548 nodes_[nid].MarkDelete();
555 Entry e; e.flag = -1;
557 std::fill(data_.begin(), data_.end(), e);
562 if (inst[i].index >= data_.
size())
continue;
563 data_[inst[i].index].fvalue = inst[i].fvalue;
569 if (inst[i].index >= data_.
size())
continue;
570 data_[inst[i].index].flag = -1;
579 return data_[i].fvalue;
583 return data_[i].flag == -1;
587 unsigned root_id)
const {
588 auto pid =
static_cast<int>(root_id);
589 while (!(*
this)[pid].IsLeaf()) {
590 unsigned split_index = (*this)[pid].SplitIndex();
591 pid = this->GetNext(pid, feat.
Fvalue(split_index), feat.
IsMissing(split_index));
598 bst_float split_value = (*this)[pid].SplitCond();
600 return (*
this)[pid].DefaultChild();
602 if (fvalue < split_value) {
603 return (*
this)[pid].LeftChild();
605 return (*
this)[pid].RightChild();
610 #endif // XGBOOST_TREE_MODEL_H_ int NumExtraNodes() const
number of extra nodes besides the root
Definition: tree_model.h:398
int GetNext(int pid, bst_float fvalue, bool is_unknown) const
get next position of the tree given current pid
Definition: tree_model.h:597
XGBOOST_DEVICE int RightChild() const
index of right child
Definition: tree_model.h:111
float bst_float
float type, used for storing statistics
Definition: base.h:89
XGBOOST_DEVICE constexpr index_type size() const __span_noexcept
Definition: span.h:502
XGBOOST_DEVICE bst_float LeafValue() const
Definition: tree_model.h:131
XGBOOST_DEVICE void SetParent(int pidx, bool is_left_child=true)
Definition: tree_model.h:198
The input data structure of xgboost.
XGBOOST_DEVICE void SetLeaf(bst_float value, int right=-1)
set the leaf value of the node
Definition: tree_model.h:184
void Fill(const SparsePage::Inst &inst)
fill the vector with sparse vector
Definition: tree_model.h:560
const Node & operator[](int nid) const
get node given nid
Definition: tree_model.h:275
XGBOOST_DEVICE void SetRightChild(int nid)
set the right child
Definition: tree_model.h:163
XGBOOST_DEVICE unsigned SplitIndex() const
feature index of split condition
Definition: tree_model.h:119
int GetLeafIndex(const FVec &feat, unsigned root_id=0) const
get the leaf index
Definition: tree_model.h:586
int leaf_child_cnt
number of child that is leaf node known up to now
Definition: tree_model.h:83
bst_float Fvalue(size_t i) const
get ith value
Definition: tree_model.h:578
int GetDepth(int nid) const
get current depth
Definition: tree_model.h:368
Feature map data structure to help text model dump. TODO(tqchen) consider make it even more lightweig...
Definition: feature_map.h:20
XGBOOST_DEVICE bool DefaultLeft() const
when feature is unknown, whether goes to left child
Definition: tree_model.h:123
bst_float base_weight
weight of current node
Definition: tree_model.h:81
define regression tree to be the most common tree model. This is the data structure used in xgboost's...
Definition: tree_model.h:94
TreeParam()
constructor
Definition: tree_model.h:47
int size_leaf_vector
leaf vector size, used for vector tree used to store more than one dimensional information in tree ...
Definition: tree_model.h:43
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition: span.h:109
bool operator==(const RTreeNodeStat &b) const
Definition: tree_model.h:84
node statistics used in regression tree
Definition: tree_model.h:75
bst_float sum_hess
sum of hessian values, used to measure coverage of data
Definition: tree_model.h:79
void ChangeToLeaf(int rid, bst_float value)
change a non leaf node to a leaf node, delete its children
Definition: tree_model.h:233
meta parameters of the tree
Definition: tree_model.h:28
bool operator==(const RegTree &b) const
Definition: tree_model.h:323
int num_deleted
number of deleted nodes
Definition: tree_model.h:34
int MaxDepth(int nid) const
get maximum depth
Definition: tree_model.h:380
int num_nodes
total number of nodes
Definition: tree_model.h:32
XGBOOST_DEVICE void SetLeftChild(int nid)
set the left child
Definition: tree_model.h:156
TreeParam param
model parameter
Definition: tree_model.h:257
bool operator==(const Node &b) const
Definition: tree_model.h:202
bool operator==(const TreeParam &b) const
Definition: tree_model.h:66
const RTreeNodeStat & Stat(int nid) const
get node statistics given nid
Definition: tree_model.h:287
void Init(size_t size)
initialize the vector with size vector
Definition: tree_model.h:554
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition: base.h:75
Feature map data structure to help visualization and model dump.
XGBOOST_DEVICE bool IsRoot() const
whether current node is root
Definition: tree_model.h:151
const std::vector< Node > & GetNodes() const
get const reference to nodes
Definition: tree_model.h:280
namespace of xgboost
Definition: base.h:79
void ExpandNode(int nid, unsigned split_index, bst_float split_value, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess)
Expands a leaf node into two additional leaf nodes.
Definition: tree_model.h:341
tree node
Definition: tree_model.h:99
int MaxDepth()
get maximum depth
Definition: tree_model.h:389
XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond, bool default_left=false)
set split condition of current node
Definition: tree_model.h:172
XGBOOST_DEVICE int DefaultChild() const
index of default child when feature is missing
Definition: tree_model.h:115
defines configuration macros of xgboost.
XGBOOST_DEVICE void MarkDelete()
mark that this node is deleted
Definition: tree_model.h:190
XGBOOST_DEVICE bool IsLeftChild() const
whether current node is left child
Definition: tree_model.h:143
Node()
Definition: tree_model.h:101
size_t Size() const
returns the size of the feature vector
Definition: tree_model.h:574
void Save(dmlc::Stream *fo) const
save model to stream
Definition: tree_model.h:314
int num_roots
number of start root
Definition: tree_model.h:30
XGBOOST_DEVICE bool IsLeaf() const
whether current node is leaf node
Definition: tree_model.h:127
RegTree()
constructor
Definition: tree_model.h:259
XGBOOST_DEVICE bool IsDeleted() const
whether this node is deleted
Definition: tree_model.h:147
uint32_t bst_uint
unsigned integer type used in boost, used for feature index and row index.
Definition: base.h:84
void Load(dmlc::Stream *fi)
load model from stream
Definition: tree_model.h:294
int num_feature
number of features used for tree construction
Definition: tree_model.h:38
void Drop(const SparsePage::Inst &inst)
drop the trace after fill, must be called after fill.
Definition: tree_model.h:567
int reserved[31]
reserved part, make sure alignment works for 64bit
Definition: tree_model.h:45
void CollapseToLeaf(int rid, bst_float value)
collapse a non leaf node to a leaf node, delete its children
Definition: tree_model.h:245
Node & operator[](int nid)
get node given nid
Definition: tree_model.h:271
bst_float SplitCondT
auxiliary statistics of node to help tree building
Definition: tree_model.h:97
XGBOOST_DEVICE int LeftChild() const
index of left child
Definition: tree_model.h:107
RTreeNodeStat & Stat(int nid)
get node statistics given nid
Definition: tree_model.h:283
XGBOOST_DEVICE void Reuse()
Reuse this deleted node.
Definition: tree_model.h:194
bst_float loss_chg
loss change caused by current split
Definition: tree_model.h:77
XGBOOST_DEVICE int Parent() const
get parent of the node
Definition: tree_model.h:139
XGBOOST_DEVICE SplitCondT SplitCond() const
Definition: tree_model.h:135
bool IsMissing(size_t i) const
check whether i-th entry is missing
Definition: tree_model.h:582
dense feature vector that can be taken by RegTree and can be construct from sparse feature vector...
Definition: tree_model.h:406
int max_depth
maximum depth, this is a statistics of the tree
Definition: tree_model.h:36
DMLC_DECLARE_PARAMETER(TreeParam)
Definition: tree_model.h:55