xgboost
tree_updater.h
Go to the documentation of this file.
1 
8 #ifndef XGBOOST_TREE_UPDATER_H_
9 #define XGBOOST_TREE_UPDATER_H_
10 
11 #include <dmlc/registry.h>
12 #include <xgboost/base.h>
13 #include <xgboost/data.h>
16 #include <xgboost/linalg.h>
17 #include <xgboost/model.h>
18 #include <xgboost/task.h>
19 #include <xgboost/tree_model.h>
20 
21 #include <functional>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 namespace xgboost {
27 
28 class Json;
29 
33 class TreeUpdater : public Configurable {
34  protected:
35  GenericParameter const* ctx_ = nullptr;
36 
37  public:
38  explicit TreeUpdater(const GenericParameter* ctx) : ctx_(ctx) {}
40  ~TreeUpdater() override = default;
45  virtual void Configure(const Args& args) = 0;
52  virtual bool CanModifyTree() const { return false; }
57  virtual bool HasNodePosition() const { return false; }
69  virtual void Update(HostDeviceVector<GradientPair>* gpair, DMatrix* data,
71  const std::vector<RegTree*>& out_trees) = 0;
72 
83  virtual bool UpdatePredictionCache(const DMatrix * /*data*/,
84  linalg::VectorView<float> /*out_preds*/) {
85  return false;
86  }
87 
88  virtual char const* Name() const = 0;
89 
95  static TreeUpdater* Create(const std::string& name, GenericParameter const* tparam, ObjInfo task);
96 };
97 
102  : public dmlc::FunctionRegEntryBase<
103  TreeUpdaterReg,
104  std::function<TreeUpdater*(GenericParameter const* tparam, ObjInfo task)> > {};
105 
118 #define XGBOOST_REGISTER_TREE_UPDATER(UniqueId, Name) \
119  static DMLC_ATTRIBUTE_UNUSED ::xgboost::TreeUpdaterReg& \
120  __make_ ## TreeUpdaterReg ## _ ## UniqueId ## __ = \
121  ::dmlc::Registry< ::xgboost::TreeUpdaterReg>::Get()->__REGISTER__(Name)
122 
123 } // namespace xgboost
124 #endif // XGBOOST_TREE_UPDATER_H_
defines configuration macros of xgboost.
Internal data structured used by XGBoost during training.
Definition: data.h:490
Definition: host_device_vector.h:86
interface of tree update module, that performs update of a tree.
Definition: tree_updater.h:33
~TreeUpdater() override=default
virtual destructor
static TreeUpdater * Create(const std::string &name, GenericParameter const *tparam, ObjInfo task)
Create a tree updater given name.
virtual bool UpdatePredictionCache(const DMatrix *, linalg::VectorView< float >)
determines whether updater has enough knowledge about a given dataset to quickly update prediction ca...
Definition: tree_updater.h:83
virtual bool CanModifyTree() const
Whether this updater can be used for updating existing trees.
Definition: tree_updater.h:52
virtual bool HasNodePosition() const
Wether the out_position in Update is valid. This determines whether adaptive tree can be used.
Definition: tree_updater.h:57
GenericParameter const * ctx_
Definition: tree_updater.h:35
TreeUpdater(const GenericParameter *ctx)
Definition: tree_updater.h:38
virtual void Configure(const Args &args)=0
Initialize the updater with given arguments.
virtual char const * Name() const =0
virtual void Update(HostDeviceVector< GradientPair > *gpair, DMatrix *data, common::Span< HostDeviceVector< bst_node_t >> out_position, const std::vector< RegTree * > &out_trees)=0
perform update to the tree models
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition: span.h:423
A tensor view with static type and dimension. It implements indexing and slicing.
Definition: linalg.h:262
The input data structure of xgboost.
A device-and-host vector abstraction layer.
Linear algebra related utilities.
Defines the abstract interface for different components in XGBoost.
namespace of xgboost
Definition: base.h:110
std::vector< std::pair< std::string, std::string > > Args
Definition: base.h:318
Definition: model.h:31
Definition: generic_parameters.h:15
A struct returned by objective, which determines task at hand. The struct is not used by any algorith...
Definition: task.h:24
Registry entry for tree updater.
Definition: tree_updater.h:104
model structure for tree