xgboost
predictor.h
Go to the documentation of this file.
1 
7 #pragma once
8 #include <xgboost/base.h>
9 #include <xgboost/cache.h> // for DMatrixCache
10 #include <xgboost/context.h> // for Context
11 #include <xgboost/context.h>
12 #include <xgboost/data.h>
14 
15 #include <functional> // for function
16 #include <memory> // for shared_ptr
17 #include <string>
18 #include <utility> // for make_pair
19 #include <vector>
20 
21 // Forward declarations
22 namespace xgboost::gbm {
23 struct GBTreeModel;
24 } // namespace xgboost::gbm
25 
26 namespace xgboost {
31  // A storage for caching prediction values
33  // The version of current cache, corresponding number of layers of trees
34  std::uint32_t version{0};
35 
36  PredictionCacheEntry() = default;
42  void Update(std::uint32_t v) { version += v; }
43  void Reset() { version = 0; }
44 };
45 
49 class PredictionContainer : public DMatrixCache<PredictionCacheEntry> {
50  // We cache up to 64 DMatrix for all threads
51  std::size_t static constexpr DefaultSize() { return 64; }
52 
53  public:
55  PredictionCacheEntry& Cache(std::shared_ptr<DMatrix> m, DeviceOrd device) {
56  auto p_cache = this->CacheItem(m);
57  if (device.IsCUDA()) {
58  p_cache->predictions.SetDevice(device);
59  }
60  return *p_cache;
61  }
62 };
63 
72 class Predictor {
73  protected:
74  Context const* ctx_;
75 
76  public:
77  explicit Predictor(Context const* ctx) : ctx_{ctx} {}
78 
79  virtual ~Predictor() = default;
80 
86  virtual void Configure(Args const&);
87 
95  virtual void InitOutPredictions(const MetaInfo& info, HostDeviceVector<bst_float>* out_predt,
96  const gbm::GBTreeModel& model) const;
97 
108  virtual void PredictBatch(DMatrix* dmat, PredictionCacheEntry* out_preds,
109  const gbm::GBTreeModel& model, uint32_t tree_begin,
110  uint32_t tree_end = 0) const = 0;
111 
125  virtual bool InplacePredict(std::shared_ptr<DMatrix> p_fmat, const gbm::GBTreeModel& model,
126  float missing, PredictionCacheEntry* out_preds,
127  uint32_t tree_begin = 0, uint32_t tree_end = 0) const = 0;
141  virtual void PredictInstance(const SparsePage::Inst& inst,
142  std::vector<bst_float>* out_preds,
143  const gbm::GBTreeModel& model,
144  unsigned tree_end = 0,
145  bool is_column_split = false) const = 0;
146 
157  virtual void PredictLeaf(DMatrix* dmat, HostDeviceVector<bst_float>* out_preds,
158  const gbm::GBTreeModel& model,
159  unsigned tree_end = 0) const = 0;
160 
176  virtual void
178  const gbm::GBTreeModel &model, unsigned tree_end = 0,
179  std::vector<bst_float> const *tree_weights = nullptr,
180  bool approximate = false, int condition = 0,
181  unsigned condition_feature = 0) const = 0;
182 
184  DMatrix *dmat, HostDeviceVector<bst_float> *out_contribs,
185  const gbm::GBTreeModel &model, unsigned tree_end = 0,
186  std::vector<bst_float> const *tree_weights = nullptr,
187  bool approximate = false) const = 0;
188 
195  static Predictor* Create(std::string const& name, Context const* ctx);
196 };
197 
202  : public dmlc::FunctionRegEntryBase<PredictorReg, std::function<Predictor*(Context const*)>> {};
203 
204 #define XGBOOST_REGISTER_PREDICTOR(UniqueId, Name) \
205  static DMLC_ATTRIBUTE_UNUSED ::xgboost::PredictorReg& \
206  __make_##PredictorReg##_##UniqueId##__ = \
207  ::dmlc::Registry<::xgboost::PredictorReg>::Get()->__REGISTER__(Name)
208 } // namespace xgboost
Defines configuration macros and basic types for xgboost.
Thread-aware FIFO cache for DMatrix related data.
Definition: cache.h:26
std::shared_ptr< PredictionCacheEntry > CacheItem(std::shared_ptr< DMatrix > m, Args const &... args)
Cache a new DMatrix if it's not in the cache already.
Definition: cache.h:145
Internal data structured used by XGBoost during training.
Definition: data.h:505
Meta information about dataset, always sit in memory.
Definition: data.h:47
A container for managed prediction caches.
Definition: predictor.h:49
PredictionContainer()
Definition: predictor.h:54
PredictionCacheEntry & Cache(std::shared_ptr< DMatrix > m, DeviceOrd device)
Definition: predictor.h:55
Performs prediction on individual training instances or batches of instances for GBTree....
Definition: predictor.h:72
virtual void PredictInteractionContributions(DMatrix *dmat, HostDeviceVector< bst_float > *out_contribs, const gbm::GBTreeModel &model, unsigned tree_end=0, std::vector< bst_float > const *tree_weights=nullptr, bool approximate=false) const =0
virtual void Configure(Args const &)
Configure and register input matrices in prediction cache.
virtual void InitOutPredictions(const MetaInfo &info, HostDeviceVector< bst_float > *out_predt, const gbm::GBTreeModel &model) const
Initialize output prediction.
virtual void PredictContribution(DMatrix *dmat, HostDeviceVector< bst_float > *out_contribs, const gbm::GBTreeModel &model, unsigned tree_end=0, std::vector< bst_float > const *tree_weights=nullptr, bool approximate=false, int condition=0, unsigned condition_feature=0) const =0
feature contributions to individual predictions; the output will be a vector of length (nfeats + 1) *...
Predictor(Context const *ctx)
Definition: predictor.h:77
virtual void PredictLeaf(DMatrix *dmat, HostDeviceVector< bst_float > *out_preds, const gbm::GBTreeModel &model, unsigned tree_end=0) const =0
predict the leaf index of each tree, the output will be nsample * ntree vector this is only valid in ...
Context const * ctx_
Definition: predictor.h:74
static Predictor * Create(std::string const &name, Context const *ctx)
Creates a new Predictor*.
virtual ~Predictor()=default
virtual bool InplacePredict(std::shared_ptr< DMatrix > p_fmat, const gbm::GBTreeModel &model, float missing, PredictionCacheEntry *out_preds, uint32_t tree_begin=0, uint32_t tree_end=0) const =0
Inplace prediction.
virtual void PredictInstance(const SparsePage::Inst &inst, std::vector< bst_float > *out_preds, const gbm::GBTreeModel &model, unsigned tree_end=0, bool is_column_split=false) const =0
online prediction function, predict score for one instance at a time NOTE: use the batch prediction i...
virtual void PredictBatch(DMatrix *dmat, PredictionCacheEntry *out_preds, const gbm::GBTreeModel &model, uint32_t tree_begin, uint32_t tree_end=0) const =0
Generate batch predictions for a given feature matrix. May use cached predictions if available instea...
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition: span.h:422
The input data structure of xgboost.
A device-and-host vector abstraction layer.
Definition: linear_updater.h:23
Core data structure for multi-target trees.
Definition: base.h:87
std::vector< std::pair< std::string, std::string > > Args
Definition: base.h:310
Runtime context for XGBoost. Contains information like threads and device.
Definition: context.h:133
A type for device ordinal. The type is packed into 32-bit for efficient use in viewing types like lin...
Definition: context.h:34
bool IsCUDA() const
Definition: context.h:44
Contains pointer to input matrix and associated cached predictions.
Definition: predictor.h:30
std::uint32_t version
Definition: predictor.h:34
HostDeviceVector< bst_float > predictions
Definition: predictor.h:32
void Reset()
Definition: predictor.h:43
void Update(std::uint32_t v)
Update the cache entry by number of versions.
Definition: predictor.h:42
Registry entry for predictor.
Definition: predictor.h:202