xgboost
Classes | Public Types | Public Member Functions | Static Public Attributes | List of all members
xgboost::RegTree Class Reference

define regression tree to be the most common tree model. More...

#include <tree_model.h>

Inheritance diagram for xgboost::RegTree:
Inheritance graph
Collaboration diagram for xgboost::RegTree:
Collaboration graph

Classes

struct  CategoricalSplitMatrix
 CSR-like matrix for categorical splits. More...
 
struct  FVec
 dense feature vector that can be taken by RegTree and can be construct from sparse feature vector. More...
 
class  Node
 tree node More...
 

Public Types

using SplitCondT = bst_float
 

Public Member Functions

void ChangeToLeaf (int rid, bst_float value)
 change a non leaf node to a leaf node, delete its children More...
 
void CollapseToLeaf (int rid, bst_float value)
 collapse a non leaf node to a leaf node, delete its children More...
 
 RegTree ()
 
 RegTree (bst_target_t n_targets, bst_feature_t n_features)
 Constructor that initializes the tree model with shape. More...
 
Nodeoperator[] (int nid)
 get node given nid More...
 
const Nodeoperator[] (int nid) const
 get node given nid More...
 
const std::vector< Node > & GetNodes () const
 get const reference to nodes More...
 
const std::vector< RTreeNodeStat > & GetStats () const
 get const reference to stats More...
 
RTreeNodeStatStat (int nid)
 get node statistics given nid More...
 
const RTreeNodeStatStat (int nid) const
 get node statistics given nid More...
 
void Load (dmlc::Stream *fi)
 load model from stream More...
 
void Save (dmlc::Stream *fo) const
 save model to stream More...
 
void LoadModel (Json const &in) override
 load the model from a JSON object More...
 
void SaveModel (Json *out) const override
 saves the model config to a JSON object More...
 
bool operator== (const RegTree &b) const
 
template<typename Func >
void WalkTree (Func func) const
 
bool Equal (const RegTree &b) const
 Compares whether 2 trees are equal from a user's perspective. The equality compares only non-deleted nodes. More...
 
void ExpandNode (bst_node_t nid, unsigned split_index, bst_float split_value, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum, bst_node_t leaf_right_child=kInvalidNodeId)
 Expands a leaf node into two additional leaf nodes. More...
 
void ExpandNode (bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left, linalg::VectorView< float const > base_weight, linalg::VectorView< float const > left_weight, linalg::VectorView< float const > right_weight)
 Expands a leaf node into two additional leaf nodes for a multi-target tree. More...
 
void ExpandCategorical (bst_node_t nid, bst_feature_t split_index, common::Span< const uint32_t > split_cat, bool default_left, bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum)
 Expands a leaf node with categories. More...
 
bool HasCategoricalSplit () const
 Whether this tree has categorical split. More...
 
bool IsMultiTarget () const
 Whether this is a multi-target tree. More...
 
bst_target_t NumTargets () const
 The size of leaf weight. More...
 
auto GetMultiTargetTree () const
 Get the underlying implementaiton of multi-target tree. More...
 
bst_feature_t NumFeatures () const noexcept
 Get the number of features. More...
 
bst_node_t NumNodes () const noexcept
 Get the total number of nodes including deleted ones in this tree. More...
 
bst_node_t NumValidNodes () const noexcept
 Get the total number of valid nodes in this tree. More...
 
bst_node_t NumExtraNodes () const noexcept
 number of extra nodes besides the root More...
 
bst_node_t GetNumLeaves () const
 
bst_node_t GetNumSplitNodes () const
 
std::int32_t GetDepth (bst_node_t nid) const
 get current depth More...
 
void SetLeaf (bst_node_t nidx, linalg::VectorView< float const > weight)
 Set the leaf weight for a multi-target tree. More...
 
int MaxDepth (int nid) const
 get maximum depth More...
 
int MaxDepth ()
 get maximum depth More...
 
void CalculateContributionsApprox (const RegTree::FVec &feat, std::vector< float > *mean_values, bst_float *out_contribs) const
 calculate the approximate feature contributions for the given root More...
 
