xgboost
multi_target_tree_model.h
Go to the documentation of this file.
1 
6 #ifndef XGBOOST_MULTI_TARGET_TREE_MODEL_H_
7 #define XGBOOST_MULTI_TARGET_TREE_MODEL_H_
8 
9 #include <xgboost/base.h> // for bst_node_t, bst_target_t, bst_feature_t
10 #include <xgboost/context.h> // for Context
11 #include <xgboost/host_device_vector.h> // for HostDeviceVector
12 #include <xgboost/linalg.h> // for VectorView, MatrixView
13 #include <xgboost/model.h> // for Model
14 #include <xgboost/span.h> // for Span
15 
16 #include <cstddef> // for size_t
17 #include <cstdint> // for uint8_t
18 #include <vector> // for vector
19 
20 namespace xgboost {
21 namespace tree {
22 struct MultiTargetTreeView;
23 }
24 struct TreeParam;
25 
38 class MultiTargetTree : public Model {
39  public:
40  static bst_node_t constexpr InvalidNodeId() { return -1; }
42 
43  private:
44  TreeParam const* param_;
45  // Mapping from node index to its left child. -1 for a leaf node.
47  // Mapping from node index to its right child. Maps to leaf weight for a leaf node.
49  // Mapping from node index to its parent.
51  // Feature index for node split.
53  // Whether the left child is the default node when split feature is missing.
54  HostDeviceVector<std::uint8_t> default_left_;
55  // Threshold for splitting a node.
56  HostDeviceVector<float> split_conds_;
57  // Internal base weights.
58  HostDeviceVector<float> weights_;
59  // Output weights.
60  HostDeviceVector<float> leaf_weights_;
61  // Loss change for each node.
62  HostDeviceVector<float> loss_chg_;
63  // Sum of hessians for each node (coverage).
64  HostDeviceVector<float> sum_hess_;
65 
66  [[nodiscard]] linalg::VectorView<float const> NodeWeight(bst_node_t nidx) const {
67  auto beg = nidx * this->NumSplitTargets();
68  auto v = this->weights_.ConstHostSpan().subspan(beg, this->NumSplitTargets());
69  return linalg::MakeTensorView(DeviceOrd::CPU(), v, v.size());
70  }
71  // Unlike the const version, `NumSplitTargets` is not reliable if the tree can change.
72  [[nodiscard]] linalg::VectorView<float> NodeWeight(bst_node_t nidx,
73  bst_target_t n_split_targets) {
74  auto beg = nidx * n_split_targets;
75  auto v = this->weights_.HostSpan().subspan(beg, n_split_targets);
76  return linalg::MakeTensorView(DeviceOrd::CPU(), v, v.size());
77  }
78  [[nodiscard]] bst_node_t LeafIdx(bst_node_t nidx) const { return this->RightChild(nidx); }
79 
80  public:
81  explicit MultiTargetTree(TreeParam const* param);
83  MultiTargetTree& operator=(MultiTargetTree const& that) = delete;
84  MultiTargetTree(MultiTargetTree&& that) = delete;
86 
93  void SetRoot(linalg::VectorView<float const> weight, float sum_hess);
97  void Expand(bst_node_t nidx, bst_feature_t split_idx, float split_cond, bool default_left,
100  linalg::VectorView<float const> right_weight, float loss_chg, float sum_hess,
101  float left_sum, float right_sum);
103  void SetLeaves(std::vector<bst_node_t> leaves, common::Span<float const> weights);
105  void SetLeaves();
106 
107  [[nodiscard]] bool IsLeaf(bst_node_t nidx) const {
108  return left_.ConstHostVector()[nidx] == InvalidNodeId();
109  }
110  [[nodiscard]] bst_node_t LeftChild(bst_node_t nidx) const {
111  return left_.ConstHostVector().at(nidx);
112  }
113  [[nodiscard]] bst_node_t RightChild(bst_node_t nidx) const {
114  return right_.ConstHostVector().at(nidx);
115  }
119  [[nodiscard]] bst_target_t NumTargets() const;
123  [[nodiscard]] bst_target_t NumSplitTargets() const;
124  [[nodiscard]] auto NumLeaves() const { return this->leaf_weights_.Size() / this->NumTargets(); }
125 
126  [[nodiscard]] std::size_t Size() const;
127  [[nodiscard]] MultiTargetTree* Copy(TreeParam const* param) const;
128 
130  if (device.IsCPU()) {
131  return this->leaf_weights_.ConstHostSpan();
132  }
133  this->leaf_weights_.SetDevice(device);
134  return this->leaf_weights_.ConstDeviceSpan();
135  }
136 
138  CHECK(IsLeaf(nidx));
139  auto n_targets = this->NumTargets();
140  auto h_leaf_mapping = this->right_.ConstHostSpan();
141  auto h_leaf_weights = this->leaf_weights_.ConstHostSpan();
142  auto lidx = h_leaf_mapping[nidx];
143  CHECK_NE(lidx, InvalidNodeId());
144  auto weight = h_leaf_weights.subspan(lidx * n_targets, n_targets);
145  return linalg::MakeVec(DeviceOrd::CPU(), weight);
146  }
147 
148  void LoadModel(Json const& in) override;
149  void SaveModel(Json* out) const override;
150 
151  [[nodiscard]] std::size_t MemCostBytes() const;
152 };
153 } // namespace xgboost
154 #endif // XGBOOST_MULTI_TARGET_TREE_MODEL_H_
Defines configuration macros and basic types for xgboost.
std::size_t Size() const
common::Span< T const > ConstHostSpan() const
Definition: host_device_vector.h:118
common::Span< const T > ConstDeviceSpan() const
const std::vector< T > & ConstHostVector() const
common::Span< T > HostSpan()
Definition: host_device_vector.h:116
void SetDevice(DeviceOrd device) const
Data structure representing JSON format.
Definition: json.h:396
Tree structure for multi-target model.
Definition: multi_target_tree_model.h:38
void SetLeaves()
Copy base weight into leaf weight for a non-reduced multi-target tree.
bst_target_t NumSplitTargets() const
Number of reduced targets.
bool IsLeaf(bst_node_t nidx) const
Definition: multi_target_tree_model.h:107
std::size_t Size() const
bst_node_t RightChild(bst_node_t nidx) const
Definition: multi_target_tree_model.h:113
MultiTargetTree * Copy(TreeParam const *param) const
MultiTargetTree & operator=(MultiTargetTree &&that)=delete
MultiTargetTree(MultiTargetTree const &that)
bst_target_t NumTargets() const
Number of targets (size of a leaf).
MultiTargetTree & operator=(MultiTargetTree const &that)=delete
void SaveModel(Json *out) const override
saves the model config to a JSON object
void SetRoot(linalg::VectorView< float const > weight, float sum_hess)
Set the weight and statistics for the root.
static constexpr bst_node_t InvalidNodeId()
Definition: multi_target_tree_model.h:40
MultiTargetTree(TreeParam const *param)
bst_node_t LeftChild(bst_node_t nidx) const
Definition: multi_target_tree_model.h:110
auto NumLeaves() const
Definition: multi_target_tree_model.h:124
void SetLeaves(std::vector< bst_node_t > leaves, common::Span< float const > weights)
void Expand(bst_node_t nidx, bst_feature_t split_idx, float split_cond, bool default_left, linalg::VectorView< float const > base_weight, linalg::VectorView< float const > left_weight, linalg::VectorView< float const > right_weight, float loss_chg, float sum_hess, float left_sum, float right_sum)
Expand a leaf into split node.
friend struct tree::MultiTargetTreeView
Definition: multi_target_tree_model.h:41
MultiTargetTree(MultiTargetTree &&that)=delete
linalg::VectorView< float const > LeafValue(bst_node_t nidx) const
Definition: multi_target_tree_model.h:137
void LoadModel(Json const &in) override
load the model from a JSON object
std::size_t MemCostBytes() const
common::Span< float const > LeafWeights(DeviceOrd device) const
Definition: multi_target_tree_model.h:129
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition: span.h:435
XGBOOST_DEVICE auto subspan() const -> Span< element_type, detail::ExtentValue< Extent, Offset, Count >::value >
Definition: span.h:601
A tensor view with static type and dimension. It implements indexing and slicing.
Definition: linalg.h:278
A device-and-host vector abstraction layer.
Linear algebra related utilities.
Defines the abstract interface for different components in XGBoost.
auto MakeTensorView(Context const *ctx, Container &data, S &&...shape)
Constructor for automatic type deduction.
Definition: linalg.h:568
auto MakeVec(T *ptr, size_t s, DeviceOrd device=DeviceOrd::CPU())
Create a vector view from contigious memory.
Definition: linalg.h:648
Learner interface that integrates objective, gbm and evaluation together. This is the user facing XGB...
Definition: base.h:89
std::int32_t bst_node_t
Type for tree node index and tree depth.
Definition: base.h:111
std::uint32_t bst_target_t
Type for indexing into output targets.
Definition: base.h:119
std::uint32_t bst_feature_t
Type for data column (feature) index.
Definition: base.h:99
A type for device ordinal. The type is packed into 32-bit for efficient use in viewing types like lin...
Definition: context.h:40
bool IsCPU() const
Definition: context.h:56
constexpr static auto CPU()
Constructor for CPU.
Definition: context.h:73
Definition: model.h:14
meta parameters of the tree
Definition: tree_model.h:37