7 #ifndef XGBOOST_TREE_MODEL_H_
8 #define XGBOOST_TREE_MODEL_H_
11 #include <dmlc/parameter.h>
16 #include <xgboost/logging.h>
35 struct TreeParam :
public dmlc::Parameter<TreeParam> {
56 static_assert(
sizeof(
TreeParam) == (31 + 6) *
sizeof(
int),
"TreeParam: 64 bit align");
78 DMLC_DECLARE_FIELD(
num_nodes).set_lower_bound(1).set_default(1);
81 .describe(
"Number of features used in tree construction.");
86 .describe(
"Size of leaf vector, reserved for vector tree");
128 template <
typename T>
130 std::unique_ptr<T> ptr_{
nullptr};
137 ptr_ = std::make_unique<T>(*that);
140 T*
get() const noexcept {
return ptr_.get(); }
148 explicit operator bool()
const {
return static_cast<bool>(ptr_); }
150 void reset(T* ptr) { ptr_.reset(ptr); }
170 static_assert(
sizeof(
Node) == 4 *
sizeof(
int) +
sizeof(Info),
171 "Node: 64 bit align");
173 Node(int32_t cleft, int32_t cright, int32_t parent,
174 uint32_t split_ind,
float split_cond,
bool default_left) :
175 parent_{parent}, cleft_{cleft}, cright_{cright} {
177 this->
SetSplit(split_ind, split_cond, default_left);
190 return sindex_ & ((1U << 31) - 1U);
229 bool default_left =
false) {
230 if (default_left) split_index |= (1U << 31);
231 this->sindex_ = split_index;
232 (this->info_).split_cond = split_cond;
241 (this->info_).leaf_value = value;
243 this->cright_ = right;
255 if (is_left_child) pidx |= (1U << 31);
256 this->parent_ = pidx;
259 return parent_ == b.parent_ && cleft_ == b.cleft_ &&
260 cright_ == b.cright_ && sindex_ == b.sindex_ &&
261 info_.leaf_value == b.info_.leaf_value;
266 dmlc::ByteSwap(&x.parent_,
sizeof(x.parent_), 1);
267 dmlc::ByteSwap(&x.cleft_,
sizeof(x.cleft_), 1);
268 dmlc::ByteSwap(&x.cright_,
sizeof(x.cright_), 1);
269 dmlc::ByteSwap(&x.sindex_,
sizeof(x.sindex_), 1);
270 dmlc::ByteSwap(&x.info_,
sizeof(x.info_), 1);
302 this->DeleteNode(nodes_[rid].
LeftChild());
304 nodes_[rid].SetLeaf(value);
312 if (nodes_[rid].
IsLeaf())
return;
327 split_categories_segments_.resize(param_.
num_nodes);
328 for (
int i = 0; i < param_.
num_nodes; i++) {
329 nodes_[i].SetLeaf(0.0f);
354 [[nodiscard]]
const std::vector<Node>&
GetNodes()
const {
return nodes_; }
357 [[nodiscard]]
const std::vector<RTreeNodeStat>&
GetStats()
const {
return stats_; }
377 void Save(dmlc::Stream* fo)
const;
383 return nodes_ == b.nodes_ && stats_ == b.stats_ &&
384 deleted_nodes_ == b.deleted_nodes_ && param_ == b.param_;
391 template <
typename Func>
void WalkTree(Func func)
const {
392 std::stack<bst_node_t> nodes;
395 while (!nodes.empty()) {
396 auto nidx = nodes.top();
401 auto left =
self[nidx].LeftChild();
402 auto right =
self[nidx].RightChild();
437 bool default_left,
bst_float base_weight,
439 bst_float loss_change,
float sum_hess,
float left_sum,
469 float left_sum,
float right_sum);
477 [[nodiscard]]
bool IsMultiTarget()
const {
return static_cast<bool>(p_mt_tree_); }
487 return p_mt_tree_.get();
519 return this->p_mt_tree_->Depth(nid);
522 while (!nodes_[nid].
IsRoot()) {
524 nid = nodes_[nid].Parent();
533 return this->p_mt_tree_->SetLeaf(nidx, weight);
541 if (nodes_[nid].
IsLeaf())
return 0;
559 void Init(
size_t size);
575 [[nodiscard]]
size_t Size()
const;
587 [[nodiscard]]
bool IsMissing(
size_t i)
const;
600 std::vector<Entry> data_;
610 std::vector<float>* mean_values,
620 std::string format)
const;
634 return split_categories_;
642 auto segment = node_ptr[nidx];
643 auto node_cats = categories.
subspan(segment.beg, segment.size);
675 return this->p_mt_tree_->SplitIndex(nidx);
677 return (*
this)[nidx].SplitIndex();
681 return this->p_mt_tree_->SplitCond(nidx);
683 return (*
this)[nidx].SplitCond();
687 return this->p_mt_tree_->DefaultLeft(nidx);
689 return (*
this)[nidx].DefaultLeft();
693 return nidx ==
kRoot;
695 return (*
this)[nidx].IsRoot();
699 return this->p_mt_tree_->IsLeaf(nidx);
701 return (*
this)[nidx].IsLeaf();
705 return this->p_mt_tree_->Parent(nidx);
707 return (*
this)[nidx].Parent();
711 return this->p_mt_tree_->LeftChild(nidx);
713 return (*
this)[nidx].LeftChild();
717 return this->p_mt_tree_->RightChild(nidx);
719 return (*
this)[nidx].RightChild();
723 CHECK_NE(nidx,
kRoot);
724 auto p = this->p_mt_tree_->Parent(nidx);
725 return nidx == this->p_mt_tree_->LeftChild(p);
727 return (*
this)[nidx].IsLeftChild();
731 return this->p_mt_tree_->Size();
733 return this->nodes_.size();
737 template <
bool typed>
738 void LoadCategoricalSplit(
Json const& in);
739 void SaveCategoricalSplit(
Json* p_out)
const;
743 std::vector<Node> nodes_;
745 std::vector<int> deleted_nodes_;
747 std::vector<RTreeNodeStat> stats_;
748 std::vector<FeatureType> split_types_;
751 std::vector<uint32_t> split_categories_;
753 std::vector<CategoricalSplitMatrix::Segment> split_categories_segments_;
760 int nid = deleted_nodes_.back();
761 deleted_nodes_.pop_back();
767 CHECK_LT(param_.
num_nodes, std::numeric_limits<int>::max())
768 <<
"number of nodes in the tree exceed 2^31";
772 split_categories_segments_.resize(param_.
num_nodes);
776 void DeleteNode(
int nid) {
778 auto pid = (*this)[nid].Parent();
785 deleted_nodes_.push_back(nid);
786 nodes_[nid].MarkDelete();
792 Entry e; e.flag = -1;
794 std::fill(data_.begin(), data_.end(), e);
799 size_t feature_count = 0;
800 for (
auto const& entry : inst) {
801 if (entry.index >= data_.size()) {
804 data_[entry.index].fvalue = entry.fvalue;
807 has_missing_ = data_.size() != feature_count;
813 std::fill_n(data_.data(), data_.size(), e);
822 return data_[i].fvalue;
826 return data_[i].flag == -1;
835 return " support for multi-target tree is not yet implemented.";
Defines configuration macros and basic types for xgboost.
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition: base.h:64
Helper for defining copyable data structure that contains unique pointers.
Definition: tree_model.h:129
T const * operator->() const noexcept
Definition: tree_model.h:146
T * get() const noexcept
Definition: tree_model.h:140
bool operator!() const
Definition: tree_model.h:149
CopyUniquePtr(CopyUniquePtr const &that)
Definition: tree_model.h:134
T * operator->() noexcept
Definition: tree_model.h:143
T & operator*()
Definition: tree_model.h:142
T const & operator*() const
Definition: tree_model.h:145
void reset(T *ptr)
Definition: tree_model.h:150
Feature map data structure to help text model dump. TODO(tqchen) consider make it even more lightweig...
Definition: feature_map.h:22
Data structure representing JSON format.
Definition: json.h:357
Tree structure for multi-target model.
Definition: multi_target_tree_model.h:23
static constexpr bst_node_t InvalidNodeId()
Definition: multi_target_tree_model.h:25
tree node
Definition: tree_model.h:166
XGBOOST_DEVICE int Parent() const
get parent of the node
Definition: tree_model.h:201
XGBOOST_DEVICE void MarkDelete()
mark that this node is deleted
Definition: tree_model.h:246
XGBOOST_DEVICE bool IsRoot() const
whether current node is root
Definition: tree_model.h:207
XGBOOST_DEVICE int RightChild() const
index of right child
Definition: tree_model.h:183
XGBOOST_DEVICE float LeafValue() const
Definition: tree_model.h:197
XGBOOST_DEVICE Node()
Definition: tree_model.h:168
XGBOOST_DEVICE void SetParent(int pidx, bool is_left_child=true)
Definition: tree_model.h:254
XGBOOST_DEVICE unsigned SplitIndex() const
feature index of split condition
Definition: tree_model.h:189
Node ByteSwap() const
Definition: tree_model.h:264
XGBOOST_DEVICE void SetLeaf(bst_float value, int right=kInvalidNodeId)
set the leaf value of the node
Definition: tree_model.h:240
XGBOOST_DEVICE bool IsLeftChild() const
whether current node is left child
Definition: tree_model.h:203
XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond, bool default_left=false)
set split condition of current node
Definition: tree_model.h:228
XGBOOST_DEVICE void SetLeftChild(int nid)
set the left child
Definition: tree_model.h:212
XGBOOST_DEVICE bool IsDeleted() const
whether this node is deleted
Definition: tree_model.h:205
XGBOOST_DEVICE bool IsLeaf() const
whether current node is leaf node
Definition: tree_model.h:195
bool operator==(const Node &b) const
Definition: tree_model.h:258
Node(int32_t cleft, int32_t cright, int32_t parent, uint32_t split_ind, float split_cond, bool default_left)
Definition: tree_model.h:173
XGBOOST_DEVICE void Reuse()
Reuse this deleted node.
Definition: tree_model.h:250
XGBOOST_DEVICE void SetRightChild(int nid)
set the right child
Definition: tree_model.h:219
XGBOOST_DEVICE bool DefaultLeft() const
when feature is unknown, whether goes to left child
Definition: tree_model.h:193
XGBOOST_DEVICE int LeftChild() const
index of left child
Definition: tree_model.h:181
XGBOOST_DEVICE int DefaultChild() const
index of default child when feature is missing
Definition: tree_model.h:185
XGBOOST_DEVICE SplitCondT SplitCond() const
Definition: tree_model.h:199
define regression tree to be the most common tree model.
Definition: tree_model.h:158
int MaxDepth(int nid) const
get maximum depth
Definition: tree_model.h:540
void SaveModel(Json *out) const override
saves the model config to a JSON object
bst_target_t NumTargets() const
The size of leaf weight.
Definition: tree_model.h:481
void WalkTree(Func func) const
Definition: tree_model.h:391
void Save(dmlc::Stream *fo) const
save model to stream
bool IsLeaf(bst_node_t nidx) const
Definition: tree_model.h:697
bool operator==(const RegTree &b) const
Definition: tree_model.h:382
const RTreeNodeStat & Stat(int nid) const
get node statistics given nid
Definition: tree_model.h:364
void ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left, linalg::VectorView< float const > base_weight, linalg::VectorView< float const > left_weight, linalg::VectorView< float const > right_weight)
Expands a leaf node into two additional leaf nodes for a multi-target tree.
bst_node_t Parent(bst_node_t nidx) const
Definition: tree_model.h:703
bst_node_t NumNodes() const noexcept
Get the total number of nodes including deleted ones in this tree.
Definition: tree_model.h:496
const Node & operator[](int nid) const
get node given nid
Definition: tree_model.h:349
void ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum, bst_node_t leaf_right_child=kInvalidNodeId)
Expands a leaf node into two additional leaf nodes.
Node & operator[](int nid)
get node given nid
Definition: tree_model.h:345
RegTree()
Definition: tree_model.h:322
static constexpr bst_node_t kInvalidNodeId
Definition: tree_model.h:161
bst_feature_t SplitIndex(bst_node_t nidx) const
Definition: tree_model.h:673
bool IsRoot(bst_node_t nidx) const
Definition: tree_model.h:691
static constexpr uint32_t kDeletedNodeMarker
Definition: tree_model.h:162
bool IsMultiTarget() const
Whether this is a multi-target tree.
Definition: tree_model.h:477
bst_node_t NumExtraNodes() const noexcept
number of extra nodes besides the root
Definition: tree_model.h:506
bool DefaultLeft(bst_node_t nidx) const
Definition: tree_model.h:685
auto GetMultiTargetTree() const
Get the underlying implementaiton of multi-target tree.
Definition: tree_model.h:485
void Load(dmlc::Stream *fi)
load model from stream
bst_node_t LeftChild(bst_node_t nidx) const
Definition: tree_model.h:709
bst_node_t GetNumLeaves() const
RegTree(bst_target_t n_targets, bst_feature_t n_features)
Constructor that initializes the tree model with shape.
Definition: tree_model.h:336
bst_node_t RightChild(bst_node_t nidx) const
Definition: tree_model.h:715
common::Span< uint32_t const > NodeCats(bst_node_t nidx) const
Get the bit storage for categories.
Definition: tree_model.h:639
bool IsLeftChild(bst_node_t nidx) const
Definition: tree_model.h:721
CategoricalSplitMatrix GetCategoriesMatrix() const
Definition: tree_model.h:665
RTreeNodeStat & Stat(int nid)
get node statistics given nid
Definition: tree_model.h:360
bst_float SplitCondT
Definition: tree_model.h:160
void ExpandCategorical(bst_node_t nid, bst_feature_t split_index, common::Span< const uint32_t > split_cat, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum)
Expands a leaf node with categories.
bool Equal(const RegTree &b) const
Compares whether 2 trees are equal from a user's perspective. The equality compares only non-deleted ...
std::vector< FeatureType > const & GetSplitTypes() const
Get split types for all nodes.
Definition: tree_model.h:630
void CollapseToLeaf(int rid, bst_float value)
collapse a non leaf node to a leaf node, delete its children
Definition: tree_model.h:311
bst_node_t NumValidNodes() const noexcept
Get the total number of valid nodes in this tree.
Definition: tree_model.h:500
void ChangeToLeaf(int rid, bst_float value)
change a non leaf node to a leaf node, delete its children
Definition: tree_model.h:299
const std::vector< RTreeNodeStat > & GetStats() const
get const reference to stats
Definition: tree_model.h:357
void SetLeaf(bst_node_t nidx, linalg::VectorView< float const > weight)
Set the leaf weight for a multi-target tree.
Definition: tree_model.h:531
const std::vector< Node > & GetNodes() const
get const reference to nodes
Definition: tree_model.h:354
void CalculateContributionsApprox(const RegTree::FVec &feat, std::vector< float > *mean_values, bst_float *out_contribs) const
calculate the approximate feature contributions for the given root
void LoadModel(Json const &in) override
load the model from a JSON object
std::string DumpModel(const FeatureMap &fmap, bool with_stats, std::string format) const
dump the model in the requested format as a text string
FeatureType NodeSplitType(bst_node_t nidx) const
Get split type for a node.
Definition: tree_model.h:626
bst_feature_t NumFeatures() const noexcept
Get the number of features.
Definition: tree_model.h:492
common::Span< uint32_t const > GetSplitCategories() const
Definition: tree_model.h:633
bool HasCategoricalSplit() const
Whether this tree has categorical split.
Definition: tree_model.h:473
std::int32_t GetDepth(bst_node_t nid) const
get current depth
Definition: tree_model.h:517
static constexpr bst_node_t kRoot
Definition: tree_model.h:163
bst_node_t GetNumSplitNodes() const
auto const & GetSplitCategoriesPtr() const
Definition: tree_model.h:646
float SplitCond(bst_node_t nidx) const
Definition: tree_model.h:679
int MaxDepth()
get maximum depth
Definition: tree_model.h:548
bst_node_t Size() const
Definition: tree_model.h:729
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition: span.h:424
XGBOOST_DEVICE auto subspan() const -> Span< element_type, detail::ExtentValue< Extent, Offset, Count >::value >
Definition: span.h:596
A tensor view with static type and dimension. It implements indexing and slicing.
Definition: linalg.h:293
The input data structure of xgboost.
Feature map data structure to help visualization and model dump.
Linear algebra related utilities.
Defines the abstract interface for different components in XGBoost.
namespace of xgboost
Definition: base.h:90
std::vector< std::pair< std::string, std::string > > Args
Definition: base.h:316
uint32_t bst_feature_t
Type for data column (feature) index.
Definition: base.h:101
std::int32_t bst_node_t
Type for tree node index.
Definition: base.h:112
FeatureType
Definition: data.h:41
std::uint32_t bst_target_t
Type for indexing into output targets.
Definition: base.h:118
float bst_float
float type, used for storing statistics
Definition: base.h:97
StringView MTNotImplemented()
Definition: tree_model.h:834
node statistics used in regression tree
Definition: tree_model.h:96
RTreeNodeStat(float loss_chg, float sum_hess, float weight)
Definition: tree_model.h:107
bst_float loss_chg
loss change caused by current split
Definition: tree_model.h:98
int leaf_child_cnt
number of child that is leaf node known up to now
Definition: tree_model.h:104
bst_float sum_hess
sum of hessian values, used to measure coverage of data
Definition: tree_model.h:100
bool operator==(const RTreeNodeStat &b) const
Definition: tree_model.h:109
RTreeNodeStat ByteSwap() const
Definition: tree_model.h:115
bst_float base_weight
weight of current node
Definition: tree_model.h:102
Definition: tree_model.h:656
std::size_t size
Definition: tree_model.h:658
std::size_t beg
Definition: tree_model.h:657
CSR-like matrix for categorical splits.
Definition: tree_model.h:655
common::Span< uint32_t const > categories
Definition: tree_model.h:661
common::Span< Segment const > node_ptr
Definition: tree_model.h:662
common::Span< FeatureType const > split_type
Definition: tree_model.h:660
dense feature vector that can be taken by RegTree and can be construct from sparse feature vector.
Definition: tree_model.h:554
void Drop()
drop the trace after fill, must be called after fill.
Definition: tree_model.h:810
bool HasMissing() const
Definition: tree_model.h:829
void Fill(const SparsePage::Inst &inst)
fill the vector with sparse vector
Definition: tree_model.h:798
bool IsMissing(size_t i) const
check whether i-th entry is missing
Definition: tree_model.h:825
size_t Size() const
returns the size of the feature vector
Definition: tree_model.h:817
bst_float GetFvalue(size_t i) const
get ith value
Definition: tree_model.h:821
void Init(size_t size)
initialize the vector with size vector
Definition: tree_model.h:791
Definition: string_view.h:15
meta parameters of the tree
Definition: tree_model.h:35
bst_feature_t num_feature
number of features used for tree construction
Definition: tree_model.h:45
int num_nodes
total number of nodes
Definition: tree_model.h:39
int num_deleted
number of deleted nodes
Definition: tree_model.h:41
bool operator==(const TreeParam &b) const
Definition: tree_model.h:89
int reserved[31]
reserved part, make sure alignment works for 64bit
Definition: tree_model.h:52
TreeParam ByteSwap() const
Definition: tree_model.h:62
TreeParam()
constructor
Definition: tree_model.h:54
DMLC_DECLARE_PARAMETER(TreeParam)
Definition: tree_model.h:75
bst_target_t size_leaf_vector
leaf vector size, used for vector tree used to store more than one dimensional information in tree
Definition: tree_model.h:50
int deprecated_num_roots
(Deprecated) number of start root
Definition: tree_model.h:37
int deprecated_max_depth
maximum depth, this is a statistics of the tree
Definition: tree_model.h:43