std::string DumpModel (const FeatureMap &fmap, bool with_stats, std::string format) const
 dump the model in the requested format as a text string More...
 
FeatureType NodeSplitType (bst_node_t nidx) const
 Get split type for a node. More...
 
std::vector< FeatureType > const & GetSplitTypes () const
 Get split types for all nodes. More...
 
common::Span< uint32_t const > GetSplitCategories () const
 
common::Span< uint32_t const > NodeCats (bst_node_t nidx) const
 Get the bit storage for categories. More...
 
auto const & GetSplitCategoriesPtr () const
 
CategoricalSplitMatrix GetCategoriesMatrix () const
 
bst_feature_t SplitIndex (bst_node_t nidx) const
 
float SplitCond (bst_node_t nidx) const
 
bool DefaultLeft (bst_node_t nidx) const
 
bool IsRoot (bst_node_t nidx) const
 
bool IsLeaf (bst_node_t nidx) const
 
bst_node_t Parent (bst_node_t nidx) const
 
bst_node_t LeftChild (bst_node_t nidx) const
 
bst_node_t RightChild (bst_node_t nidx) const
 
bool IsLeftChild (bst_node_t nidx) const
 
bst_node_t Size () const
 
- Public Member Functions inherited from xgboost::Model
virtual ~Model ()=default
 

Static Public Attributes

static constexpr bst_node_t kInvalidNodeId {MultiTargetTree::InvalidNodeId()}
 
static constexpr uint32_t kDeletedNodeMarker = std::numeric_limits<uint32_t>::max()
 
static constexpr bst_node_t kRoot {0}
 

Detailed Description

define regression tree to be the most common tree model.

This is the data structure used in xgboost's major tree models.

Member Typedef Documentation

◆ SplitCondT

Constructor & Destructor Documentation

◆ RegTree() [1/2]

xgboost::RegTree::RegTree ( )
inline

◆ RegTree() [2/2]

xgboost::RegTree::RegTree ( bst_target_t  n_targets,
bst_feature_t  n_features 
)
inlineexplicit

Constructor that initializes the tree model with shape.

Member Function Documentation

◆ CalculateContributionsApprox()

void xgboost::RegTree::CalculateContributionsApprox ( const RegTree::FVec feat,
std::vector< float > *  mean_values,
bst_float out_contribs 
) const

calculate the approximate feature contributions for the given root

Parameters
featdense feature vector, if the feature is missing the field is set to NaN
out_contribsoutput vector to hold the contributions

◆ ChangeToLeaf()

void xgboost::RegTree::ChangeToLeaf ( int  rid,
bst_float  value 
)
inline

change a non leaf node to a leaf node, delete its children

Parameters
ridnode id of the node
valuenew leaf value

◆ CollapseToLeaf()

void xgboost::RegTree::CollapseToLeaf ( int  rid,
bst_float  value 
)
inline

collapse a non leaf node to a leaf node, delete its children

Parameters
ridnode id of the node
valuenew leaf value

◆ DefaultLeft()

bool xgboost::RegTree::DefaultLeft ( bst_node_t  nidx) const
inline

◆ DumpModel()

std::string xgboost::RegTree::DumpModel ( const FeatureMap fmap,
bool  with_stats,
std::string  format 
) const

dump the model in the requested format as a text string

Parameters
fmapfeature map that may help give interpretations of feature
with_statswhether dump out statistics as well
formatthe format to dump the model in
Returns
the string of dumped model

◆ Equal()

bool xgboost::RegTree::Equal ( const RegTree b) const

Compares whether 2 trees are equal from a user's perspective. The equality compares only non-deleted nodes.

Parameters
bThe other tree.

◆ ExpandCategorical()

void xgboost::RegTree::ExpandCategorical ( bst_node_t  nid,
bst_feature_t  split_index,
common::Span< const uint32_t >  split_cat,
bool  default_left,
bst_float  base_weight,
bst_float  left_leaf_weight,
bst_float  right_leaf_weight,
bst_float  loss_change,
float  sum_hess,
float  left_sum,
float  right_sum 
)

Expands a leaf node with categories.

