xgboost
Classes | Public Types | Public Member Functions | Static Public Attributes | List of all members
xgboost::RegTree Class Reference

define regression tree to be the most common tree model. More...

#include <tree_model.h>

Inheritance diagram for xgboost::RegTree:
Inheritance graph
Collaboration diagram for xgboost::RegTree:
Collaboration graph

Classes

struct  CategoricalSplitMatrix
 CSR-like matrix for categorical splits. More...
 
struct  FVec
 dense feature vector that can be taken by RegTree and can be construct from sparse feature vector. More...
 
class  Node
 tree node More...
 

Public Types

using SplitCondT = float
 

Public Member Functions

void ChangeToLeaf (bst_node_t nidx, float value)
 Change a non leaf node to a leaf node, delete its children. More...
 
void CollapseToLeaf (bst_node_t nidx, float value)
 Collapse a non leaf node to a leaf node, delete its children. More...
 
 RegTree ()
 
 RegTree (bst_target_t n_targets, bst_feature_t n_features)
 Constructor that initializes the tree model with shape. More...
 
Nodeoperator[] (bst_node_t nidx)
 get node given nid More...
 
common::Span< Node const > GetNodes (DeviceOrd device) const
 Get const reference to nodes. More...
 
common::Span< RTreeNodeStat const > GetStats (DeviceOrd device) const
 Get const reference to stats. More...
 
RTreeNodeStatStat (int nid)
 get node statistics given nid More...
 
void LoadModel (Json const &in) override
 load the model from a JSON object More...
 
void SaveModel (Json *out) const override
 saves the model config to a JSON object More...
 
bool operator== (const RegTree &b) const
 
bool Equal (const RegTree &b) const
 Compares whether 2 trees are equal from a user's perspective. The equality compares only non-deleted nodes. More...
 
void ExpandNode (bst_node_t nid, unsigned split_index, bst_float split_value, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum, bst_node_t leaf_right_child=kInvalidNodeId)
 Expands a leaf node into two additional leaf nodes. More...
 
void ExpandNode (bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left, linalg::VectorView< float const > base_weight, linalg::VectorView< float const > left_weight, linalg::VectorView< float const > right_weight, float loss_chg, float sum_hess, float left_sum, float right_sum)
 Expands a leaf node into two additional leaf nodes for a multi-target tree. More...
 
void SetLeaves (std::vector< bst_node_t > leaves, common::Span< float const > weights)
 Set all leaf weights for a multi-target tree. More...
 
void ExpandCategorical (bst_node_t nid, bst_feature_t split_index, common::Span< const uint32_t > split_cat, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum)
 Expands a leaf node with categories. More...
 
bool HasCategoricalSplit () const
 Whether this tree has categorical split. More...
 
bool IsMultiTarget () const
 Whether this is a multi-target tree. More...
 
bst_target_t NumTargets () const
 The size of leaf weight. More...
 
auto GetMultiTargetTree () const
 Get the underlying implementaiton of multi-target tree. More...
 
bst_feature_t NumFeatures () const noexcept
 Get the number of features. More...
 
bst_node_t NumNodes () const noexcept
 Get the total number of nodes including deleted ones in this tree. More...
 
bst_node_t NumValidNodes () const noexcept
 Get the total number of valid nodes in this tree. More...
 
bst_node_t NumExtraNodes () const noexcept
 number of extra nodes besides the root More...
 
bst_node_t GetNumLeaves () const
 
bst_node_t GetNumSplitNodes () const
 
bst_node_t GetDepth (bst_node_t nidx) const
 Get the depth of a node. More...
 
void SetRoot (linalg::VectorView< float const > weight, float sum_hess)
 Set the root weight and statistics for a multi-target tree. More...
 
bst_node_t MaxDepth () const
 Get the maximum depth. More...
 
std::string DumpModel (const FeatureMap &fmap, bool with_stats, std::string format) const
 dump the model in the requested format as a text string More...
 
common::Span< FeatureType const > GetSplitTypes (DeviceOrd device) const
 Get split types for all nodes. More...
 
common::Span< uint32_t const > GetSplitCategories (DeviceOrd device) const
 
auto const & GetSplitCategoriesPtr () const
 
CategoricalSplitMatrix GetCategoriesMatrix (DeviceOrd device) const
 
bst_node_t LeftChild (bst_node_t nidx) const
 
bst_node_t RightChild (bst_node_t nidx) const
 
bst_node_t Size () const
 
RegTreeCopy () const
 
tree::ScalarTreeView HostScView () const
 
tree::MultiTargetTreeView HostMtView () const
 
- Public Member Functions inherited from xgboost::Model
virtual ~Model ()=default
 

Static Public Attributes

static constexpr bst_node_t kInvalidNodeId {MultiTargetTree::InvalidNodeId()}
 
static constexpr uint32_t kDeletedNodeMarker = std::numeric_limits<uint32_t>::max()
 
