7 #ifndef XGBOOST_TREE_MODEL_H_
8 #define XGBOOST_TREE_MODEL_H_
11 #include <dmlc/parameter.h>
16 #include <xgboost/logging.h>
34 struct TreeParam :
public dmlc::Parameter<TreeParam> {
55 static_assert(
sizeof(
TreeParam) == (31 + 6) *
sizeof(
int),
"TreeParam: 64 bit align");
77 DMLC_DECLARE_FIELD(
num_nodes).set_lower_bound(1).set_default(1);
80 .describe(
"Number of features used in tree construction.");
85 .describe(
"Size of leaf vector, reserved for vector tree");
127 template <
typename T>
129 std::unique_ptr<T> ptr_{
nullptr};
136 ptr_ = std::make_unique<T>(*that);
139 T*
get() const noexcept {
return ptr_.get(); }
147 explicit operator bool()
const {
return static_cast<bool>(ptr_); }
149 void reset(T* ptr) { ptr_.reset(ptr); }
169 static_assert(
sizeof(
Node) == 4 *
sizeof(
int) +
sizeof(Info),
170 "Node: 64 bit align");
172 Node(int32_t cleft, int32_t cright, int32_t parent,
173 uint32_t split_ind,
float split_cond,
bool default_left) :
174 parent_{parent}, cleft_{cleft}, cright_{cright} {
176 this->
SetSplit(split_ind, split_cond, default_left);
189 static_assert(!std::is_signed_v<bst_feature_t>);
190 return sindex_ & ((1U << 31) - 1U);
229 bool default_left =
false) {
230 if (default_left) split_index |= (1U << 31);
231 this->sindex_ = split_index;
232 (this->info_).split_cond = split_cond;
241 (this->info_).leaf_value = value;
243 this->cright_ = right;
255 if (is_left_child) pidx |= (1U << 31);
256 this->parent_ = pidx;
259 return parent_ == b.parent_ && cleft_ == b.cleft_ &&
260 cright_ == b.cright_ && sindex_ == b.sindex_ &&
261 info_.leaf_value == b.info_.leaf_value;
266 dmlc::ByteSwap(&x.parent_,
sizeof(x.parent_), 1);
267 dmlc::ByteSwap(&x.cleft_,
sizeof(x.cleft_), 1);
268 dmlc::ByteSwap(&x.cright_,
sizeof(x.cright_), 1);
269 dmlc::ByteSwap(&x.sindex_,
sizeof(x.sindex_), 1);
270 dmlc::ByteSwap(&x.info_,
sizeof(x.info_), 1);
302 this->DeleteNode(nodes_[rid].
LeftChild());
304 nodes_[rid].SetLeaf(value);
312 if (nodes_[rid].
IsLeaf())
return;
327 split_categories_segments_.resize(param_.
num_nodes);
328 for (
int i = 0; i < param_.
num_nodes; i++) {
329 nodes_[i].SetLeaf(0.0f);
354 [[nodiscard]]
const std::vector<Node>&
GetNodes()
const {
return nodes_; }
357 [[nodiscard]]
const std::vector<RTreeNodeStat>&
GetStats()
const {
return stats_; }
377 void Save(dmlc::Stream* fo)
const;
383 return nodes_ == b.nodes_ && stats_ == b.stats_ &&
384 deleted_nodes_ == b.deleted_nodes_ && param_ == b.param_;
391 template <
typename Func>
void WalkTree(Func func)
const {
392 std::stack<bst_node_t> nodes;
395 while (!nodes.empty()) {
396 auto nidx = nodes.top();
401 auto left =
self.LeftChild(nidx);
402 auto right =
self.RightChild(nidx);
437 bool default_left,
bst_float base_weight,
439 bst_float loss_change,
float sum_hess,
float left_sum,
469 float left_sum,
float right_sum);
477 [[nodiscard]]
bool IsMultiTarget()
const {
return static_cast<bool>(p_mt_tree_); }
487 return p_mt_tree_.get();
519 return this->p_mt_tree_->Depth(nid);
522 while (!nodes_[nid].
IsRoot()) {
524 nid = nodes_[nid].Parent();
533 return this->p_mt_tree_->SetLeaf(nidx, weight);
541 if (nodes_[nid].
IsLeaf())
return 0;
559 void Init(
size_t size);
575 [[nodiscard]]
size_t Size()
const;
587 [[nodiscard]]
bool IsMissing(
size_t i)
const;
589 void HasMissing(
bool has_missing) { this->has_missing_ = has_missing; }
599 std::vector<float> data_;
609 std::vector<float>* mean_values,
619 std::string format)
const;
633 return split_categories_;
641 auto segment = node_ptr[nidx];
642 auto node_cats = categories.
subspan(segment.beg, segment.size);
674 return this->p_mt_tree_->SplitIndex(nidx);
676 return (*
this)[nidx].SplitIndex();
680 return this->p_mt_tree_->SplitCond(nidx);
682 return (*
this)[nidx].SplitCond();
686 return this->p_mt_tree_->DefaultLeft(nidx);
688 return (*
this)[nidx].DefaultLeft();
695 return nidx ==
kRoot;
697 return (*
this)[nidx].IsRoot();
701 return this->p_mt_tree_->IsLeaf(nidx);
703 return (*
this)[nidx].IsLeaf();
707 return this->p_mt_tree_->Parent(nidx);
709 return (*
this)[nidx].Parent();
713 return this->p_mt_tree_->LeftChild(nidx);
715 return (*
this)[nidx].LeftChild();
719 return this->p_mt_tree_->RightChild(nidx);
721 return (*
this)[nidx].RightChild();
725 CHECK_NE(nidx,
kRoot);
726 auto p = this->p_mt_tree_->Parent(nidx);
727 return nidx == this->p_mt_tree_->LeftChild(p);
729 return (*
this)[nidx].IsLeftChild();
733 return this->p_mt_tree_->Size();
735 return this->nodes_.size();
739 template <
bool typed>
740 void LoadCategoricalSplit(
Json const& in);
741 void SaveCategoricalSplit(
Json* p_out)
const;
745 std::vector<Node> nodes_;
747 std::vector<int> deleted_nodes_;
749 std::vector<RTreeNodeStat> stats_;
750 std::vector<FeatureType> split_types_;
753 std::vector<uint32_t> split_categories_;
755 std::vector<CategoricalSplitMatrix::Segment> split_categories_segments_;
762 int nid = deleted_nodes_.back();
763 deleted_nodes_.pop_back();
769 CHECK_LT(param_.
num_nodes, std::numeric_limits<int>::max())
770 <<
"number of nodes in the tree exceed 2^31";
774 split_categories_segments_.resize(param_.
num_nodes);
778 void DeleteNode(
int nid) {
780 auto pid = (*this)[nid].Parent();
787 deleted_nodes_.push_back(nid);
788 nodes_[nid].MarkDelete();
795 std::fill(data_.begin(), data_.end(), std::numeric_limits<float>::quiet_NaN());
800 auto p_data = inst.
data();
801 auto p_out = data_.data();
803 for (std::size_t i = 0, n = inst.
size(); i < n; ++i) {
804 auto const& entry = p_data[i];
805 p_out[entry.index] = entry.fvalue;
807 has_missing_ = data_.size() != inst.
size();
826 return " support for multi-target tree is not yet implemented.";
Defines configuration macros and basic types for xgboost.
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition: base.h:64
Helper for defining copyable data structure that contains unique pointers.
Definition: tree_model.h:128
T const * operator->() const noexcept
Definition: tree_model.h:145
T * get() const noexcept
Definition: tree_model.h:139
bool operator!() const
Definition: tree_model.h:148
CopyUniquePtr(CopyUniquePtr const &that)
Definition: tree_model.h:133
T * operator->() noexcept
Definition: tree_model.h:142
T & operator*()
Definition: tree_model.h:141
T const & operator*() const
Definition: tree_model.h:144
void reset(T *ptr)
Definition: tree_model.h:149
Feature map data structure to help text model dump. TODO(tqchen) consider make it even more lightweig...
Definition: feature_map.h:22
Data structure representing JSON format.
Definition: json.h:378
Tree structure for multi-target model.
Definition: multi_target_tree_model.h:23
static constexpr bst_node_t InvalidNodeId()
Definition: multi_target_tree_model.h:25
tree node
Definition: tree_model.h:165
XGBOOST_DEVICE int Parent() const
get parent of the node
Definition: tree_model.h:201
XGBOOST_DEVICE void MarkDelete()
mark that this node is deleted
Definition: tree_model.h:246
XGBOOST_DEVICE bool IsRoot() const
whether current node is root
Definition: tree_model.h:207
XGBOOST_DEVICE int RightChild() const
index of right child
Definition: tree_model.h:182
XGBOOST_DEVICE float LeafValue() const
Definition: tree_model.h:197
XGBOOST_DEVICE Node()
Definition: tree_model.h:167
XGBOOST_DEVICE void SetParent(int pidx, bool is_left_child=true)
Definition: tree_model.h:254
Node ByteSwap() const
Definition: tree_model.h:264
XGBOOST_DEVICE void SetLeaf(bst_float value, int right=kInvalidNodeId)
set the leaf value of the node
Definition: tree_model.h:240
XGBOOST_DEVICE bool IsLeftChild() const
whether current node is left child
Definition: tree_model.h:203
XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond, bool default_left=false)
set split condition of current node
Definition: tree_model.h:228
XGBOOST_DEVICE void SetLeftChild(int nid)
set the left child
Definition: tree_model.h:212
XGBOOST_DEVICE bst_feature_t SplitIndex() const
feature index of split condition
Definition: tree_model.h:188
XGBOOST_DEVICE bool IsDeleted() const
whether this node is deleted
Definition: tree_model.h:205
XGBOOST_DEVICE bool IsLeaf() const
whether current node is leaf node
Definition: tree_model.h:195
bool operator==(const Node &b) const
Definition: tree_model.h:258
Node(int32_t cleft, int32_t cright, int32_t parent, uint32_t split_ind, float split_cond, bool default_left)
Definition: tree_model.h:172
XGBOOST_DEVICE void Reuse()
Reuse this deleted node.
Definition: tree_model.h:250
XGBOOST_DEVICE void SetRightChild(int nid)
set the right child
Definition: tree_model.h:219
XGBOOST_DEVICE bool DefaultLeft() const
when feature is unknown, whether goes to left child
Definition: tree_model.h:193
XGBOOST_DEVICE int LeftChild() const
index of left child
Definition: tree_model.h:180
XGBOOST_DEVICE int DefaultChild() const
index of default child when feature is missing
Definition: tree_model.h:184
XGBOOST_DEVICE SplitCondT SplitCond() const
Definition: tree_model.h:199
define regression tree to be the most common tree model.
Definition: tree_model.h:157
int MaxDepth(int nid) const
get maximum depth
Definition: tree_model.h:540
void SaveModel(Json *out) const override
saves the model config to a JSON object
bst_target_t NumTargets() const
The size of leaf weight.
Definition: tree_model.h:481
void WalkTree(Func func) const
Definition: tree_model.h:391
void Save(dmlc::Stream *fo) const
save model to stream
bool IsLeaf(bst_node_t nidx) const
Definition: tree_model.h:699
bool operator==(const RegTree &b) const
Definition: tree_model.h:382
const RTreeNodeStat & Stat(int nid) const
get node statistics given nid
Definition: tree_model.h:364
void ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left, linalg::VectorView< float const > base_weight, linalg::VectorView< float const > left_weight, linalg::VectorView< float const > right_weight)
Expands a leaf node into two additional leaf nodes for a multi-target tree.
bst_node_t Parent(bst_node_t nidx) const
Definition: tree_model.h:705
bst_node_t NumNodes() const noexcept
Get the total number of nodes including deleted ones in this tree.
Definition: tree_model.h:496
const Node & operator[](int nid) const
get node given nid
Definition: tree_model.h:349
bst_node_t DefaultChild(bst_node_t nidx) const
Definition: tree_model.h:690
void ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum, bst_node_t leaf_right_child=kInvalidNodeId)
Expands a leaf node into two additional leaf nodes.
Node & operator[](int nid)
get node given nid
Definition: tree_model.h:345
RegTree()
Definition: tree_model.h:322
static constexpr bst_node_t kInvalidNodeId
Definition: tree_model.h:160
bst_feature_t SplitIndex(bst_node_t nidx) const
Definition: tree_model.h:672
bool IsRoot(bst_node_t nidx) const
Definition: tree_model.h:693
static constexpr uint32_t kDeletedNodeMarker
Definition: tree_model.h:161
bool IsMultiTarget() const
Whether this is a multi-target tree.
Definition: tree_model.h:477
bst_node_t NumExtraNodes() const noexcept
number of extra nodes besides the root
Definition: tree_model.h:506
bool DefaultLeft(bst_node_t nidx) const
Definition: tree_model.h:684
auto GetMultiTargetTree() const
Get the underlying implementaiton of multi-target tree.
Definition: tree_model.h:485
void Load(dmlc::Stream *fi)
load model from stream
bst_node_t LeftChild(bst_node_t nidx) const
Definition: tree_model.h:711
bst_node_t GetNumLeaves() const
RegTree(bst_target_t n_targets, bst_feature_t n_features)
Constructor that initializes the tree model with shape.
Definition: tree_model.h:336
bst_node_t RightChild(bst_node_t nidx) const
Definition: tree_model.h:717
common::Span< uint32_t const > NodeCats(bst_node_t nidx) const
Get the bit storage for categories.
Definition: tree_model.h:638
bool IsLeftChild(bst_node_t nidx) const
Definition: tree_model.h:723
CategoricalSplitMatrix GetCategoriesMatrix() const
Definition: tree_model.h:664
RTreeNodeStat & Stat(int nid)
get node statistics given nid
Definition: tree_model.h:360
bst_float SplitCondT
Definition: tree_model.h:159
void ExpandCategorical(bst_node_t nid, bst_feature_t split_index, common::Span< const uint32_t > split_cat, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum)
Expands a leaf node with categories.
bool Equal(const RegTree &b) const
Compares whether 2 trees are equal from a user's perspective. The equality compares only non-deleted ...
std::vector< FeatureType > const & GetSplitTypes() const
Get split types for all nodes.
Definition: tree_model.h:629
void CollapseToLeaf(int rid, bst_float value)
collapse a non leaf node to a leaf node, delete its children
Definition: tree_model.h:311
bst_node_t NumValidNodes() const noexcept
Get the total number of valid nodes in this tree.
Definition: tree_model.h:500
void ChangeToLeaf(int rid, bst_float value)
change a non leaf node to a leaf node, delete its children
Definition: tree_model.h:299
const std::vector< RTreeNodeStat > & GetStats() const
get const reference to stats
Definition: tree_model.h:357
void SetLeaf(bst_node_t nidx, linalg::VectorView< float const > weight)
Set the leaf weight for a multi-target tree.
Definition: tree_model.h:531
const std::vector< Node > & GetNodes() const
get const reference to nodes
Definition: tree_model.h:354
void CalculateContributionsApprox(const RegTree::FVec &feat, std::vector< float > *mean_values, bst_float *out_contribs) const
calculate the approximate feature contributions for the given root
void LoadModel(Json const &in) override
load the model from a JSON object
std::string DumpModel(const FeatureMap &fmap, bool with_stats, std::string format) const
dump the model in the requested format as a text string
FeatureType NodeSplitType(bst_node_t nidx) const
Get split type for a node.
Definition: tree_model.h:625
bst_feature_t NumFeatures() const noexcept
Get the number of features.
Definition: tree_model.h:492
common::Span< uint32_t const > GetSplitCategories() const
Definition: tree_model.h:632
bool HasCategoricalSplit() const
Whether this tree has categorical split.
Definition: tree_model.h:473
std::int32_t GetDepth(bst_node_t nid) const
get current depth
Definition: tree_model.h:517
static constexpr bst_node_t kRoot
Definition: tree_model.h:162
bst_node_t GetNumSplitNodes() const
auto const & GetSplitCategoriesPtr() const
Definition: tree_model.h:645
float SplitCond(bst_node_t nidx) const
Definition: tree_model.h:678
int MaxDepth()
get maximum depth
Definition: tree_model.h:548
bst_node_t Size() const
Definition: tree_model.h:731
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition: span.h:431
constexpr XGBOOST_DEVICE pointer data() const __span_noexcept
Definition: span.h:550
XGBOOST_DEVICE auto subspan() const -> Span< element_type, detail::ExtentValue< Extent, Offset, Count >::value >
Definition: span.h:597
constexpr XGBOOST_DEVICE index_type size() const __span_noexcept
Definition: span.h:555
A tensor view with static type and dimension. It implements indexing and slicing.
Definition: linalg.h:294
The input data structure of xgboost.
Feature map data structure to help visualization and model dump.
Linear algebra related utilities.
Defines the abstract interface for different components in XGBoost.
Core data structure for multi-target trees.
Definition: base.h:89
std::vector< std::pair< std::string, std::string > > Args
Definition: base.h:316
std::int32_t bst_node_t
Type for tree node index.
Definition: base.h:111
FeatureType
Definition: data.h:41
std::uint32_t bst_target_t
Type for indexing into output targets.
Definition: base.h:119
std::uint32_t bst_feature_t
Type for data column (feature) index.
Definition: base.h:99
float bst_float
float type, used for storing statistics
Definition: base.h:95
StringView MTNotImplemented()
Definition: tree_model.h:825
node statistics used in regression tree
Definition: tree_model.h:95
RTreeNodeStat(float loss_chg, float sum_hess, float weight)
Definition: tree_model.h:106
bst_float loss_chg
loss change caused by current split
Definition: tree_model.h:97
int leaf_child_cnt
number of child that is leaf node known up to now
Definition: tree_model.h:103
bst_float sum_hess
sum of hessian values, used to measure coverage of data
Definition: tree_model.h:99
bool operator==(const RTreeNodeStat &b) const
Definition: tree_model.h:108
RTreeNodeStat ByteSwap() const
Definition: tree_model.h:114
bst_float base_weight
weight of current node
Definition: tree_model.h:101
Definition: tree_model.h:655
std::size_t size
Definition: tree_model.h:657
std::size_t beg
Definition: tree_model.h:656
CSR-like matrix for categorical splits.
Definition: tree_model.h:654
common::Span< uint32_t const > categories
Definition: tree_model.h:660
common::Span< Segment const > node_ptr
Definition: tree_model.h:661
common::Span< FeatureType const > split_type
Definition: tree_model.h:659
dense feature vector that can be taken by RegTree and can be construct from sparse feature vector.
Definition: tree_model.h:554
void HasMissing(bool has_missing)
Definition: tree_model.h:589
void Drop()
drop the trace after fill, must be called after fill.
Definition: tree_model.h:810
bool HasMissing() const
Definition: tree_model.h:822
bool IsMissing(size_t i) const
check whether i-th entry is missing
Definition: tree_model.h:820
size_t Size() const
returns the size of the feature vector
Definition: tree_model.h:812
void Init(size_t size)
initialize the vector with size vector
Definition: tree_model.h:793
common::Span< float > Data()
Definition: tree_model.h:591
void Fill(SparsePage::Inst const &inst)
fill the vector with sparse vector
Definition: tree_model.h:799
bst_float GetFvalue(size_t i) const
get ith value
Definition: tree_model.h:816
Definition: string_view.h:16
meta parameters of the tree
Definition: tree_model.h:34
bst_feature_t num_feature
number of features used for tree construction
Definition: tree_model.h:44
int num_nodes
total number of nodes
Definition: tree_model.h:38
int num_deleted
number of deleted nodes
Definition: tree_model.h:40
bool operator==(const TreeParam &b) const
Definition: tree_model.h:88
int reserved[31]
reserved part, make sure alignment works for 64bit
Definition: tree_model.h:51
TreeParam ByteSwap() const
Definition: tree_model.h:61
TreeParam()
constructor
Definition: tree_model.h:53
DMLC_DECLARE_PARAMETER(TreeParam)
Definition: tree_model.h:74
bst_target_t size_leaf_vector
leaf vector size, used for vector tree used to store more than one dimensional information in tree
Definition: tree_model.h:49
int deprecated_num_roots
(Deprecated) number of start root
Definition: tree_model.h:36
int deprecated_max_depth
maximum depth, this is a statistics of the tree
Definition: tree_model.h:42