Parameters
nidThe node index to expand.
split_indexFeature index of the split.
split_catThe bitset containing categories
default_leftTrue to default left.
base_weightThe base weight, before learning rate.
left_leaf_weightThe left leaf weight for prediction, modified by learning rate.
right_leaf_weightThe right leaf weight for prediction, modified by learning rate.
loss_changeThe loss change.
sum_hessThe sum hess.
left_sumThe sum hess of left leaf.
right_sumThe sum hess of right leaf.

◆ ExpandNode() [1/2]

void xgboost::RegTree::ExpandNode ( bst_node_t  nid,
unsigned  split_index,
bst_float  split_value,
bool  default_left,
bst_float  base_weight,
bst_float  left_leaf_weight,
bst_float  right_leaf_weight,
bst_float  loss_change,
float  sum_hess,
float  left_sum,
float  right_sum,
bst_node_t  leaf_right_child = kInvalidNodeId 
)

Expands a leaf node into two additional leaf nodes.

Parameters
nidThe node index to expand.
split_indexFeature index of the split.
split_valueThe split condition.
default_leftTrue to default left.
base_weightThe base weight, before learning rate.
left_leaf_weightThe left leaf weight for prediction, modified by learning rate.
right_leaf_weightThe right leaf weight for prediction, modified by learning rate.
loss_changeThe loss change.
sum_hessThe sum hess.
left_sumThe sum hess of left leaf.
right_sumThe sum hess of right leaf.
leaf_right_childThe right child index of leaf, by default kInvalidNodeId, some updaters use the right child index of leaf as a marker

◆ ExpandNode() [2/2]

void xgboost::RegTree::ExpandNode ( bst_node_t  nidx,
bst_feature_t  split_index,
float  split_cond,
bool  default_left,
linalg::VectorView< float const >  base_weight,
linalg::VectorView< float const >  left_weight,
linalg::VectorView< float const >  right_weight 
)

Expands a leaf node into two additional leaf nodes for a multi-target tree.

◆ GetCategoriesMatrix()

CategoricalSplitMatrix xgboost::RegTree::GetCategoriesMatrix ( ) const
inline

◆ GetDepth()

std::int32_t xgboost::RegTree::GetDepth ( bst_node_t  nid) const
inline

get current depth

Parameters
nidnode id

◆ GetMultiTargetTree()

auto xgboost::RegTree::GetMultiTargetTree ( ) const
inline

Get the underlying implementaiton of multi-target tree.

◆ GetNodes()

const std::vector<Node>& xgboost::RegTree::GetNodes ( ) const
inline

get const reference to nodes

◆ GetNumLeaves()

bst_node_t xgboost::RegTree::GetNumLeaves ( ) const

◆ GetNumSplitNodes()

bst_node_t xgboost::RegTree::GetNumSplitNodes ( ) const

◆ GetSplitCategories()

common::Span<uint32_t const> xgboost::RegTree::GetSplitCategories ( ) const
inline

◆ GetSplitCategoriesPtr()

auto const& xgboost::RegTree::GetSplitCategoriesPtr ( ) const
inline

◆ GetSplitTypes()

std::vector<FeatureType> const& xgboost::RegTree::GetSplitTypes ( ) const
inline

Get split types for all nodes.

◆ GetStats()

const std::vector<RTreeNodeStat>& xgboost::RegTree::GetStats ( ) const
inline

get const reference to stats

◆ HasCategoricalSplit()

bool xgboost::RegTree::HasCategoricalSplit ( ) const
inline

Whether this tree has categorical split.

◆ IsLeaf()

bool xgboost::RegTree::IsLeaf ( bst_node_t  nidx) const
inline

◆ IsLeftChild()

bool xgboost::RegTree::IsLeftChild ( bst_node_t  nidx) const
inline

◆ IsMultiTarget()

bool xgboost::RegTree::IsMultiTarget ( ) const
inline

Whether this is a multi-target tree.

◆ IsRoot()

bool xgboost::RegTree::IsRoot ( bst_node_t  nidx) const
inline

◆ LeftChild()

bst_node_t xgboost::RegTree::LeftChild ( bst_node_t  nidx) const
inline

◆ Load()

void xgboost::RegTree::Load ( dmlc::Stream *  fi)

