7 #ifndef XGBOOST_TREE_MODEL_H_
8 #define XGBOOST_TREE_MODEL_H_
11 #include <dmlc/parameter.h>
16 #include <xgboost/logging.h>
35 struct TreeParam :
public dmlc::Parameter<TreeParam> {
56 static_assert(
sizeof(
TreeParam) == (31 + 6) *
sizeof(
int),
"TreeParam: 64 bit align");
78 DMLC_DECLARE_FIELD(
num_nodes).set_lower_bound(1).set_default(1);
81 .describe(
"Number of features used in tree construction.");
86 .describe(
"Size of leaf vector, reserved for vector tree");
128 template <
typename T>
130 std::unique_ptr<T> ptr_{
nullptr};
137 ptr_ = std::make_unique<T>(*that);
140 T*
get() const noexcept {
return ptr_.get(); }
148 explicit operator bool()
const {
return static_cast<bool>(ptr_); }
150 void reset(T* ptr) { ptr_.reset(ptr); }
170 static_assert(
sizeof(
Node) == 4 *
sizeof(
int) +
sizeof(Info),
171 "Node: 64 bit align");
173 Node(int32_t cleft, int32_t cright, int32_t parent,
174 uint32_t split_ind,
float split_cond,
bool default_left) :
175 parent_{parent}, cleft_{cleft}, cright_{cright} {
177 this->
SetSplit(split_ind, split_cond, default_left);
190 return sindex_ & ((1U << 31) - 1U);
229 bool default_left =
false) {
230 if (default_left) split_index |= (1U << 31);
231 this->sindex_ = split_index;
232 (this->info_).split_cond = split_cond;
241 (this->info_).leaf_value = value;
243 this->cright_ = right;
255 if (is_left_child) pidx |= (1U << 31);
256 this->parent_ = pidx;
259 return parent_ == b.parent_ && cleft_ == b.cleft_ &&
260 cright_ == b.cright_ && sindex_ == b.sindex_ &&
261 info_.leaf_value == b.info_.leaf_value;
266 dmlc::ByteSwap(&x.parent_,
sizeof(x.parent_), 1);
267 dmlc::ByteSwap(&x.cleft_,
sizeof(x.cleft_), 1);
268 dmlc::ByteSwap(&x.cright_,
sizeof(x.cright_), 1);
269 dmlc::ByteSwap(&x.sindex_,
sizeof(x.sindex_), 1);
270 dmlc::ByteSwap(&x.info_,
sizeof(x.info_), 1);
302 this->DeleteNode(nodes_[rid].
LeftChild());
304 nodes_[rid].SetLeaf(value);
312 if (nodes_[rid].
IsLeaf())
return;
327 split_categories_segments_.resize(param_.
num_nodes);
328 for (
int i = 0; i < param_.
num_nodes; i++) {
329 nodes_[i].SetLeaf(0.0f);
354 [[nodiscard]]
const std::vector<Node>&
GetNodes()
const {
return nodes_; }
357 [[nodiscard]]
const std::vector<RTreeNodeStat>&
GetStats()
const {
return stats_; }
377 void Save(dmlc::Stream* fo)
const;
383 return nodes_ == b.nodes_ && stats_ == b.stats_ &&
384 deleted_nodes_ == b.deleted_nodes_ && param_ == b.param_;
391 template <
typename Func>
void WalkTree(Func func)
const {
392 std::stack<bst_node_t> nodes;
395 while (!nodes.empty()) {
396 auto nidx = nodes.top();
401 auto left =
self.LeftChild(nidx);
402 auto right =
self.RightChild(nidx);
437 bool default_left,
bst_float base_weight,
439 bst_float loss_change,
float sum_hess,
float left_sum,
469 float left_sum,
float right_sum);
477 [[nodiscard]]
bool IsMultiTarget()
const {
return static_cast<bool>(p_mt_tree_); }
487 return p_mt_tree_.get();
519 return this->p_mt_tree_->Depth(nid);
522 while (!nodes_[nid].
IsRoot()) {
524 nid = nodes_[nid].Parent();
533 return this->p_mt_tree_->SetLeaf(nidx, weight);
541 if (nodes_[nid].
IsLeaf())
return 0;
559 void Init(
size_t size);
575 [[nodiscard]]
size_t Size()
const;
587 [[nodiscard]]
bool IsMissing(
size_t i)
const;
600 std::vector<Entry> data_;
610 std::vector<float>* mean_values,
620 std::string format)
const;
634 return split_categories_;
642 auto segment = node_ptr[nidx];
643 auto node_cats = categories.
subspan(segment.beg, segment.size);
675 return this->p_mt_tree_->SplitIndex(nidx);
677 return (*
this)[nidx].SplitIndex();
681 return this->p_mt_tree_->SplitCond(nidx);
683 return (*
this)[nidx].SplitCond();
687 return this->p_mt_tree_->DefaultLeft(nidx);
689 return (*
this)[nidx].DefaultLeft();
696 return nidx ==
kRoot;
698 return (*
this)[nidx].IsRoot();
702 return this->p_mt_tree_->IsLeaf(nidx);
704 return (*
this)[nidx].IsLeaf();
708 return this->p_mt_tree_->Parent(nidx);
710 return (*
this)[nidx].Parent();
714 return this->p_mt_tree_->LeftChild(nidx);
716 return (*
this)[nidx].LeftChild();
720 return this->p_mt_tree_->RightChild(nidx);
722 return (*
this)[nidx].RightChild();
726 CHECK_NE(nidx,
kRoot);
727 auto p = this->p_mt_tree_->Parent(nidx);
728 return nidx == this->p_mt_tree_->LeftChild(p);
730 return (*
this)[nidx].IsLeftChild();
734 return this->p_mt_tree_->Size();
736 return this->nodes_.size();
740 template <
bool typed>
741 void LoadCategoricalSplit(
Json const& in);
742 void SaveCategoricalSplit(
Json* p_out)
const;
746 std::vector<Node> nodes_;
748 std::vector<int> deleted_nodes_;
750 std::vector<RTreeNodeStat> stats_;
751 std::vector<FeatureType> split_types_;
754 std::vector<uint32_t> split_categories_;
756 std::vector<CategoricalSplitMatrix::Segment> split_categories_segments_;
763 int nid = deleted_nodes_.back();
764 deleted_nodes_.pop_back();
770 CHECK_LT(param_.
num_nodes, std::numeric_limits<int>::max())
771 <<
"number of nodes in the tree exceed 2^31";
775 split_categories_segments_.resize(param_.
num_nodes);
779 void DeleteNode(
int nid) {
781 auto pid = (*this)[nid].Parent();
788 deleted_nodes_.push_back(nid);
789 nodes_[nid].MarkDelete();
795 Entry e; e.flag = -1;
797 std::fill(data_.begin(), data_.end(), e);
802 size_t feature_count = 0;
803 for (
auto const& entry : inst) {
804 if (entry.index >= data_.size()) {
807 data_[entry.index].fvalue = entry.fvalue;
810 has_missing_ = data_.size() != feature_count;
816 std::fill_n(data_.data(), data_.size(), e);
825 return data_[i].fvalue;
829 return data_[i].flag == -1;
838 return " support for multi-target tree is not yet implemented.";
Defines configuration macros and basic types for xgboost.
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition: base.h:62
Helper for defining copyable data structure that contains unique pointers.
Definition: tree_model.h:129
T const * operator->() const noexcept
Definition: tree_model.h:146
T * get() const noexcept
Definition: tree_model.h:140
bool operator!() const
Definition: tree_model.h:149
CopyUniquePtr(CopyUniquePtr const &that)
Definition: tree_model.h:134
T * operator->() noexcept
Definition: tree_model.h:143
T & operator*()
Definition: tree_model.h:142
T const & operator*() const
Definition: tree_model.h:145
void reset(T *ptr)
Definition: tree_model.h:150
Feature map data structure to help text model dump. TODO(tqchen) consider make it even more lightweig...
Definition: feature_map.h:22
Data structure representing JSON format.
Definition: json.h:368
Tree structure for multi-target model.
Definition: multi_target_tree_model.h:23
static constexpr bst_node_t InvalidNodeId()
Definition: multi_target_tree_model.h:25
tree node
Definition: tree_model.h:166
XGBOOST_DEVICE int Parent() const
get parent of the node
Definition: tree_model.h:201
XGBOOST_DEVICE void MarkDelete()
mark that this node is deleted
Definition: tree_model.h:246
XGBOOST_DEVICE bool IsRoot() const
whether current node is root
Definition: tree_model.h:207
XGBOOST_DEVICE int RightChild() const
index of right child
Definition: tree_model.h:183
XGBOOST_DEVICE float LeafValue() const
Definition: tree_model.h:197
XGBOOST_DEVICE Node()
Definition: tree_model.h:168
XGBOOST_DEVICE void SetParent(int pidx, bool is_left_child=true)
Definition: tree_model.h:254
XGBOOST_DEVICE unsigned SplitIndex() const
feature index of split condition
Definition: tree_model.h:189
Node ByteSwap() const
Definition: tree_model.h:264
XGBOOST_DEVICE void SetLeaf(bst_float value, int right=kInvalidNodeId)
set the leaf value of the node
Definition: tree_model.h:240
XGBOOST_DEVICE bool IsLeftChild() const
whether current node is left child
Definition: tree_model.h:203
XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond, bool default_left=false)
set split condition of current node
Definition: tree_model.h:228
XGBOOST_DEVICE void SetLeftChild(int nid)
set the left child
Definition: tree_model.h:212
XGBOOST_DEVICE bool IsDeleted() const
whether this node is deleted
Definition: tree_model.h:205
XGBOOST_DEVICE bool IsLeaf() const
whether current node is leaf node
Definition: tree_model.h:195
bool operator==(const Node &b) const
Definition: tree_model.h:258
Node(int32_t cleft, int32_t cright, int32_t parent, uint32_t split_ind, float split_cond, bool default_left)
Definition: tree_model.h:173
XGBOOST_DEVICE void Reuse()
Reuse this deleted node.
Definition: tree_model.h:250
XGBOOST_DEVICE void SetRightChild(int nid)
set the right child
Definition: tree_model.h:219
XGBOOST_DEVICE bool DefaultLeft() const
when feature is unknown, whether goes to left child
Definition: tree_model.h:193
XGBOOST_DEVICE int LeftChild() const
index of left child
Definition: tree_model.h:181
XGBOOST_DEVICE int DefaultChild() const
index of default child when feature is missing
Definition: tree_model.h:185
XGBOOST_DEVICE SplitCondT SplitCond() const
Definition: tree_model.h:199
define regression tree to be the most common tree model.
Definition: tree_model.h:158
int MaxDepth(int nid) const
get maximum depth
Definition: tree_model.h:540
void SaveModel(Json *out) const override
saves the model config to a JSON object
bst_target_t NumTargets() const
The size of leaf weight.
Definition: tree_model.h:481
void WalkTree(Func func) const
Definition: tree_model.h:391
void Save(dmlc::Stream *fo) const
save model to stream
bool IsLeaf(bst_node_t nidx) const
Definition: tree_model.h:700
bool operator==(const RegTree &b) const
Definition: tree_model.h:382
const RTreeNodeStat & Stat(int nid) const
get node statistics given nid
Definition: tree_model.h:364
void ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left, linalg::VectorView< float const > base_weight, linalg::VectorView< float const > left_weight, linalg::VectorView< float const > right_weight)
Expands a leaf node into two additional leaf nodes for a multi-target tree.
bst_node_t Parent(bst_node_t nidx) const
Definition: tree_model.h:706
bst_node_t NumNodes() const noexcept
Get the total number of nodes including deleted ones in this tree.
Definition: tree_model.h:496
const Node & operator[](int nid) const
get node given nid
Definition: tree_model.h:349
bst_node_t DefaultChild(bst_node_t nidx) const
Definition: tree_model.h:691
void ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum, bst_node_t leaf_right_child=kInvalidNodeId)
Expands a leaf node into two additional leaf nodes.
Node & operator[](int nid)
get node given nid
Definition: tree_model.h:345
RegTree()
Definition: tree_model.h:322
static constexpr bst_node_t kInvalidNodeId
Definition: tree_model.h:161
bst_feature_t SplitIndex(bst_node_t nidx) const
Definition: tree_model.h:673
bool IsRoot(bst_node_t nidx) const
Definition: tree_model.h:694
static constexpr uint32_t kDeletedNodeMarker
Definition: tree_model.h:162
bool IsMultiTarget() const
Whether this is a multi-target tree.
Definition: tree_model.h:477
bst_node_t NumExtraNodes() const noexcept
number of extra nodes besides the root
Definition: tree_model.h:506
bool DefaultLeft(bst_node_t nidx) const
Definition: tree_model.h:685
auto GetMultiTargetTree() const
Get the underlying implementaiton of multi-target tree.
Definition: tree_model.h:485
void Load(dmlc::Stream *fi)
load model from stream
bst_node_t LeftChild(bst_node_t nidx) const
Definition: tree_model.h:712
bst_node_t GetNumLeaves() const
RegTree(bst_target_t n_targets, bst_feature_t n_features)
Constructor that initializes the tree model with shape.
Definition: tree_model.h:336
bst_node_t RightChild(bst_node_t nidx) const
Definition: tree_model.h:718
common::Span< uint32_t const > NodeCats(bst_node_t nidx) const
Get the bit storage for categories.
Definition: tree_model.h:639
bool IsLeftChild(bst_node_t nidx) const
Definition: tree_model.h:724
CategoricalSplitMatrix GetCategoriesMatrix() const
Definition: tree_model.h:665
RTreeNodeStat & Stat(int nid)
get node statistics given nid
Definition: tree_model.h:360
bst_float SplitCondT
Definition: tree_model.h:160
void ExpandCategorical(bst_node_t nid, bst_feature_t split_index, common::Span< const uint32_t > split_cat, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum)
Expands a leaf node with categories.
bool Equal(const RegTree &b) const
Compares whether 2 trees are equal from a user's perspective. The equality compares only non-deleted ...
std::vector< FeatureType > const & GetSplitTypes() const
Get split types for all nodes.
Definition: tree_model.h:630
void CollapseToLeaf(int rid, bst_float value)
collapse a non leaf node to a leaf node, delete its children
Definition: tree_model.h:311
bst_node_t NumValidNodes() const noexcept
Get the total number of valid nodes in this tree.
Definition: tree_model.h:500
void ChangeToLeaf(int rid, bst_float value)
change a non leaf node to a leaf node, delete its children
Definition: tree_model.h:299
const std::vector< RTreeNodeStat > & GetStats() const
get const reference to stats
Definition: tree_model.h:357
void SetLeaf(bst_node_t nidx, linalg::VectorView< float const > weight)
Set the leaf weight for a multi-target tree.
Definition: tree_model.h:531
const std::vector< Node > & GetNodes() const
get const reference to nodes
Definition: tree_model.h:354
void CalculateContributionsApprox(const RegTree::FVec &feat, std::vector< float > *mean_values, bst_float *out_contribs) const
calculate the approximate feature contributions for the given root
void LoadModel(Json const &in) override
load the model from a JSON object
std::string DumpModel(const FeatureMap &fmap, bool with_stats, std::string format) const
dump the model in the requested format as a text string
FeatureType NodeSplitType(bst_node_t nidx) const
Get split type for a node.
Definition: tree_model.h:626
bst_feature_t NumFeatures() const noexcept
Get the number of features.
Definition: tree_model.h:492
common::Span< uint32_t const > GetSplitCategories() const
Definition: tree_model.h:633
bool HasCategoricalSplit() const
Whether this tree has categorical split.
Definition: tree_model.h:473
std::int32_t GetDepth(bst_node_t nid) const
get current depth
Definition: tree_model.h:517
static constexpr bst_node_t kRoot
Definition: tree_model.h:163
bst_node_t GetNumSplitNodes() const
auto const & GetSplitCategoriesPtr() const
Definition: tree_model.h:646
float SplitCond(bst_node_t nidx) const
Definition: tree_model.h:679
int MaxDepth()
get maximum depth
Definition: tree_model.h:548
bst_node_t Size() const
Definition: tree_model.h:732
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition: span.h:422
XGBOOST_DEVICE auto subspan() const -> Span< element_type, detail::ExtentValue< Extent, Offset, Count >::value >
Definition: span.h:594
A tensor view with static type and dimension. It implements indexing and slicing.
Definition: linalg.h:294
The input data structure of xgboost.
Feature map data structure to help visualization and model dump.
Linear algebra related utilities.
Defines the abstract interface for different components in XGBoost.
Core data structure for multi-target trees.
Definition: base.h:87
std::vector< std::pair< std::string, std::string > > Args
Definition: base.h:310
std::int32_t bst_node_t
Type for tree node index.
Definition: base.h:107
FeatureType
Definition: data.h:40
std::uint32_t bst_target_t
Type for indexing into output targets.
Definition: base.h:113
std::uint32_t bst_feature_t
Type for data column (feature) index.
Definition: base.h:97
float bst_float
float type, used for storing statistics
Definition: base.h:93
StringView MTNotImplemented()
Definition: tree_model.h:837
node statistics used in regression tree
Definition: tree_model.h:96
RTreeNodeStat(float loss_chg, float sum_hess, float weight)
Definition: tree_model.h:107
bst_float loss_chg
loss change caused by current split
Definition: tree_model.h:98
int leaf_child_cnt
number of child that is leaf node known up to now
Definition: tree_model.h:104
bst_float sum_hess
sum of hessian values, used to measure coverage of data
Definition: tree_model.h:100
bool operator==(const RTreeNodeStat &b) const
Definition: tree_model.h:109
RTreeNodeStat ByteSwap() const
Definition: tree_model.h:115
bst_float base_weight
weight of current node
Definition: tree_model.h:102
Definition: tree_model.h:656
std::size_t size
Definition: tree_model.h:658
std::size_t beg
Definition: tree_model.h:657
CSR-like matrix for categorical splits.
Definition: tree_model.h:655
common::Span< uint32_t const > categories
Definition: tree_model.h:661
common::Span< Segment const > node_ptr
Definition: tree_model.h:662
common::Span< FeatureType const > split_type
Definition: tree_model.h:660
dense feature vector that can be taken by RegTree and can be construct from sparse feature vector.
Definition: tree_model.h:554
void Drop()
drop the trace after fill, must be called after fill.
Definition: tree_model.h:813
bool HasMissing() const
Definition: tree_model.h:832
void Fill(const SparsePage::Inst &inst)
fill the vector with sparse vector
Definition: tree_model.h:801
bool IsMissing(size_t i) const
check whether i-th entry is missing
Definition: tree_model.h:828
size_t Size() const
returns the size of the feature vector
Definition: tree_model.h:820
bst_float GetFvalue(size_t i) const
get ith value
Definition: tree_model.h:824
void Init(size_t size)
initialize the vector with size vector
Definition: tree_model.h:794
Definition: string_view.h:16
meta parameters of the tree
Definition: tree_model.h:35
bst_feature_t num_feature
number of features used for tree construction
Definition: tree_model.h:45
int num_nodes
total number of nodes
Definition: tree_model.h:39
int num_deleted
number of deleted nodes
Definition: tree_model.h:41
bool operator==(const TreeParam &b) const
Definition: tree_model.h:89
int reserved[31]
reserved part, make sure alignment works for 64bit
Definition: tree_model.h:52
TreeParam ByteSwap() const
Definition: tree_model.h:62
TreeParam()
constructor
Definition: tree_model.h:54
DMLC_DECLARE_PARAMETER(TreeParam)
Definition: tree_model.h:75
bst_target_t size_leaf_vector
leaf vector size, used for vector tree used to store more than one dimensional information in tree
Definition: tree_model.h:50
int deprecated_num_roots
(Deprecated) number of start root
Definition: tree_model.h:37
int deprecated_max_depth
maximum depth, this is a statistics of the tree
Definition: tree_model.h:43