xgboost
learner.h
Go to the documentation of this file.
1 
8 #ifndef XGBOOST_LEARNER_H_
9 #define XGBOOST_LEARNER_H_
10 
11 #include <rabit/rabit.h>
12 #include <utility>
13 #include <map>
14 #include <memory>
15 #include <string>
16 #include <vector>
17 #include "./base.h"
18 #include "./gbm.h"
19 #include "./metric.h"
20 #include "./objective.h"
21 
22 namespace xgboost {
39 class Learner : public rabit::Serializable {
40  public:
42  ~Learner() override = default;
49  template<typename PairIter>
50  inline void Configure(PairIter begin, PairIter end);
57  virtual void Configure(const std::vector<std::pair<std::string, std::string> >& cfg) = 0;
62  virtual void InitModel() = 0;
67  void Load(dmlc::Stream* fi) override = 0;
72  void Save(dmlc::Stream* fo) const override = 0;
79  virtual void UpdateOneIter(int iter, DMatrix* train) = 0;
87  virtual void BoostOneIter(int iter,
88  DMatrix* train,
89  HostDeviceVector<GradientPair>* in_gpair) = 0;
97  virtual std::string EvalOneIter(int iter,
98  const std::vector<DMatrix*>& data_sets,
99  const std::vector<std::string>& data_names) = 0;
112  virtual void Predict(DMatrix* data,
113  bool output_margin,
114  HostDeviceVector<bst_float> *out_preds,
115  unsigned ntree_limit = 0,
116  bool pred_leaf = false,
117  bool pred_contribs = false,
118  bool approx_contribs = false,
119  bool pred_interactions = false) const = 0;
120 
127  virtual void SetAttr(const std::string& key, const std::string& value) = 0;
135  virtual bool GetAttr(const std::string& key, std::string* out) const = 0;
141  virtual bool DelAttr(const std::string& key) = 0;
146  virtual std::vector<std::string> GetAttrNames() const = 0;
150  bool AllowLazyCheckPoint() const;
158  std::vector<std::string> DumpModel(const FeatureMap& fmap,
159  bool with_stats,
160  std::string format) const;
172  inline void Predict(const SparsePage::Inst &inst,
173  bool output_margin,
174  HostDeviceVector<bst_float> *out_preds,
175  unsigned ntree_limit = 0) const;
181  static Learner* Create(const std::vector<std::shared_ptr<DMatrix> >& cache_data);
182 
187  virtual const std::map<std::string, std::string>& GetConfigurationArguments() const = 0;
188 
189  protected:
193  std::unique_ptr<ObjFunction> obj_;
195  std::unique_ptr<GradientBooster> gbm_;
197  std::vector<std::unique_ptr<Metric> > metrics_;
198 };
199 
200 // implementation of inline functions.
201 inline void Learner::Predict(const SparsePage::Inst& inst,
202  bool output_margin,
203  HostDeviceVector<bst_float>* out_preds,
204  unsigned ntree_limit) const {
205  gbm_->PredictInstance(inst, &out_preds->HostVector(), ntree_limit);
206  if (!output_margin) {
207  obj_->PredTransform(out_preds);
208  }
209 }
210 
211 // implementing configure.
212 template<typename PairIter>
213 inline void Learner::Configure(PairIter begin, PairIter end) {
214  std::vector<std::pair<std::string, std::string> > vec(begin, end);
215  this->Configure(vec);
216 }
217 
218 } // namespace xgboost
219 #endif // XGBOOST_LEARNER_H_
virtual void BoostOneIter(int iter, DMatrix *train, HostDeviceVector< GradientPair > *in_gpair)=0
Do customized gradient boosting with in_gpair. in_gair can be mutated after this call.
float bst_float
float type, used for storing statistics
Definition: base.h:89
std::vector< std::unique_ptr< Metric > > metrics_
The evaluation metrics used to evaluate the model.
Definition: learner.h:197
bst_float base_score_
internal base score of the model
Definition: learner.h:191
virtual bool DelAttr(const std::string &key)=0
Delete an attribute from the booster.
virtual void UpdateOneIter(int iter, DMatrix *train)=0
update the model for one iteration With the specified objective function.
virtual void SetAttr(const std::string &key, const std::string &value)=0
Set additional attribute to the Booster. The property will be saved along the booster.
Definition: host_device_vector.h:200
Interface of gradient booster, that learns through gradient statistics.
std::vector< std::string > DumpModel(const FeatureMap &fmap, bool with_stats, std::string format) const
dump the model in the requested format
bool AllowLazyCheckPoint() const
virtual void InitModel()=0
Initialize the model using the specified configurations via Configure. An model have to be either Loa...
static Learner * Create(const std::vector< std::shared_ptr< DMatrix > > &cache_data)
Create a new instance of learner.
~Learner() override=default
virtual destructor
std::unique_ptr< GradientBooster > gbm_
The gradient booster used by the model.
Definition: learner.h:195
Internal data structured used by XGBoost during training. There are two ways to create a customized D...
Definition: data.h:406
Feature map data structure to help text model dump. TODO(tqchen) consider make it even more lightweig...
Definition: feature_map.h:20
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition: span.h:109
void Save(dmlc::Stream *fo) const override=0
save model to stream.
interface of objective function used by xgboost.
virtual bool GetAttr(const std::string &key, std::string *out) const =0
Get attribute from the booster. The property will be saved along the booster.
void Configure(PairIter begin, PairIter end)
set configuration from pair iterators.
Definition: learner.h:213
std::vector< T > & HostVector()
std::unique_ptr< ObjFunction > obj_
objective function
Definition: learner.h:193
interface of evaluation metric function supported in xgboost.
namespace of xgboost
Definition: base.h:79
defines configuration macros of xgboost.
Learner class that does training and prediction. This is the user facing module of xgboost training...
Definition: learner.h:39
virtual std::vector< std::string > GetAttrNames() const =0
Get a vector of attribute names from the booster.
virtual std::string EvalOneIter(int iter, const std::vector< DMatrix *> &data_sets, const std::vector< std::string > &data_names)=0
evaluate the model for specific iteration using the configured metrics.
virtual void Predict(DMatrix *data, bool output_margin, HostDeviceVector< bst_float > *out_preds, unsigned ntree_limit=0, bool pred_leaf=false, bool pred_contribs=false, bool approx_contribs=false, bool pred_interactions=false) const =0
get prediction given the model.
virtual const std::map< std::string, std::string > & GetConfigurationArguments() const =0
Get configuration arguments currently stored by the learner.
void Load(dmlc::Stream *fi) override=0
load model from stream