load model from stream

Parameters
fiinput stream

◆ LoadModel()

void xgboost::RegTree::LoadModel ( Json const &  in)
overridevirtual

load the model from a JSON object

Parameters
inJSON object where to load the model from

Implements xgboost::Model.

◆ MaxDepth() [1/2]

int xgboost::RegTree::MaxDepth ( )
inline

get maximum depth

◆ MaxDepth() [2/2]

int xgboost::RegTree::MaxDepth ( int  nid) const
inline

get maximum depth

Parameters
nidnode id

◆ NodeCats()

common::Span<uint32_t const> xgboost::RegTree::NodeCats ( bst_node_t  nidx) const
inline

Get the bit storage for categories.

◆ NodeSplitType()

FeatureType xgboost::RegTree::NodeSplitType ( bst_node_t  nidx) const
inline

Get split type for a node.

Parameters
nidxIndex of node.
Returns
The type of this split. For leaf node it's always kNumerical.

◆ NumExtraNodes()

bst_node_t xgboost::RegTree::NumExtraNodes ( ) const
inlinenoexcept

number of extra nodes besides the root

◆ NumFeatures()

bst_feature_t xgboost::RegTree::NumFeatures ( ) const
inlinenoexcept

Get the number of features.

◆ NumNodes()

bst_node_t xgboost::RegTree::NumNodes ( ) const
inlinenoexcept

Get the total number of nodes including deleted ones in this tree.

◆ NumTargets()

bst_target_t xgboost::RegTree::NumTargets ( ) const
inline

The size of leaf weight.

◆ NumValidNodes()

bst_node_t xgboost::RegTree::NumValidNodes ( ) const
inlinenoexcept

Get the total number of valid nodes in this tree.

◆ operator==()

bool xgboost::RegTree::operator== ( const RegTree b) const
inline

◆ operator[]() [1/2]

Node& xgboost::RegTree::operator[] ( int  nid)
inline

get node given nid

◆ operator[]() [2/2]

const Node& xgboost::RegTree::operator[] ( int  nid) const
inline

get node given nid

◆ Parent()

bst_node_t xgboost::RegTree::Parent ( bst_node_t  nidx) const
inline

◆ RightChild()

bst_node_t xgboost::RegTree::RightChild ( bst_node_t  nidx) const
inline

◆ Save()

void xgboost::RegTree::Save ( dmlc::Stream *  fo) const

save model to stream

Parameters
fooutput stream

◆ SaveModel()

void xgboost::RegTree::SaveModel ( Json out) const
overridevirtual

saves the model config to a JSON object

Parameters
outJSON container where to save the model to

Implements xgboost::Model.

◆ SetLeaf()

void xgboost::RegTree::SetLeaf ( bst_node_t  nidx,
linalg::VectorView< float const >  weight 
)
inline

Set the leaf weight for a multi-target tree.

◆ Size()

bst_node_t xgboost::RegTree::Size ( ) const
inline

◆ SplitCond()

float xgboost::RegTree::SplitCond ( bst_node_t  nidx) const
inline

◆ SplitIndex()

bst_feature_t xgboost::RegTree::SplitIndex ( bst_node_t  nidx) const
inline

◆ Stat() [1/2]

RTreeNodeStat& xgboost::RegTree::Stat ( int  nid)
inline

get node statistics given nid

◆ Stat() [2/2]

const RTreeNodeStat& xgboost::RegTree::Stat ( int  nid) const
inline

get node statistics given nid

◆ WalkTree()

template<typename Func >
void xgboost::RegTree::WalkTree ( Func  func) const
inline

Member Data Documentation

◆ kDeletedNodeMarker

constexpr uint32_t xgboost::RegTree::kDeletedNodeMarker = std::numeric_limits<uint32_t>::max()
staticconstexpr

◆ kInvalidNodeId

constexpr bst_node_t xgboost::RegTree::kInvalidNodeId {MultiTargetTree::InvalidNodeId()}
staticconstexpr

◆ kRoot

constexpr bst_node_t xgboost::RegTree::kRoot {0}
staticconstexpr

The documentation for this class was generated from the following file: