6 #ifndef XGBOOST_MULTI_TARGET_TREE_MODEL_H_
7 #define XGBOOST_MULTI_TARGET_TREE_MODEL_H_
29 std::vector<bst_node_t> left_;
30 std::vector<bst_node_t> right_;
31 std::vector<bst_node_t> parent_;
32 std::vector<bst_feature_t> split_index_;
33 std::vector<std::uint8_t> default_left_;
34 std::vector<float> split_conds_;
35 std::vector<float> weights_;
42 [[nodiscard]] linalg::VectorView<float> NodeWeight(
bst_node_t nidx) {
44 auto v = common::Span<float>{weights_}.subspan(beg, this->
NumTarget());
76 [[nodiscard]] std::size_t
Size()
const;
89 return this->NodeWeight(nidx);
Defines configuration macros and basic types for xgboost.
Data structure representing JSON format.
Definition: json.h:357
Tree structure for multi-target model.
Definition: multi_target_tree_model.h:23
bool IsLeaf(bst_node_t nidx) const
Definition: multi_target_tree_model.h:62
bst_feature_t SplitIndex(bst_node_t nidx) const
Definition: multi_target_tree_model.h:67
bst_node_t Parent(bst_node_t nidx) const
Definition: multi_target_tree_model.h:63
bst_node_t RightChild(bst_node_t nidx) const
Definition: multi_target_tree_model.h:65
bst_target_t NumTarget() const
void SaveModel(Json *out) const override
saves the model config to a JSON object
static constexpr bst_node_t InvalidNodeId()
Definition: multi_target_tree_model.h:25
MultiTargetTree(TreeParam const *param)
bool DefaultLeft(bst_node_t nidx) const
Definition: multi_target_tree_model.h:69
bst_node_t LeftChild(bst_node_t nidx) const
Definition: multi_target_tree_model.h:64
bst_node_t DefaultChild(bst_node_t nidx) const
Definition: multi_target_tree_model.h:70
void SetLeaf(bst_node_t nidx, linalg::VectorView< float const > weight)
Set the weight for a leaf.
float SplitCond(bst_node_t nidx) const
Definition: multi_target_tree_model.h:68
linalg::VectorView< float const > LeafValue(bst_node_t nidx) const
Definition: multi_target_tree_model.h:87
bst_node_t Depth(bst_node_t nidx) const
Definition: multi_target_tree_model.h:78
void Expand(bst_node_t nidx, bst_feature_t split_idx, float split_cond, bool default_left, linalg::VectorView< float const > base_weight, linalg::VectorView< float const > left_weight, linalg::VectorView< float const > right_weight)
Expand a leaf into split node.
void LoadModel(Json const &in) override
load the model from a JSON object
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition: span.h:424
XGBOOST_DEVICE auto subspan() const -> Span< element_type, detail::ExtentValue< Extent, Offset, Count >::value >
Definition: span.h:596
A tensor view with static type and dimension. It implements indexing and slicing.
Definition: linalg.h:293
Linear algebra related utilities.
Defines the abstract interface for different components in XGBoost.
auto MakeTensorView(Context const *ctx, Container &data, S &&...shape)
Constructor for automatic type deduction.
Definition: linalg.h:576
namespace of xgboost
Definition: base.h:90
uint32_t bst_feature_t
Type for data column (feature) index.
Definition: base.h:101
std::int32_t bst_node_t
Type for tree node index.
Definition: base.h:112
std::uint32_t bst_target_t
Type for indexing into output targets.
Definition: base.h:118
static constexpr bst_d_ordinal_t kCpuId
Definition: context.h:93
meta parameters of the tree
Definition: tree_model.h:35