static constexpr bst_node_t kRoot {0}
 

Detailed Description

define regression tree to be the most common tree model.

This is the data structure used in xgboost's major tree models.

Member Typedef Documentation

◆ SplitCondT

Constructor & Destructor Documentation

◆ RegTree() [1/2]

xgboost::RegTree::RegTree ( )
inline

◆ RegTree() [2/2]

xgboost::RegTree::RegTree ( bst_target_t  n_targets,
bst_feature_t  n_features 
)
inlineexplicit

Constructor that initializes the tree model with shape.

Member Function Documentation

◆ ChangeToLeaf()

void xgboost::RegTree::ChangeToLeaf ( bst_node_t  nidx,
float  value 
)
inline

Change a non leaf node to a leaf node, delete its children.

Parameters
nidxNode id
valueThe new leaf value

◆ CollapseToLeaf()

void xgboost::RegTree::CollapseToLeaf ( bst_node_t  nidx,
float  value 
)
inline

Collapse a non leaf node to a leaf node, delete its children.

Parameters
nidxNode id
valueThe new leaf value

◆ Copy()

RegTree* xgboost::RegTree::Copy ( ) const

◆ DumpModel()

std::string xgboost::RegTree::DumpModel ( const FeatureMap fmap,
bool  with_stats,
std::string  format 
) const

dump the model in the requested format as a text string

Parameters
fmapfeature map that may help give interpretations of feature
with_statswhether dump out statistics as well
formatthe format to dump the model in
Returns
the string of dumped model

◆ Equal()

bool xgboost::RegTree::Equal ( const RegTree b) const

Compares whether 2 trees are equal from a user's perspective. The equality compares only non-deleted nodes.

Parameters
bThe other tree.

◆ ExpandCategorical()

void xgboost::RegTree::ExpandCategorical ( bst_node_t  nid,
bst_feature_t  split_index,
common::Span< const uint32_t >  split_cat,
bool  default_left,
bst_float  base_weight,
bst_float  left_leaf_weight,
bst_float  right_leaf_weight,
bst_float  loss_change,
float  sum_hess,
float  left_sum,
float  right_sum 
)

Expands a leaf node with categories.

Parameters
nidThe node index to expand.
split_indexFeature index of the split.
split_catThe bitset containing categories
default_leftTrue to default left.
base_weightThe base weight, before learning rate.
left_leaf_weightThe left leaf weight for prediction, modified by learning rate.
right_leaf_weightThe right leaf weight for prediction, modified by learning rate.
loss_changeThe loss change.
sum_hessThe sum hess.
left_sumThe sum hess of left leaf.
right_sumThe sum hess of right leaf.

◆ ExpandNode() [1/2]

void xgboost::RegTree::ExpandNode ( bst_node_t  nid,
unsigned  split_index,
bst_float  split_value,
bool  default_left,
bst_float  base_weight,
bst_float  left_leaf_weight,
bst_float  right_leaf_weight,
bst_float  loss_change,
float  sum_hess,
float  left_sum,
float  right_sum,
bst_node_t  leaf_right_child = kInvalidNodeId 
)

Expands a leaf node into two additional leaf nodes.

Parameters
nidThe node index to expand.
split_indexFeature index of the split.
split_valueThe split condition.
default_leftTrue to default left.
base_weightThe base weight, before learning rate.
left_leaf_weightThe left leaf weight for prediction, modified by learning rate.
right_leaf_weightThe right leaf weight for prediction, modified by learning rate.
loss_changeThe loss change.
sum_hessThe sum hess.
left_sumThe sum hess of left leaf.
right_sumThe sum hess of right leaf.
leaf_right_childThe right child index of leaf, by default kInvalidNodeId, some updaters use the right child index of leaf as a marker

◆ ExpandNode() [2/2]

void xgboost::RegTree::ExpandNode ( bst_node_t  nidx,
bst_feature_t  split_index,
float  split_cond,
bool  default_left,
linalg::VectorView< float const >  base_weight,
linalg::VectorView< float const >  left_weight,
linalg::VectorView< float const >  right_weight,
float  loss_chg,
float  sum_hess,
float  left_sum,
float  right_sum 
)

Expands a leaf node into two additional leaf nodes for a multi-target tree.

Parameters
gainThe gain (loss change) from this split.
sum_hessThe sum of hessians for the parent node (coverage).
left_sumThe sum of hessians for the left child (coverage).
right_sumThe sum of hessians for the right child (coverage).

◆ GetCategoriesMatrix()

CategoricalSplitMatrix xgboost::RegTree::GetCategoriesMatrix ( DeviceOrd  device) const
inline

◆ GetDepth()

bst_node_t xgboost::RegTree::GetDepth ( bst_node_t  nidx) const

Get the depth of a node.

◆ GetMultiTargetTree()

auto xgboost::RegTree::GetMultiTargetTree ( ) const
inline

