xgboost
tree_updater.h
Go to the documentation of this file.
1 
8 #ifndef XGBOOST_TREE_UPDATER_H_
9 #define XGBOOST_TREE_UPDATER_H_
10 
11 #include <dmlc/registry.h>
12 #include <xgboost/base.h> // for Args, GradientPair
13 #include <xgboost/data.h> // for DMatrix
14 #include <xgboost/gradient.h> // for GradientContainer
15 #include <xgboost/host_device_vector.h> // for HostDeviceVector
16 #include <xgboost/linalg.h> // for VectorView
17 #include <xgboost/model.h> // for Configurable
18 #include <xgboost/span.h> // for Span
19 #include <xgboost/tree_model.h> // for RegTree
20 
21 #include <functional> // for function
22 #include <string> // for string
23 #include <vector> // for vector
24 
25 namespace xgboost {
26 namespace tree {
27 struct TrainParam;
28 }
29 
30 class Json;
31 struct Context;
32 struct ObjInfo;
33 
37 class TreeUpdater : public Configurable {
38  protected:
39  Context const* ctx_ = nullptr;
40 
41  public:
42  explicit TreeUpdater(const Context* ctx) : ctx_(ctx) {}
43  ~TreeUpdater() override = default;
48  virtual void Configure(const Args& args) = 0;
56  [[nodiscard]] virtual bool CanModifyTree() const { return false; }
61  [[nodiscard]] virtual bool HasNodePosition() const { return false; }
76  virtual void Update(tree::TrainParam const* param, GradientContainer* gpair, DMatrix* p_fmat,
78  std::vector<RegTree*> const& out_trees) = 0;
79 
92  virtual bool UpdatePredictionCache(DMatrix const* /*data*/,
94  linalg::MatrixView<float> /*out_preds*/) {
95  return false;
96  }
97 
98  [[nodiscard]] virtual char const* Name() const = 0;
99 
107  static TreeUpdater* Create(const std::string& name, Context const* ctx, ObjInfo const* task);
108 };
109 
114  : public dmlc::FunctionRegEntryBase<
115  TreeUpdaterReg, std::function<TreeUpdater*(Context const* ctx, ObjInfo const* task)>> {};
116 
129 #define XGBOOST_REGISTER_TREE_UPDATER(UniqueId, Name) \
130  static DMLC_ATTRIBUTE_UNUSED ::xgboost::TreeUpdaterReg& \
131  __make_ ## TreeUpdaterReg ## _ ## UniqueId ## __ = \
132  ::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->__REGISTER__(Name)
133 
134 } // namespace xgboost
135 #endif // XGBOOST_TREE_UPDATER_H_
Defines configuration macros and basic types for xgboost.
Internal data structured used by XGBoost to hold all external data.
Definition: data.h:577
interface of tree update module, that performs update of a tree.
Definition: tree_updater.h:37
Context const * ctx_
Definition: tree_updater.h:39
~TreeUpdater() override=default
static TreeUpdater * Create(const std::string &name, Context const *ctx, ObjInfo const *task)
Create a tree updater given name.
virtual bool CanModifyTree() const
Whether this updater can be used for updating existing trees.
Definition: tree_updater.h:56
virtual void Update(tree::TrainParam const *param, GradientContainer *gpair, DMatrix *p_fmat, common::Span< HostDeviceVector< bst_node_t >> out_position, std::vector< RegTree * > const &out_trees)=0
perform update to the tree models
virtual bool HasNodePosition() const
Whether the out_position in Update is valid. This determines whether adaptive tree can be used.
Definition: tree_updater.h:61
TreeUpdater(const Context *ctx)
Definition: tree_updater.h:42
virtual bool UpdatePredictionCache(DMatrix const *, common::Span< HostDeviceVector< bst_node_t >>, linalg::MatrixView< float >)
Determines whether updater has enough knowledge about a given dataset to quickly update prediction ca...
Definition: tree_updater.h:92
virtual void Configure(const Args &args)=0
Initialize the updater with given arguments.
virtual char const * Name() const =0
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition: span.h:435
A tensor view with static type and dimension. It implements indexing and slicing.
Definition: linalg.h:278
The input data structure of xgboost.
A device-and-host vector abstraction layer.
Linear algebra related utilities.
Defines the abstract interface for different components in XGBoost.
Learner interface that integrates objective, gbm and evaluation together. This is the user facing XGB...
Definition: base.h:89
std::vector< std::pair< std::string, std::string > > Args
Definition: base.h:306
Definition: model.h:28
Runtime context for XGBoost. Contains information like threads and device.
Definition: context.h:133
Container for gradient produced by objective.
Definition: gradient.h:16
A struct returned by objective, which determines task at hand. The struct is not used by any algorith...
Definition: task.h:24
Registry entry for tree updater.
Definition: tree_updater.h:115