xgboost
Public Member Functions | Static Public Member Functions | Friends | List of all members
xgboost::MultiTargetTree Class Reference

Tree structure for multi-target model. More...

#include <multi_target_tree_model.h>

Inheritance diagram for xgboost::MultiTargetTree:
Inheritance graph
Collaboration diagram for xgboost::MultiTargetTree:
Collaboration graph

Public Member Functions

 MultiTargetTree (TreeParam const *param)
 
 MultiTargetTree (MultiTargetTree const &that)
 
MultiTargetTreeoperator= (MultiTargetTree const &that)=delete
 
 MultiTargetTree (MultiTargetTree &&that)=delete
 
MultiTargetTreeoperator= (MultiTargetTree &&that)=delete
 
void SetRoot (linalg::VectorView< float const > weight, float sum_hess)
 Set the weight and statistics for the root. More...
 
void Expand (bst_node_t nidx, bst_feature_t split_idx, float split_cond, bool default_left, linalg::VectorView< float const > base_weight, linalg::VectorView< float const > left_weight, linalg::VectorView< float const > right_weight, float loss_chg, float sum_hess, float left_sum, float right_sum)
 Expand a leaf into split node. More...
 
void SetLeaves (std::vector< bst_node_t > leaves, common::Span< float const > weights)
 
void SetLeaves ()
 Copy base weight into leaf weight for a non-reduced multi-target tree. More...
 
bool IsLeaf (bst_node_t nidx) const
 
bst_node_t LeftChild (bst_node_t nidx) const
 
bst_node_t RightChild (bst_node_t nidx) const
 
bst_target_t NumTargets () const
 Number of targets (size of a leaf). More...
 
bst_target_t NumSplitTargets () const
 Number of reduced targets. More...
 
auto NumLeaves () const
 
std::size_t Size () const
 
MultiTargetTreeCopy (TreeParam const *param) const
 
common::Span< float const > LeafWeights (DeviceOrd device) const
 
linalg::VectorView< float const > LeafValue (bst_node_t nidx) const
 
void LoadModel (Json const &in) override
 load the model from a JSON object More...
 
void SaveModel (Json *out) const override
 saves the model config to a JSON object More...
 
std::size_t MemCostBytes () const
 
- Public Member Functions inherited from xgboost::Model
virtual ~Model ()=default
 

Static Public Member Functions

static constexpr bst_node_t InvalidNodeId ()
 

Friends

struct tree::MultiTargetTreeView
 

Detailed Description

Tree structure for multi-target model.

In order to support reduced gradient, the internal storage distinguishes weights between base weights and leaf weights. The former is the weight calculated from split gradient, and the later is the weight calculated from value gradient and used as outputs. Every node has a base weight, but only leaves have leaf weights.

To access the leaf weights, we re-use the right child to store leaf indices. For split nodes, the right_ member stores their right child node indices, for leaf nodes, the right_ member stores the corresponding leaf weight indices.

Constructor & Destructor Documentation

◆ MultiTargetTree() [1/3]

xgboost::MultiTargetTree::MultiTargetTree ( TreeParam const *  param)
explicit

◆ MultiTargetTree() [2/3]

xgboost::MultiTargetTree::MultiTargetTree ( MultiTargetTree const &  that)

◆ MultiTargetTree() [3/3]

xgboost::MultiTargetTree::MultiTargetTree ( MultiTargetTree &&  that)
delete

Member Function Documentation

◆ Copy()

MultiTargetTree* xgboost::MultiTargetTree::Copy ( TreeParam const *  param) const

◆ Expand()

void xgboost::MultiTargetTree::Expand ( bst_node_t  nidx,
bst_feature_t  split_idx,
float  split_cond,
bool  default_left,
linalg::VectorView< float const >  base_weight,
linalg::VectorView< float const >  left_weight,
linalg::VectorView< float const >  right_weight,
float  loss_chg,
float  sum_hess,
float  left_sum,
float  right_sum 
)

Expand a leaf into split node.

◆ InvalidNodeId()

static constexpr bst_node_t xgboost::MultiTargetTree::InvalidNodeId ( )
inlinestaticconstexpr

◆ IsLeaf()

bool xgboost::MultiTargetTree::IsLeaf ( bst_node_t  nidx) const
inline

◆ LeafValue()

linalg::VectorView<float const> xgboost::MultiTargetTree::LeafValue ( bst_node_t  nidx) const
inline

◆ LeafWeights()

common::Span<float const> xgboost::MultiTargetTree::LeafWeights ( DeviceOrd  device) const
inline

◆ LeftChild()

bst_node_t xgboost::MultiTargetTree::LeftChild ( bst_node_t  nidx) const
inline

◆ LoadModel()

void xgboost::MultiTargetTree::LoadModel ( Json const &  in)
overridevirtual

load the model from a JSON object

Parameters
inJSON object where to load the model from

Implements xgboost::Model.

◆ MemCostBytes()

std::size_t xgboost::MultiTargetTree::MemCostBytes ( ) const

◆ NumLeaves()

auto xgboost::MultiTargetTree::NumLeaves ( ) const
inline

◆ NumSplitTargets()

bst_target_t xgboost::MultiTargetTree::NumSplitTargets ( ) const

Number of reduced targets.

◆ NumTargets()

bst_target_t xgboost::MultiTargetTree::NumTargets ( ) const

Number of targets (size of a leaf).

◆ operator=() [1/2]

MultiTargetTree& xgboost::MultiTargetTree::operator= ( MultiTargetTree &&  that)
delete

◆ operator=() [2/2]

MultiTargetTree& xgboost::MultiTargetTree::operator= ( MultiTargetTree const &  that)
delete

◆ RightChild()

bst_node_t xgboost::MultiTargetTree::RightChild ( bst_node_t  nidx) const
inline

◆ SaveModel()

void xgboost::MultiTargetTree::SaveModel ( Json out) const
overridevirtual

saves the model config to a JSON object

Parameters
outJSON container where to save the model to

Implements xgboost::Model.

◆ SetLeaves() [1/2]

void xgboost::MultiTargetTree::SetLeaves ( )

Copy base weight into leaf weight for a non-reduced multi-target tree.

◆ SetLeaves() [2/2]

void xgboost::MultiTargetTree::SetLeaves ( std::vector< bst_node_t leaves,
common::Span< float const >  weights 
)

◆ SetRoot()

void xgboost::MultiTargetTree::SetRoot ( linalg::VectorView< float const >  weight,
float  sum_hess 
)

Set the weight and statistics for the root.

Parameters
weightThe weight vector for the root node.
sum_hessThe sum of hessians for the root node (coverage).

◆ Size()

std::size_t xgboost::MultiTargetTree::Size ( ) const

Friends And Related Function Documentation

◆ tree::MultiTargetTreeView

friend struct tree::MultiTargetTreeView
friend

The documentation for this class was generated from the following file: