7 #ifndef XGBOOST_TREE_MODEL_H_
8 #define XGBOOST_TREE_MODEL_H_
15 #include <xgboost/logging.h>
24 #include <type_traits>
30 struct ScalarTreeView;
31 struct MultiTargetTreeView;
93 static_assert(
sizeof(
Node) == 4 *
sizeof(
int) +
sizeof(Info),
"Node: 64 bit align");
95 Node(int32_t cleft, int32_t cright, int32_t parent, uint32_t split_ind,
float split_cond,
97 : parent_{parent}, cleft_{cleft}, cright_{cright} {
99 this->
SetSplit(split_ind, split_cond, default_left);
112 static_assert(!std::is_signed_v<bst_feature_t>);
113 return sindex_ & ((1U << 31) - 1U);
152 bool default_left =
false) {
153 if (default_left) split_index |= (1U << 31);
154 this->sindex_ = split_index;
155 (this->info_).split_cond = split_cond;
164 (this->info_).leaf_value = value;
166 this->cright_ = right;
178 if (is_left_child) pidx |= (1U << 31);
179 this->parent_ = pidx;
182 return parent_ == b.parent_ && cleft_ == b.cleft_ &&
183 cright_ == b.cright_ && sindex_ == b.sindex_ &&
184 info_.leaf_value == b.info_.leaf_value;
214 auto& h_nodes = nodes_.HostVector();
215 CHECK(h_nodes[h_nodes[nidx].
LeftChild()].IsLeaf());
216 CHECK(h_nodes[h_nodes[nidx].
RightChild()].IsLeaf());
217 this->DeleteNode(h_nodes[nidx].
LeftChild());
219 h_nodes[nidx].SetLeaf(value);
228 auto& h_nodes = nodes_.HostVector();
229 if (h_nodes[nidx].IsLeaf())
return;
230 if (!h_nodes[h_nodes[nidx].
LeftChild()].IsLeaf()) {
233 if (!h_nodes[h_nodes[nidx].
RightChild()].IsLeaf()) {
240 nodes_.HostVector().resize(param_.
num_nodes);
241 stats_.HostVector().resize(param_.
num_nodes);
243 split_categories_segments_.HostVector().resize(param_.
num_nodes);
244 auto& h_nodes = nodes_.HostVector();
245 for (
int i = 0; i < param_.
num_nodes; i++) {
246 h_nodes[i].SetLeaf(0.0f);
268 return device.
IsCPU() ? nodes_.ConstHostSpan()
269 : (nodes_.SetDevice(device), nodes_.ConstDeviceSpan());
275 return device.
IsCPU() ? stats_.ConstHostSpan()
276 : (stats_.SetDevice(device), stats_.ConstDeviceSpan());
281 return stats_.HostVector()[nid];
288 return nodes_.ConstHostVector() == b.nodes_.ConstHostVector() &&
289 stats_.ConstHostVector() == b.stats_.ConstHostVector() &&
290 deleted_nodes_ == b.deleted_nodes_ && param_ == b.param_;
318 bool default_left,
bst_float base_weight,
320 bst_float loss_change,
float sum_hess,
float left_sum,
335 float left_sum,
float right_sum);
367 float left_sum,
float right_sum);
375 [[nodiscard]]
bool IsMultiTarget()
const {
return static_cast<bool>(p_mt_tree_); }
385 return p_mt_tree_.get();
423 return this->p_mt_tree_->SetRoot(weight, sum_hess);
439 void Init(
size_t size);
455 [[nodiscard]]
size_t Size()
const;
467 [[nodiscard]]
bool IsMissing(
size_t i)
const;
469 void HasMissing(
bool has_missing) { this->has_missing_ = has_missing; }
479 std::vector<float> data_;
491 std::string format)
const;
500 return device.
IsCPU()
505 return split_categories_segments_.ConstHostVector();
529 if (device.
IsCPU()) {
530 view.
node_ptr = split_categories_segments_.ConstHostSpan();
532 split_categories_segments_.SetDevice(device);
533 view.
node_ptr = split_categories_segments_.ConstDeviceSpan();
540 return this->p_mt_tree_->LeftChild(nidx);
542 return nodes_.ConstHostVector()[nidx].LeftChild();
546 return this->p_mt_tree_->RightChild(nidx);
548 return nodes_.ConstHostVector()[nidx].RightChild();
552 return this->p_mt_tree_->Size();
554 return this->nodes_.Size();
562 template <
bool typed>
563 void LoadCategoricalSplit(
Json const& in);
564 void SaveCategoricalSplit(
Json* p_out)
const;
570 std::vector<int> deleted_nodes_;
580 std::unique_ptr<MultiTargetTree> p_mt_tree_;
585 int nid = deleted_nodes_.back();
586 deleted_nodes_.pop_back();
592 CHECK_LT(param_.
num_nodes, std::numeric_limits<int>::max())
593 <<
"number of nodes in the tree exceed 2^31";
601 void DeleteNode(
int nid) {
603 auto pid = (*this)[nid].Parent();
610 deleted_nodes_.push_back(nid);
618 std::fill(data_.begin(), data_.end(), std::numeric_limits<float>::quiet_NaN());
623 auto p_data = inst.
data();
624 auto p_out = data_.data();
626 for (std::size_t i = 0, n = inst.
size(); i < n; ++i) {
627 auto const& entry = p_data[i];
628 p_out[entry.index] = entry.fvalue;
630 has_missing_ = data_.size() != inst.
size();
649 return " support for multi-target tree is not yet implemented.";
Defines configuration macros and basic types for xgboost.
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition: base.h:57
Feature map data structure to help text model dump. TODO(tqchen) consider make it even more lightweig...
Definition: feature_map.h:22
Definition: host_device_vector.h:87
bool Empty() const
Definition: host_device_vector.h:102
common::Span< T const > ConstHostSpan() const
Definition: host_device_vector.h:116
std::vector< T > & HostVector()
common::Span< const T > ConstDeviceSpan() const
void SetDevice(DeviceOrd device) const
Data structure representing JSON format.
Definition: json.h:396
Tree structure for multi-target model.
Definition: multi_target_tree_model.h:38
static constexpr bst_node_t InvalidNodeId()
Definition: multi_target_tree_model.h:40
tree node
Definition: tree_model.h:89
XGBOOST_DEVICE int Parent() const
get parent of the node
Definition: tree_model.h:124
XGBOOST_DEVICE void MarkDelete()
mark that this node is deleted
Definition: tree_model.h:169
XGBOOST_DEVICE bool IsRoot() const
whether current node is root
Definition: tree_model.h:130
XGBOOST_DEVICE int RightChild() const
index of right child
Definition: tree_model.h:105
XGBOOST_DEVICE float LeafValue() const
Definition: tree_model.h:120
XGBOOST_DEVICE Node()
Definition: tree_model.h:91
XGBOOST_DEVICE void SetParent(int pidx, bool is_left_child=true)
Definition: tree_model.h:177
XGBOOST_DEVICE void SetLeaf(bst_float value, int right=kInvalidNodeId)
set the leaf value of the node
Definition: tree_model.h:163
XGBOOST_DEVICE bool IsLeftChild() const
whether current node is left child
Definition: tree_model.h:126
XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond, bool default_left=false)
set split condition of current node
Definition: tree_model.h:151
XGBOOST_DEVICE void SetLeftChild(int nid)
set the left child
Definition: tree_model.h:135
XGBOOST_DEVICE bst_feature_t SplitIndex() const
feature index of split condition
Definition: tree_model.h:111
XGBOOST_DEVICE bool IsDeleted() const
whether this node is deleted
Definition: tree_model.h:128
XGBOOST_DEVICE bool IsLeaf() const
whether current node is leaf node
Definition: tree_model.h:118
bool operator==(const Node &b) const
Definition: tree_model.h:181
Node(int32_t cleft, int32_t cright, int32_t parent, uint32_t split_ind, float split_cond, bool default_left)
Definition: tree_model.h:95
XGBOOST_DEVICE void Reuse()
Reuse this deleted node.
Definition: tree_model.h:173
XGBOOST_DEVICE void SetRightChild(int nid)
set the right child
Definition: tree_model.h:142
XGBOOST_DEVICE bool DefaultLeft() const
when feature is unknown, whether goes to left child
Definition: tree_model.h:116
XGBOOST_DEVICE int LeftChild() const
index of left child
Definition: tree_model.h:103
XGBOOST_DEVICE int DefaultChild() const
index of default child when feature is missing
Definition: tree_model.h:107
XGBOOST_DEVICE SplitCondT SplitCond() const
Definition: tree_model.h:122
define regression tree to be the most common tree model.
Definition: tree_model.h:81
void SaveModel(Json *out) const override
saves the model config to a JSON object
tree::MultiTargetTreeView HostMtView() const
void ChangeToLeaf(bst_node_t nidx, float value)
Change a non leaf node to a leaf node, delete its children.
Definition: tree_model.h:213
bst_target_t NumTargets() const
The size of leaf weight.
Definition: tree_model.h:379
bool operator==(const RegTree &b) const
Definition: tree_model.h:287
bst_node_t NumNodes() const noexcept
Get the total number of nodes including deleted ones in this tree.
Definition: tree_model.h:394
void ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum, bst_node_t leaf_right_child=kInvalidNodeId)
Expands a leaf node into two additional leaf nodes.
RegTree()
Definition: tree_model.h:239
static constexpr bst_node_t kInvalidNodeId
Definition: tree_model.h:84
void ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left, linalg::VectorView< float const > base_weight, linalg::VectorView< float const > left_weight, linalg::VectorView< float const > right_weight, float loss_chg, float sum_hess, float left_sum, float right_sum)
Expands a leaf node into two additional leaf nodes for a multi-target tree.
Node & operator[](bst_node_t nidx)
get node given nid
Definition: tree_model.h:262
static constexpr uint32_t kDeletedNodeMarker
Definition: tree_model.h:85
bool IsMultiTarget() const
Whether this is a multi-target tree.
Definition: tree_model.h:375
bst_node_t NumExtraNodes() const noexcept
number of extra nodes besides the root
Definition: tree_model.h:404
bst_node_t MaxDepth() const
Get the maximum depth.
auto GetMultiTargetTree() const
Get the underlying implementaiton of multi-target tree.
Definition: tree_model.h:383
bst_node_t LeftChild(bst_node_t nidx) const
Definition: tree_model.h:538
common::Span< RTreeNodeStat const > GetStats(DeviceOrd device) const
Get const reference to stats.
Definition: tree_model.h:273
void SetLeaves(std::vector< bst_node_t > leaves, common::Span< float const > weights)
Set all leaf weights for a multi-target tree.
void CollapseToLeaf(bst_node_t nidx, float value)
Collapse a non leaf node to a leaf node, delete its children.
Definition: tree_model.h:227
bst_node_t GetNumLeaves() const
common::Span< Node const > GetNodes(DeviceOrd device) const
Get const reference to nodes.
Definition: tree_model.h:266
RegTree(bst_target_t n_targets, bst_feature_t n_features)
Constructor that initializes the tree model with shape.
Definition: tree_model.h:253
bst_node_t RightChild(bst_node_t nidx) const
Definition: tree_model.h:544
common::Span< FeatureType const > GetSplitTypes(DeviceOrd device) const
Get split types for all nodes.
Definition: tree_model.h:495
RTreeNodeStat & Stat(int nid)
get node statistics given nid
Definition: tree_model.h:280
void SetRoot(linalg::VectorView< float const > weight, float sum_hess)
Set the root weight and statistics for a multi-target tree.
Definition: tree_model.h:421
void ExpandCategorical(bst_node_t nid, bst_feature_t split_index, common::Span< const uint32_t > split_cat, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum)
Expands a leaf node with categories.
bool Equal(const RegTree &b) const
Compares whether 2 trees are equal from a user's perspective. The equality compares only non-deleted ...
bst_node_t NumValidNodes() const noexcept
Get the total number of valid nodes in this tree.
Definition: tree_model.h:398
float SplitCondT
Definition: tree_model.h:83
CategoricalSplitMatrix GetCategoriesMatrix(DeviceOrd device) const
Definition: tree_model.h:525
common::Span< uint32_t const > GetSplitCategories(DeviceOrd device) const
Definition: tree_model.h:499
void LoadModel(Json const &in) override
load the model from a JSON object
std::string DumpModel(const FeatureMap &fmap, bool with_stats, std::string format) const
dump the model in the requested format as a text string
bst_feature_t NumFeatures() const noexcept
Get the number of features.
Definition: tree_model.h:390
tree::ScalarTreeView HostScView() const
bool HasCategoricalSplit() const
Whether this tree has categorical split.
Definition: tree_model.h:371
static constexpr bst_node_t kRoot
Definition: tree_model.h:86
bst_node_t GetNumSplitNodes() const
auto const & GetSplitCategoriesPtr() const
Definition: tree_model.h:504
bst_node_t GetDepth(bst_node_t nidx) const
Get the depth of a node.
bst_node_t Size() const
Definition: tree_model.h:550
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition: span.h:435
constexpr XGBOOST_DEVICE pointer data() const __span_noexcept
Definition: span.h:554
constexpr XGBOOST_DEVICE index_type size() const __span_noexcept
Definition: span.h:559
A tensor view with static type and dimension. It implements indexing and slicing.
Definition: linalg.h:278
The input data structure of xgboost.
Feature map data structure to help visualization and model dump.
A device-and-host vector abstraction layer.
Linear algebra related utilities.
Defines the abstract interface for different components in XGBoost.
Learner interface that integrates objective, gbm and evaluation together. This is the user facing XGB...
Definition: base.h:89
std::int32_t bst_node_t
Type for tree node index and tree depth.
Definition: base.h:111
std::uint32_t bst_target_t
Type for indexing into output targets.
Definition: base.h:119
std::uint32_t bst_feature_t
Type for data column (feature) index.
Definition: base.h:99
float bst_float
float type, used for storing statistics
Definition: base.h:95
StringView MTNotImplemented()
Definition: tree_model.h:648
A type for device ordinal. The type is packed into 32-bit for efficient use in viewing types like lin...
Definition: context.h:34
bool IsCPU() const
Definition: context.h:45
node statistics used in regression tree
Definition: tree_model.h:57
RTreeNodeStat(float loss_chg, float sum_hess, float weight)
Definition: tree_model.h:68
float sum_hess
sum of hessian values, used to measure coverage of data
Definition: tree_model.h:61
int leaf_child_cnt
number of child that is leaf node known up to now
Definition: tree_model.h:65
float loss_chg
loss change caused by current split
Definition: tree_model.h:59
float base_weight
weight of current node
Definition: tree_model.h:63
bool operator==(const RTreeNodeStat &b) const
Definition: tree_model.h:70
Definition: tree_model.h:516
std::size_t size
Definition: tree_model.h:518
std::size_t beg
Definition: tree_model.h:517
CSR-like matrix for categorical splits.
Definition: tree_model.h:515
common::Span< uint32_t const > categories
Definition: tree_model.h:521
common::Span< Segment const > node_ptr
Definition: tree_model.h:522
common::Span< FeatureType const > split_type
Definition: tree_model.h:520
dense feature vector that can be taken by RegTree and can be construct from sparse feature vector.
Definition: tree_model.h:434
void HasMissing(bool has_missing)
Definition: tree_model.h:469
void Drop()
drop the trace after fill, must be called after fill.
Definition: tree_model.h:633
bool HasMissing() const
Definition: tree_model.h:645
bool IsMissing(size_t i) const
check whether i-th entry is missing
Definition: tree_model.h:643
size_t Size() const
returns the size of the feature vector
Definition: tree_model.h:635
void Init(size_t size)
initialize the vector with size vector
Definition: tree_model.h:616
common::Span< float > Data()
Definition: tree_model.h:471
void Fill(SparsePage::Inst const &inst)
fill the vector with sparse vector
Definition: tree_model.h:622
bst_float GetFvalue(size_t i) const
get ith value
Definition: tree_model.h:639
Definition: string_view.h:16
meta parameters of the tree
Definition: tree_model.h:37
bst_node_t num_deleted
The number of deleted nodes.
Definition: tree_model.h:41
bst_feature_t num_feature
The number of features used for tree construction.
Definition: tree_model.h:43
bool operator==(const TreeParam &b) const
Definition: tree_model.h:47
bst_node_t num_nodes
The number of nodes.
Definition: tree_model.h:39
void ToJson(Json *p_out) const
void FromJson(Json const &in)
bst_target_t size_leaf_vector
leaf vector size. Used by the vector leaf.
Definition: tree_model.h:45