7 #ifndef XGBOOST_TREE_MODEL_H_
8 #define XGBOOST_TREE_MODEL_H_
11 #include <dmlc/parameter.h>
15 #include <xgboost/logging.h>
35 struct TreeParam :
public dmlc::Parameter<TreeParam> {
56 static_assert(
sizeof(
TreeParam) == (31 + 6) *
sizeof(
int),
57 "TreeParam: 64 bit align");
81 DMLC_DECLARE_FIELD(
num_nodes).set_lower_bound(1).set_default(1);
83 .describe(
"Number of features used in tree construction.");
86 .describe(
"Size of leaf vector, reserved for vector tree");
143 static_assert(
sizeof(
Node) == 4 *
sizeof(
int) +
sizeof(Info),
144 "Node: 64 bit align");
146 Node(int32_t cleft, int32_t cright, int32_t parent,
147 uint32_t split_ind,
float split_cond,
bool default_left) :
148 parent_{parent}, cleft_{cleft}, cright_{cright} {
150 this->
SetSplit(split_ind, split_cond, default_left);
159 return this->cright_;
167 return sindex_ & ((1U << 31) - 1U);
171 return (sindex_ >> 31) != 0;
179 return (this->info_).leaf_value;
183 return (this->info_).split_cond;
187 return parent_ & ((1U << 31) - 1);
191 return (parent_ & (1U << 31)) != 0;
220 bool default_left =
false) {
221 if (default_left) split_index |= (1U << 31);
222 this->sindex_ = split_index;
223 (this->info_).split_cond = split_cond;
232 (this->info_).leaf_value = value;
234 this->cright_ = right;
246 if (is_left_child) pidx |= (1U << 31);
247 this->parent_ = pidx;
250 return parent_ == b.parent_ && cleft_ == b.cleft_ &&
251 cright_ == b.cright_ && sindex_ == b.sindex_ &&
252 info_.leaf_value == b.info_.leaf_value;
257 dmlc::ByteSwap(&x.parent_,
sizeof(x.parent_), 1);
258 dmlc::ByteSwap(&x.cleft_,
sizeof(x.cleft_), 1);
259 dmlc::ByteSwap(&x.cright_,
sizeof(x.cright_), 1);
260 dmlc::ByteSwap(&x.sindex_,
sizeof(x.sindex_), 1);
261 dmlc::ByteSwap(&x.info_,
sizeof(x.info_), 1);
291 CHECK(nodes_[nodes_[rid].LeftChild() ].IsLeaf());
292 CHECK(nodes_[nodes_[rid].RightChild()].IsLeaf());
293 this->DeleteNode(nodes_[rid].LeftChild());
294 this->DeleteNode(nodes_[rid].RightChild());
295 nodes_[rid].SetLeaf(value);
303 if (nodes_[rid].IsLeaf())
return;
304 if (!nodes_[nodes_[rid].LeftChild() ].IsLeaf()) {
307 if (!nodes_[nodes_[rid].RightChild() ].IsLeaf()) {
324 nodes_[i].SetLeaf(0.0f);
338 const std::vector<Node>&
GetNodes()
const {
return nodes_; }
341 const std::vector<RTreeNodeStat>&
GetStats()
const {
return stats_; }
361 void Save(dmlc::Stream* fo)
const;
367 return nodes_ == b.nodes_ && stats_ == b.stats_ &&
368 deleted_nodes_ == b.deleted_nodes_ &&
param == b.
param;
375 template <
typename Func>
void WalkTree(Func func)
const {
376 std::stack<bst_node_t> nodes;
379 while (!nodes.empty()) {
380 auto nidx = nodes.top();
385 auto left =
self[nidx].LeftChild();
386 auto right =
self[nidx].RightChild();
421 bool default_left,
bst_float base_weight,
423 bst_float loss_change,
float sum_hess,
float left_sum,
446 float left_sum,
float right_sum);
449 return !split_categories_.empty();
458 while (!nodes_[nid].IsRoot()) {
460 nid = nodes_[nid].Parent();
470 if (nodes_[nid].IsLeaf())
return 0;
471 return std::max(
MaxDepth(nodes_[nid].LeftChild())+1,
472 MaxDepth(nodes_[nid].RightChild())+1);
500 void Init(
size_t size);
541 std::vector<Entry> data_;
553 std::vector<float>* mean_values,
554 bst_float* out_contribs,
int condition = 0,
555 unsigned condition_feature = 0)
const;
571 unsigned unique_depth, PathElement* parent_unique_path,
573 int parent_feature_index,
int condition,
574 unsigned condition_feature,
bst_float condition_fraction)
const;
582 std::vector<float>* mean_values,
593 std::string format)
const;
600 return split_types_.at(nidx);
605 std::vector<FeatureType>
const &
GetSplitTypes()
const {
return split_types_; }
613 auto segment = node_ptr[nidx];
614 auto node_cats = categories.
subspan(segment.beg, segment.size);
642 template <
bool typed>
643 void LoadCategoricalSplit(
Json const& in);
644 void SaveCategoricalSplit(
Json* p_out)
const;
646 std::vector<Node> nodes_;
648 std::vector<int> deleted_nodes_;
650 std::vector<RTreeNodeStat> stats_;
651 std::vector<FeatureType> split_types_;
654 std::vector<uint32_t> split_categories_;
656 std::vector<Segment> split_categories_segments_;
662 int nid = deleted_nodes_.back();
663 deleted_nodes_.pop_back();
670 <<
"number of nodes in the tree exceed 2^31";
678 void DeleteNode(
int nid) {
680 auto pid = (*this)[nid].Parent();
681 if (nid == (*
this)[pid].LeftChild()) {
687 deleted_nodes_.push_back(nid);
688 nodes_[nid].MarkDelete();
694 Entry e; e.flag = -1;
696 std::fill(data_.begin(), data_.end(), e);
701 size_t feature_count = 0;
702 for (
auto const& entry : inst) {
703 if (entry.index >= data_.size()) {
706 data_[entry.index].fvalue = entry.fvalue;
709 has_missing_ = data_.size() != feature_count;
713 for (
auto const& entry : inst) {
714 if (entry.index >= data_.size()) {
717 data_[entry.index].flag = -1;
727 return data_[i].fvalue;
731 return data_[i].flag == -1;
defines configuration macros of xgboost.
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition: base.h:84
Feature map data structure to help text model dump. TODO(tqchen) consider make it even more lightweig...
Definition: feature_map.h:22
Data structure representing JSON format.
Definition: json.h:356
tree node
Definition: tree_model.h:139
XGBOOST_DEVICE bst_float LeafValue() const
Definition: tree_model.h:178
XGBOOST_DEVICE int Parent() const
get parent of the node
Definition: tree_model.h:186
XGBOOST_DEVICE void MarkDelete()
mark that this node is deleted
Definition: tree_model.h:237
XGBOOST_DEVICE bool IsRoot() const
whether current node is root
Definition: tree_model.h:198
XGBOOST_DEVICE int RightChild() const
index of right child
Definition: tree_model.h:158
XGBOOST_DEVICE Node()
Definition: tree_model.h:141
XGBOOST_DEVICE void SetParent(int pidx, bool is_left_child=true)
Definition: tree_model.h:245
XGBOOST_DEVICE unsigned SplitIndex() const
feature index of split condition
Definition: tree_model.h:166
Node ByteSwap() const
Definition: tree_model.h:255
XGBOOST_DEVICE void SetLeaf(bst_float value, int right=kInvalidNodeId)
set the leaf value of the node
Definition: tree_model.h:231
XGBOOST_DEVICE bool IsLeftChild() const
whether current node is left child
Definition: tree_model.h:190
XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond, bool default_left=false)
set split condition of current node
Definition: tree_model.h:219
XGBOOST_DEVICE void SetLeftChild(int nid)
set the left child
Definition: tree_model.h:203
XGBOOST_DEVICE bool IsDeleted() const
whether this node is deleted
Definition: tree_model.h:194
XGBOOST_DEVICE bool IsLeaf() const
whether current node is leaf node
Definition: tree_model.h:174
bool operator==(const Node &b) const
Definition: tree_model.h:249
Node(int32_t cleft, int32_t cright, int32_t parent, uint32_t split_ind, float split_cond, bool default_left)
Definition: tree_model.h:146
XGBOOST_DEVICE void Reuse()
Reuse this deleted node.
Definition: tree_model.h:241
XGBOOST_DEVICE void SetRightChild(int nid)
set the right child
Definition: tree_model.h:210
XGBOOST_DEVICE bool DefaultLeft() const
when feature is unknown, whether goes to left child
Definition: tree_model.h:170
XGBOOST_DEVICE int LeftChild() const
index of left child
Definition: tree_model.h:154
XGBOOST_DEVICE int DefaultChild() const
index of default child when feature is missing
Definition: tree_model.h:162
XGBOOST_DEVICE SplitCondT SplitCond() const
Definition: tree_model.h:182
define regression tree to be the most common tree model. This is the data structure used in xgboost's...
Definition: tree_model.h:131
int MaxDepth(int nid) const
get maximum depth
Definition: tree_model.h:469
void SaveModel(Json *out) const override
saves the model config to a JSON object
void WalkTree(Func func) const
Definition: tree_model.h:375
void ExpandCategorical(bst_node_t nid, unsigned split_index, common::Span< const uint32_t > split_cat, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum)
Expands a leaf node with categories.
void Save(dmlc::Stream *fo) const
save model to stream
bool operator==(const RegTree &b) const
Definition: tree_model.h:366
const RTreeNodeStat & Stat(int nid) const
get node statistics given nid
Definition: tree_model.h:348
const Node & operator[](int nid) const
get node given nid
Definition: tree_model.h:333
void ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum, bst_node_t leaf_right_child=kInvalidNodeId)
Expands a leaf node into two additional leaf nodes.
Node & operator[](int nid)
get node given nid
Definition: tree_model.h:329
RegTree()
constructor
Definition: tree_model.h:316
static constexpr bst_node_t kInvalidNodeId
Definition: tree_model.h:134
static constexpr uint32_t kDeletedNodeMarker
Definition: tree_model.h:135
int GetDepth(int nid) const
get current depth
Definition: tree_model.h:456
void Load(dmlc::Stream *fi)
load model from stream
bst_node_t GetNumLeaves() const
void CalculateContributions(const RegTree::FVec &feat, std::vector< float > *mean_values, bst_float *out_contribs, int condition=0, unsigned condition_feature=0) const
calculate the feature contributions (https://arxiv.org/abs/1706.06060) for the tree
common::Span< uint32_t const > NodeCats(bst_node_t nidx) const
Get the bit storage for categories.
Definition: tree_model.h:610
CategoricalSplitMatrix GetCategoriesMatrix() const
Definition: tree_model.h:633
RTreeNodeStat & Stat(int nid)
get node statistics given nid
Definition: tree_model.h:344
bst_float SplitCondT
Definition: tree_model.h:133
bool Equal(const RegTree &b) const
Compares whether 2 trees are equal from a user's perspective. The equality compares only non-deleted ...
std::vector< FeatureType > const & GetSplitTypes() const
Get split types for all nodes.
Definition: tree_model.h:605
TreeParam param
model parameter
Definition: tree_model.h:314
void CollapseToLeaf(int rid, bst_float value)
collapse a non leaf node to a leaf node, delete its children
Definition: tree_model.h:302
int NumExtraNodes() const
number of extra nodes besides the root
Definition: tree_model.h:483
void ChangeToLeaf(int rid, bst_float value)
change a non leaf node to a leaf node, delete its children
Definition: tree_model.h:290
const std::vector< RTreeNodeStat > & GetStats() const
get const reference to stats
Definition: tree_model.h:341
const std::vector< Node > & GetNodes() const
get const reference to nodes
Definition: tree_model.h:338
void CalculateContributionsApprox(const RegTree::FVec &feat, std::vector< float > *mean_values, bst_float *out_contribs) const
calculate the approximate feature contributions for the given root
void LoadModel(Json const &in) override
load the model from a JSON object
std::string DumpModel(const FeatureMap &fmap, bool with_stats, std::string format) const
dump the model in the requested format as a text string
FeatureType NodeSplitType(bst_node_t nidx) const
Get split type for a node.
Definition: tree_model.h:599
void TreeShap(const RegTree::FVec &feat, bst_float *phi, bst_node_t node_index, unsigned unique_depth, PathElement *parent_unique_path, bst_float parent_zero_fraction, bst_float parent_one_fraction, int parent_feature_index, int condition, unsigned condition_feature, bst_float condition_fraction) const
Recursive function that computes the feature attributions for a single tree.
common::Span< uint32_t const > GetSplitCategories() const
Definition: tree_model.h:606
bool HasCategoricalSplit() const
Definition: tree_model.h:448
static constexpr bst_node_t kRoot
Definition: tree_model.h:136
bst_node_t GetNumSplitNodes() const
auto const & GetSplitCategoriesPtr() const
Definition: tree_model.h:617
int MaxDepth()
get maximum depth
Definition: tree_model.h:478
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition: span.h:423
XGBOOST_DEVICE auto subspan() const -> Span< element_type, detail::ExtentValue< Extent, Offset, Count >::value >
Definition: span.h:595
The input data structure of xgboost.
Feature map data structure to help visualization and model dump.
Defines the abstract interface for different components in XGBoost.
namespace of xgboost
Definition: base.h:110
uint32_t bst_feature_t
Type for data column (feature) index.
Definition: base.h:123
FeatureType
Definition: data.h:41
int32_t bst_node_t
Type for tree node index.
Definition: base.h:134
float bst_float
float type, used for storing statistics
Definition: base.h:119
node statistics used in regression tree
Definition: tree_model.h:98
RTreeNodeStat(float loss_chg, float sum_hess, float weight)
Definition: tree_model.h:109
bst_float loss_chg
loss change caused by current split
Definition: tree_model.h:100
int leaf_child_cnt
number of child that is leaf node known up to now
Definition: tree_model.h:106
bst_float sum_hess
sum of hessian values, used to measure coverage of data
Definition: tree_model.h:102
bool operator==(const RTreeNodeStat &b) const
Definition: tree_model.h:111
RTreeNodeStat ByteSwap() const
Definition: tree_model.h:117
bst_float base_weight
weight of current node
Definition: tree_model.h:104
Definition: tree_model.h:627
common::Span< uint32_t const > categories
Definition: tree_model.h:629
common::Span< Segment const > node_ptr
Definition: tree_model.h:630
common::Span< FeatureType const > split_type
Definition: tree_model.h:628
dense feature vector that can be taken by RegTree and can be construct from sparse feature vector.
Definition: tree_model.h:495
bool HasMissing() const
Definition: tree_model.h:734
void Fill(const SparsePage::Inst &inst)
fill the vector with sparse vector
Definition: tree_model.h:700
bool IsMissing(size_t i) const
check whether i-th entry is missing
Definition: tree_model.h:730
size_t Size() const
returns the size of the feature vector
Definition: tree_model.h:722
bst_float GetFvalue(size_t i) const
get ith value
Definition: tree_model.h:726
void Init(size_t size)
initialize the vector with size vector
Definition: tree_model.h:693
void Drop(const SparsePage::Inst &inst)
drop the trace after fill, must be called after fill.
Definition: tree_model.h:712
Definition: tree_model.h:622
size_t size
Definition: tree_model.h:624
size_t beg
Definition: tree_model.h:623
meta parameters of the tree
Definition: tree_model.h:35
bst_feature_t num_feature
number of features used for tree construction
Definition: tree_model.h:45
int num_nodes
total number of nodes
Definition: tree_model.h:39
int num_deleted
number of deleted nodes
Definition: tree_model.h:41
bool operator==(const TreeParam &b) const
Definition: tree_model.h:89
int reserved[31]
reserved part, make sure alignment works for 64bit
Definition: tree_model.h:52
TreeParam ByteSwap() const
Definition: tree_model.h:65
TreeParam()
constructor
Definition: tree_model.h:54
DMLC_DECLARE_PARAMETER(TreeParam)
Definition: tree_model.h:78
int size_leaf_vector
leaf vector size, used for vector tree used to store more than one dimensional information in tree
Definition: tree_model.h:50
int deprecated_num_roots
(Deprecated) number of start root
Definition: tree_model.h:37
int deprecated_max_depth
maximum depth, this is a statistics of the tree
Definition: tree_model.h:43