6 #ifndef XGBOOST_MULTI_TARGET_TREE_MODEL_H_
7 #define XGBOOST_MULTI_TARGET_TREE_MODEL_H_
83 mutable std::mutex tree_view_lock_;
90 [[nodiscard]] linalg::VectorView<float> NodeWeight(
bst_node_t nidx) {
143 [[nodiscard]] std::size_t
Size()
const;
156 return this->NodeWeight(nidx);
Defines configuration macros and basic types for xgboost.
#define XGBOOST_DEVICE
Tag function as usable by device.
Definition: base.h:64
common::Span< T const > ConstHostSpan() const
Definition: host_device_vector.h:116
const std::vector< T > & ConstHostVector() const
common::Span< T > HostSpan()
Definition: host_device_vector.h:114
Data structure representing JSON format.
Definition: json.h:392
Tree structure for multi-target model.
Definition: multi_target_tree_model.h:69
bool IsLeaf(bst_node_t nidx) const
Definition: multi_target_tree_model.h:115
bst_feature_t SplitIndex(bst_node_t nidx) const
Definition: multi_target_tree_model.h:128
bst_node_t Parent(bst_node_t nidx) const
Definition: multi_target_tree_model.h:118
bst_node_t RightChild(bst_node_t nidx) const
Definition: multi_target_tree_model.h:124
MultiTargetTree & operator=(MultiTargetTree &&that)=delete
MultiTargetTree(MultiTargetTree const &that)
bst_target_t NumTargets() const
MultiTargetTree & operator=(MultiTargetTree const &that)=delete
void SaveModel(Json *out) const override
saves the model config to a JSON object
static constexpr bst_node_t InvalidNodeId()
Definition: multi_target_tree_model.h:71
MultiTargetTree(TreeParam const *param)
bool DefaultLeft(bst_node_t nidx) const
Definition: multi_target_tree_model.h:134
bst_node_t LeftChild(bst_node_t nidx) const
Definition: multi_target_tree_model.h:121
bst_node_t DefaultChild(bst_node_t nidx) const
Definition: multi_target_tree_model.h:137
MultiTargetTreeView View(Context const *ctx) const
Get a view to the tree.
MultiTargetTree(MultiTargetTree &&that)=delete
void SetLeaf(bst_node_t nidx, linalg::VectorView< float const > weight)
Set the weight for a leaf.
float SplitCond(bst_node_t nidx) const
Definition: multi_target_tree_model.h:131
linalg::VectorView< float const > LeafValue(bst_node_t nidx) const
Definition: multi_target_tree_model.h:154
bst_node_t Depth(bst_node_t nidx) const
Definition: multi_target_tree_model.h:145
void Expand(bst_node_t nidx, bst_feature_t split_idx, float split_cond, bool default_left, linalg::VectorView< float const > base_weight, linalg::VectorView< float const > left_weight, linalg::VectorView< float const > right_weight)
Expand a leaf into split node.
void LoadModel(Json const &in) override
load the model from a JSON object
std::size_t MemCostBytes() const
XGBOOST_DEVICE auto subspan() const -> Span< element_type, detail::ExtentValue< Extent, Offset, Count >::value >
Definition: span.h:597
A tensor view with static type and dimension. It implements indexing and slicing.
Definition: linalg.h:277
LINALG_HD auto Shape() const
Definition: linalg.h:506
LINALG_HD auto Slice(S &&...slices) const
Slice the tensor. The returned tensor has inferred dim and shape. Scalar result is not supported.
Definition: linalg.h:493
A device-and-host vector abstraction layer.
Linear algebra related utilities.
Defines the abstract interface for different components in XGBoost.
auto MakeTensorView(Context const *ctx, Container &data, S &&...shape)
Constructor for automatic type deduction.
Definition: linalg.h:564
constexpr detail::AllTag All()
Specify all elements in the axis for slicing.
Definition: linalg.h:249
Learner interface that integrates objective, gbm and evaluation together. This is the user facing XGB...
Definition: base.h:97
std::int32_t bst_node_t
Type for tree node index.
Definition: base.h:119
std::uint32_t bst_target_t
Type for indexing into output targets.
Definition: base.h:127
std::uint32_t bst_feature_t
Type for data column (feature) index.
Definition: base.h:107
Runtime context for XGBoost. Contains information like threads and device.
Definition: context.h:133
constexpr static auto CPU()
Constructor for CPU.
Definition: context.h:64
A view to the @MultiTargetTree suitable for both host and device.
Definition: multi_target_tree_model.h:26
bst_node_t Size() const
Definition: multi_target_tree_model.h:63
XGBOOST_DEVICE float SplitCond(bst_node_t nidx) const
Definition: multi_target_tree_model.h:51
std::size_t n
Definition: multi_target_tree_model.h:38
linalg::MatrixView< float const > weights
Definition: multi_target_tree_model.h:40
bst_target_t NumTargets() const
Definition: multi_target_tree_model.h:62
bst_node_t const * parent
Definition: multi_target_tree_model.h:31
XGBOOST_DEVICE bst_node_t RightChild(bst_node_t nidx) const
Definition: multi_target_tree_model.h:47
bst_node_t const * left
Definition: multi_target_tree_model.h:29
static constexpr bst_node_t InvalidNodeId()
Definition: multi_target_tree_model.h:27
XGBOOST_DEVICE bst_node_t DefaultChild(bst_node_t nidx) const
Definition: multi_target_tree_model.h:55
XGBOOST_DEVICE bool IsLeaf(bst_node_t nidx) const
Definition: multi_target_tree_model.h:42
XGBOOST_DEVICE bst_node_t LeftChild(bst_node_t nidx) const
Definition: multi_target_tree_model.h:46
XGBOOST_DEVICE linalg::VectorView< float const > LeafValue(bst_node_t nidx) const
Definition: multi_target_tree_model.h:58
XGBOOST_DEVICE bst_feature_t SplitIndex(bst_node_t nidx) const
Definition: multi_target_tree_model.h:48
std::uint8_t const * default_left
Definition: multi_target_tree_model.h:34
float const * split_conds
Definition: multi_target_tree_model.h:35
bst_feature_t const * split_index
Definition: multi_target_tree_model.h:33
XGBOOST_DEVICE bool DefaultLeft(bst_node_t nidx) const
Definition: multi_target_tree_model.h:52
bst_node_t const * right
Definition: multi_target_tree_model.h:30
meta parameters of the tree
Definition: tree_model.h:30