xgboost
multi_target_tree_model.h
Go to the documentation of this file.
1 
6 #ifndef XGBOOST_MULTI_TARGET_TREE_MODEL_H_
7 #define XGBOOST_MULTI_TARGET_TREE_MODEL_H_
8 #include <xgboost/base.h> // for bst_node_t, bst_target_t, bst_feature_t
9 #include <xgboost/context.h> // for Context
10 #include <xgboost/linalg.h> // for VectorView
11 #include <xgboost/model.h> // for Model
12 #include <xgboost/span.h> // for Span
13 
14 #include <cinttypes> // for uint8_t
15 #include <cstddef> // for size_t
16 #include <vector> // for vector
17 
18 namespace xgboost {
19 struct TreeParam;
23 class MultiTargetTree : public Model {
24  public:
25  static bst_node_t constexpr InvalidNodeId() { return -1; }
26 
27  private:
28  TreeParam const* param_;
29  std::vector<bst_node_t> left_;
30  std::vector<bst_node_t> right_;
31  std::vector<bst_node_t> parent_;
32  std::vector<bst_feature_t> split_index_;
33  std::vector<std::uint8_t> default_left_;
34  std::vector<float> split_conds_;
35  std::vector<float> weights_;
36 
37  [[nodiscard]] linalg::VectorView<float const> NodeWeight(bst_node_t nidx) const {
38  auto beg = nidx * this->NumTarget();
39  auto v = common::Span<float const>{weights_}.subspan(beg, this->NumTarget());
40  return linalg::MakeTensorView(Context::kCpuId, v, v.size());
41  }
42  [[nodiscard]] linalg::VectorView<float> NodeWeight(bst_node_t nidx) {
43  auto beg = nidx * this->NumTarget();
44  auto v = common::Span<float>{weights_}.subspan(beg, this->NumTarget());
45  return linalg::MakeTensorView(Context::kCpuId, v, v.size());
46  }
47 
48  public:
49  explicit MultiTargetTree(TreeParam const* param);
57  void Expand(bst_node_t nidx, bst_feature_t split_idx, float split_cond, bool default_left,
60  linalg::VectorView<float const> right_weight);
61 
62  [[nodiscard]] bool IsLeaf(bst_node_t nidx) const { return left_[nidx] == InvalidNodeId(); }
63  [[nodiscard]] bst_node_t Parent(bst_node_t nidx) const { return parent_.at(nidx); }
64  [[nodiscard]] bst_node_t LeftChild(bst_node_t nidx) const { return left_.at(nidx); }
65  [[nodiscard]] bst_node_t RightChild(bst_node_t nidx) const { return right_.at(nidx); }
66 
67  [[nodiscard]] bst_feature_t SplitIndex(bst_node_t nidx) const { return split_index_[nidx]; }
68  [[nodiscard]] float SplitCond(bst_node_t nidx) const { return split_conds_[nidx]; }
69  [[nodiscard]] bool DefaultLeft(bst_node_t nidx) const { return default_left_[nidx]; }
70  [[nodiscard]] bst_node_t DefaultChild(bst_node_t nidx) const {
71  return this->DefaultLeft(nidx) ? this->LeftChild(nidx) : this->RightChild(nidx);
72  }
73 
74  [[nodiscard]] bst_target_t NumTarget() const;
75 
76  [[nodiscard]] std::size_t Size() const;
77 
78  [[nodiscard]] bst_node_t Depth(bst_node_t nidx) const {
79  bst_node_t depth{0};
80  while (Parent(nidx) != InvalidNodeId()) {
81  ++depth;
82  nidx = Parent(nidx);
83  }
84  return depth;
85  }
86 
88  CHECK(IsLeaf(nidx));
89  return this->NodeWeight(nidx);
90  }
91 
92  void LoadModel(Json const& in) override;
93  void SaveModel(Json* out) const override;
94 };
95 } // namespace xgboost
96 #endif // XGBOOST_MULTI_TARGET_TREE_MODEL_H_
Defines configuration macros and basic types for xgboost.
Data structure representing JSON format.
Definition: json.h:357
Tree structure for multi-target model.
Definition: multi_target_tree_model.h:23
bool IsLeaf(bst_node_t nidx) const
Definition: multi_target_tree_model.h:62
bst_feature_t SplitIndex(bst_node_t nidx) const
Definition: multi_target_tree_model.h:67
bst_node_t Parent(bst_node_t nidx) const
Definition: multi_target_tree_model.h:63
std::size_t Size() const
bst_node_t RightChild(bst_node_t nidx) const
Definition: multi_target_tree_model.h:65
bst_target_t NumTarget() const
void SaveModel(Json *out) const override
saves the model config to a JSON object
static constexpr bst_node_t InvalidNodeId()
Definition: multi_target_tree_model.h:25
MultiTargetTree(TreeParam const *param)
bool DefaultLeft(bst_node_t nidx) const
Definition: multi_target_tree_model.h:69
bst_node_t LeftChild(bst_node_t nidx) const
Definition: multi_target_tree_model.h:64
bst_node_t DefaultChild(bst_node_t nidx) const
Definition: multi_target_tree_model.h:70
void SetLeaf(bst_node_t nidx, linalg::VectorView< float const > weight)
Set the weight for a leaf.
float SplitCond(bst_node_t nidx) const
Definition: multi_target_tree_model.h:68
linalg::VectorView< float const > LeafValue(bst_node_t nidx) const
Definition: multi_target_tree_model.h:87
bst_node_t Depth(bst_node_t nidx) const
Definition: multi_target_tree_model.h:78
void Expand(bst_node_t nidx, bst_feature_t split_idx, float split_cond, bool default_left, linalg::VectorView< float const > base_weight, linalg::VectorView< float const > left_weight, linalg::VectorView< float const > right_weight)
Expand a leaf into split node.
void LoadModel(Json const &in) override
load the model from a JSON object
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition: span.h:424
XGBOOST_DEVICE auto subspan() const -> Span< element_type, detail::ExtentValue< Extent, Offset, Count >::value >
Definition: span.h:596
A tensor view with static type and dimension. It implements indexing and slicing.
Definition: linalg.h:293
Linear algebra related utilities.
Defines the abstract interface for different components in XGBoost.
auto MakeTensorView(Context const *ctx, Container &data, S &&...shape)
Constructor for automatic type deduction.
Definition: linalg.h:576
namespace of xgboost
Definition: base.h:90
uint32_t bst_feature_t
Type for data column (feature) index.
Definition: base.h:101
std::int32_t bst_node_t
Type for tree node index.
Definition: base.h:112
std::uint32_t bst_target_t
Type for indexing into output targets.
Definition: base.h:118
static constexpr bst_d_ordinal_t kCpuId
Definition: context.h:93
Definition: model.h:17
meta parameters of the tree
Definition: tree_model.h:35