xgboost
tree_updater.h
Go to the documentation of this file.
1 
8 #ifndef XGBOOST_TREE_UPDATER_H_
9 #define XGBOOST_TREE_UPDATER_H_
10 
11 #include <dmlc/registry.h>
12 #include <xgboost/base.h> // for Args, GradientPair
13 #include <xgboost/data.h> // DMatrix
14 #include <xgboost/host_device_vector.h> // for HostDeviceVector
15 #include <xgboost/linalg.h> // for VectorView
16 #include <xgboost/model.h> // for Configurable
17 #include <xgboost/span.h> // for Span
18 #include <xgboost/tree_model.h> // for RegTree
19 
20 #include <functional> // for function
21 #include <string> // for string
22 #include <vector> // for vector
23 
24 namespace xgboost {
25 namespace tree {
26 struct TrainParam;
27 }
28 
29 class Json;
30 struct Context;
31 struct ObjInfo;
32 
36 class TreeUpdater : public Configurable {
37  protected:
38  Context const* ctx_ = nullptr;
39 
40  public:
41  explicit TreeUpdater(const Context* ctx) : ctx_(ctx) {}
43  ~TreeUpdater() override = default;
48  virtual void Configure(const Args& args) = 0;
55  [[nodiscard]] virtual bool CanModifyTree() const { return false; }
60  [[nodiscard]] virtual bool HasNodePosition() const { return false; }
74  virtual void Update(tree::TrainParam const* param, linalg::Matrix<GradientPair>* gpair,
76  const std::vector<RegTree*>& out_trees) = 0;
77 
88  virtual bool UpdatePredictionCache(const DMatrix* /*data*/,
89  linalg::MatrixView<float> /*out_preds*/) {
90  return false;
91  }
92 
93  [[nodiscard]] virtual char const* Name() const = 0;
94 
101  static TreeUpdater* Create(const std::string& name, Context const* ctx, ObjInfo const* task);
102 };
103 
108  : public dmlc::FunctionRegEntryBase<
109  TreeUpdaterReg, std::function<TreeUpdater*(Context const* ctx, ObjInfo const* task)>> {};
110 
123 #define XGBOOST_REGISTER_TREE_UPDATER(UniqueId, Name) \
124  static DMLC_ATTRIBUTE_UNUSED ::xgboost::TreeUpdaterReg& \
125  __make_ ## TreeUpdaterReg ## _ ## UniqueId ## __ = \
126  ::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->__REGISTER__(Name)
127 
128 } // namespace xgboost
129 #endif // XGBOOST_TREE_UPDATER_H_
Defines configuration macros and basic types for xgboost.
Internal data structured used by XGBoost to hold all external data.
Definition: data.h:573
Data structure representing JSON format.
Definition: json.h:392
interface of tree update module, that performs update of a tree.
Definition: tree_updater.h:36
Context const * ctx_
Definition: tree_updater.h:38
~TreeUpdater() override=default
virtual destructor
virtual bool UpdatePredictionCache(const DMatrix *, linalg::MatrixView< float >)
determines whether updater has enough knowledge about a given dataset to quickly update prediction ca...
Definition: tree_updater.h:88
static TreeUpdater * Create(const std::string &name, Context const *ctx, ObjInfo const *task)
Create a tree updater given name.
virtual bool CanModifyTree() const
Whether this updater can be used for updating existing trees.
Definition: tree_updater.h:55
virtual bool HasNodePosition() const
Wether the out_position in Update is valid. This determines whether adaptive tree can be used.
Definition: tree_updater.h:60
TreeUpdater(const Context *ctx)
Definition: tree_updater.h:41
virtual void Update(tree::TrainParam const *param, linalg::Matrix< GradientPair > *gpair, DMatrix *data, common::Span< HostDeviceVector< bst_node_t >> out_position, const std::vector< RegTree * > &out_trees)=0
perform update to the tree models
virtual void Configure(const Args &args)=0
Initialize the updater with given arguments.
virtual char const * Name() const =0
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition: span.h:431
A tensor view with static type and dimension. It implements indexing and slicing.
Definition: linalg.h:277
A tensor storage. To use it for other functionality like slicing one needs to obtain a view first....
Definition: linalg.h:745
The input data structure of xgboost.
A device-and-host vector abstraction layer.
Linear algebra related utilities.
Defines the abstract interface for different components in XGBoost.
Learner interface that integrates objective, gbm and evaluation together. This is the user facing XGB...
Definition: base.h:97
std::vector< std::pair< std::string, std::string > > Args
Definition: base.h:324
Definition: model.h:28
Runtime context for XGBoost. Contains information like threads and device.
Definition: context.h:133
A struct returned by objective, which determines task at hand. The struct is not used by any algorith...
Definition: task.h:24
Registry entry for tree updater.
Definition: tree_updater.h:109