Get the underlying implementaiton of multi-target tree.

◆ GetNodes()

common::Span<Node const> xgboost::RegTree::GetNodes ( DeviceOrd  device) const
inline

Get const reference to nodes.

◆ GetNumLeaves()

bst_node_t xgboost::RegTree::GetNumLeaves ( ) const

◆ GetNumSplitNodes()

bst_node_t xgboost::RegTree::GetNumSplitNodes ( ) const

◆ GetSplitCategories()

common::Span<uint32_t const> xgboost::RegTree::GetSplitCategories ( DeviceOrd  device) const
inline

◆ GetSplitCategoriesPtr()

auto const& xgboost::RegTree::GetSplitCategoriesPtr ( ) const
inline

◆ GetSplitTypes()

common::Span<FeatureType const> xgboost::RegTree::GetSplitTypes ( DeviceOrd  device) const
inline

Get split types for all nodes.

◆ GetStats()

common::Span<RTreeNodeStat const> xgboost::RegTree::GetStats ( DeviceOrd  device) const
inline

Get const reference to stats.

◆ HasCategoricalSplit()

bool xgboost::RegTree::HasCategoricalSplit ( ) const
inline

Whether this tree has categorical split.

◆ HostMtView()

tree::MultiTargetTreeView xgboost::RegTree::HostMtView ( ) const

◆ HostScView()

tree::ScalarTreeView xgboost::RegTree::HostScView ( ) const

◆ IsMultiTarget()

bool xgboost::RegTree::IsMultiTarget ( ) const
inline

Whether this is a multi-target tree.

◆ LeftChild()

bst_node_t xgboost::RegTree::LeftChild ( bst_node_t  nidx) const
inline

◆ LoadModel()

void xgboost::RegTree::LoadModel ( Json const &  in)
overridevirtual

load the model from a JSON object

Parameters
inJSON object where to load the model from

Implements xgboost::Model.

◆ MaxDepth()

bst_node_t xgboost::RegTree::MaxDepth ( ) const

Get the maximum depth.

◆ NumExtraNodes()

bst_node_t xgboost::RegTree::NumExtraNodes ( ) const
inlinenoexcept

number of extra nodes besides the root

◆ NumFeatures()

bst_feature_t xgboost::RegTree::NumFeatures ( ) const
inlinenoexcept

Get the number of features.

◆ NumNodes()

bst_node_t xgboost::RegTree::NumNodes ( ) const
inlinenoexcept

Get the total number of nodes including deleted ones in this tree.

◆ NumTargets()

bst_target_t xgboost::RegTree::NumTargets ( ) const
inline

The size of leaf weight.

◆ NumValidNodes()

bst_node_t xgboost::RegTree::NumValidNodes ( ) const
inlinenoexcept

Get the total number of valid nodes in this tree.

◆ operator==()

bool xgboost::RegTree::operator== ( const RegTree b) const
inline

◆ operator[]()

Node& xgboost::RegTree::operator[] ( bst_node_t  nidx)
inline

get node given nid

◆ RightChild()

bst_node_t xgboost::RegTree::RightChild ( bst_node_t  nidx) const
inline

◆ SaveModel()

void xgboost::RegTree::SaveModel ( Json out) const
overridevirtual

saves the model config to a JSON object

Parameters
outJSON container where to save the model to

Implements xgboost::Model.

◆ SetLeaves()

void xgboost::RegTree::SetLeaves ( std::vector< bst_node_t leaves,
common::Span< float const >  weights 
)

Set all leaf weights for a multi-target tree.

The leaf weight can be different from the internal weight stored by ExpandNode This function is used to set the leaf at the end of tree construction.

Parameters
leavesThe node indices for all leaves. This must contain all the leaves in this tree.
weightsRow-major matrix for leaf weights, each row contains a leaf specified by the leaves parameter.

◆ SetRoot()

void xgboost::RegTree::SetRoot ( linalg::VectorView< float const >  weight,
float  sum_hess 
)
inline

Set the root weight and statistics for a multi-target tree.

Parameters
weightInternal split weight, with size equals to reduced targets.
sum_hessThe sum of hessians for the root node (coverage).

◆ Size()

bst_node_t xgboost::RegTree::Size ( ) const
inline

◆ Stat()

RTreeNodeStat& xgboost::RegTree::Stat ( int  nid)
inline

get node statistics given nid

Member Data Documentation

◆ kDeletedNodeMarker

constexpr uint32_t xgboost::RegTree::kDeletedNodeMarker = std::numeric_limits<uint32_t>::max()
staticconstexpr

◆ kInvalidNodeId

constexpr bst_node_t xgboost::RegTree::kInvalidNodeId {MultiTargetTree::InvalidNodeId()}
staticconstexpr

◆ kRoot

constexpr bst_node_t xgboost::RegTree::kRoot {0}
staticconstexpr

The documentation for this class was generated from the following file: