6 #ifndef XGBOOST_MULTI_TARGET_TREE_MODEL_H_
7 #define XGBOOST_MULTI_TARGET_TREE_MODEL_H_
22 struct MultiTargetTreeView;
74 auto beg = nidx * n_split_targets;
75 auto v = this->weights_.
HostSpan().subspan(beg, n_split_targets);
101 float left_sum,
float right_sum);
126 [[nodiscard]] std::size_t
Size()
const;
130 if (device.
IsCPU()) {
142 auto lidx = h_leaf_mapping[nidx];
144 auto weight = h_leaf_weights.subspan(lidx * n_targets, n_targets);
Defines configuration macros and basic types for xgboost.
common::Span< T const > ConstHostSpan() const
Definition: host_device_vector.h:118
common::Span< const T > ConstDeviceSpan() const
const std::vector< T > & ConstHostVector() const
common::Span< T > HostSpan()
Definition: host_device_vector.h:116
void SetDevice(DeviceOrd device) const
Data structure representing JSON format.
Definition: json.h:396
Tree structure for multi-target model.
Definition: multi_target_tree_model.h:38
void SetLeaves()
Copy base weight into leaf weight for a non-reduced multi-target tree.
bst_target_t NumSplitTargets() const
Number of reduced targets.
bool IsLeaf(bst_node_t nidx) const
Definition: multi_target_tree_model.h:107
bst_node_t RightChild(bst_node_t nidx) const
Definition: multi_target_tree_model.h:113
MultiTargetTree * Copy(TreeParam const *param) const
MultiTargetTree & operator=(MultiTargetTree &&that)=delete
MultiTargetTree(MultiTargetTree const &that)
bst_target_t NumTargets() const
Number of targets (size of a leaf).
MultiTargetTree & operator=(MultiTargetTree const &that)=delete
void SaveModel(Json *out) const override
saves the model config to a JSON object
void SetRoot(linalg::VectorView< float const > weight, float sum_hess)
Set the weight and statistics for the root.
static constexpr bst_node_t InvalidNodeId()
Definition: multi_target_tree_model.h:40
MultiTargetTree(TreeParam const *param)
bst_node_t LeftChild(bst_node_t nidx) const
Definition: multi_target_tree_model.h:110
auto NumLeaves() const
Definition: multi_target_tree_model.h:124
void SetLeaves(std::vector< bst_node_t > leaves, common::Span< float const > weights)
void Expand(bst_node_t nidx, bst_feature_t split_idx, float split_cond, bool default_left, linalg::VectorView< float const > base_weight, linalg::VectorView< float const > left_weight, linalg::VectorView< float const > right_weight, float loss_chg, float sum_hess, float left_sum, float right_sum)
Expand a leaf into split node.
friend struct tree::MultiTargetTreeView
Definition: multi_target_tree_model.h:41
MultiTargetTree(MultiTargetTree &&that)=delete
linalg::VectorView< float const > LeafValue(bst_node_t nidx) const
Definition: multi_target_tree_model.h:137
void LoadModel(Json const &in) override
load the model from a JSON object
std::size_t MemCostBytes() const
common::Span< float const > LeafWeights(DeviceOrd device) const
Definition: multi_target_tree_model.h:129
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition: span.h:435
XGBOOST_DEVICE auto subspan() const -> Span< element_type, detail::ExtentValue< Extent, Offset, Count >::value >
Definition: span.h:601
A tensor view with static type and dimension. It implements indexing and slicing.
Definition: linalg.h:278
A device-and-host vector abstraction layer.
Linear algebra related utilities.
Defines the abstract interface for different components in XGBoost.
auto MakeTensorView(Context const *ctx, Container &data, S &&...shape)
Constructor for automatic type deduction.
Definition: linalg.h:568
auto MakeVec(T *ptr, size_t s, DeviceOrd device=DeviceOrd::CPU())
Create a vector view from contigious memory.
Definition: linalg.h:648
Learner interface that integrates objective, gbm and evaluation together. This is the user facing XGB...
Definition: base.h:89
std::int32_t bst_node_t
Type for tree node index and tree depth.
Definition: base.h:111
std::uint32_t bst_target_t
Type for indexing into output targets.
Definition: base.h:119
std::uint32_t bst_feature_t
Type for data column (feature) index.
Definition: base.h:99
A type for device ordinal. The type is packed into 32-bit for efficient use in viewing types like lin...
Definition: context.h:40
bool IsCPU() const
Definition: context.h:56
constexpr static auto CPU()
Constructor for CPU.
Definition: context.h:73
meta parameters of the tree
Definition: tree_model.h:37