xgboost
predictor.h
Go to the documentation of this file.
1 
7 #pragma once
8 #include <xgboost/base.h>
9 #include <xgboost/data.h>
12 
13 #include <functional>
14 #include <memory>
15 #include <string>
16 #include <unordered_map>
17 #include <utility>
18 #include <vector>
19 #include <mutex>
20 
21 // Forward declarations
22 namespace xgboost {
23 class TreeUpdater;
24 namespace gbm {
25 struct GBTreeModel;
26 } // namespace gbm
27 }
28 
29 namespace xgboost {
36  // A storage for caching prediction values
38  // The version of current cache, corresponding number of layers of trees
39  uint32_t version { 0 };
40  // A weak pointer for checking whether the DMatrix object has expired.
41  std::weak_ptr< DMatrix > ref;
42 
43  PredictionCacheEntry() = default;
44  /* \brief Update the cache entry by number of versions.
45  *
46  * \param v Added versions.
47  */
48  void Update(uint32_t v) {
49  version += v;
50  }
51 };
52 
53 /* \brief A container for managed prediction caches.
54  */
56  std::unordered_map<DMatrix *, PredictionCacheEntry> container_;
57  void ClearExpiredEntries();
58  std::mutex cache_lock_;
59 
60  public:
61  PredictionContainer() = default;
62  /* \brief Add a new DMatrix to the cache, at the same time this function will clear out
63  * all expired caches by checking the `std::weak_ptr`. Caching an existing
64  * DMatrix won't renew it.
65  *
66  * Passing in a `shared_ptr` is critical here. First to create a `weak_ptr` inside the
67  * entry this shared pointer is necessary. More importantly, the life time of this
68  * cache is tied to the shared pointer.
69  *
70  * Another way to make a safe cache is create a proxy to this entry, with anther shared
71  * pointer defined inside, and pass this proxy around instead of the real entry. But
72  * seems to be too messy. In XGBoost, functions like `UpdateOneIter` will have
73  * (memory) safe access to the DMatrix as long as it's passed in as a `shared_ptr`.
74  *
75  * \param m shared pointer to the DMatrix that needs to be cached.
76  * \param device Which device should the cache be allocated on. Pass
77  * GenericParameter::kCpuId for CPU or positive integer for GPU id.
78  *
79  * \return the cache entry for passed in DMatrix, either an existing cache or newly
80  * created.
81  */
82  PredictionCacheEntry& Cache(std::shared_ptr<DMatrix> m, int32_t device);
83  /* \brief Get a prediction cache entry. This entry must be already allocated by `Cache`
84  * method. Otherwise a dmlc::Error is thrown.
85  *
86  * \param m pointer to the DMatrix.
87  * \return The prediction cache for passed in DMatrix.
88  */
90  /* \brief Get a const reference to the underlying hash map. Clear expired caches before
91  * returning.
92  */
93  decltype(container_) const& Container();
94 };
95 
104 class Predictor {
105  protected:
106  /*
107  * \brief Runtime parameters.
108  */
110 
111  public:
112  explicit Predictor(GenericParameter const* generic_param) :
113  generic_param_{generic_param} {}
114  virtual ~Predictor() = default;
115 
121  virtual void Configure(const std::vector<std::pair<std::string, std::string>>& cfg);
122 
134  virtual void PredictBatch(DMatrix* dmat, PredictionCacheEntry* out_preds,
135  const gbm::GBTreeModel& model, int tree_begin,
136  uint32_t const ntree_limit = 0) = 0;
137 
147  virtual void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model,
148  float missing, PredictionCacheEntry *out_preds,
149  uint32_t tree_begin = 0, uint32_t tree_end = 0) const = 0;
162  virtual void PredictInstance(const SparsePage::Inst& inst,
163  std::vector<bst_float>* out_preds,
164  const gbm::GBTreeModel& model,
165  unsigned ntree_limit = 0) = 0;
166 
181  virtual void PredictLeaf(DMatrix* dmat, std::vector<bst_float>* out_preds,
182  const gbm::GBTreeModel& model,
183  unsigned ntree_limit = 0) = 0;
184 
204  virtual void PredictContribution(DMatrix* dmat,
205  std::vector<bst_float>* out_contribs,
206  const gbm::GBTreeModel& model,
207  unsigned ntree_limit = 0,
208  std::vector<bst_float>* tree_weights = nullptr,
209  bool approximate = false,
210  int condition = 0,
211  unsigned condition_feature = 0) = 0;
212 
213  virtual void PredictInteractionContributions(DMatrix* dmat,
214  std::vector<bst_float>* out_contribs,
215  const gbm::GBTreeModel& model,
216  unsigned ntree_limit = 0,
217  std::vector<bst_float>* tree_weights = nullptr,
218  bool approximate = false) = 0;
219 
220 
227  static Predictor* Create(
228  std::string const& name, GenericParameter const* generic_param);
229 };
230 
235  : public dmlc::FunctionRegEntryBase<
236  PredictorReg, std::function<Predictor*(GenericParameter const*)>> {};
237 
238 #define XGBOOST_REGISTER_PREDICTOR(UniqueId, Name) \
239  static DMLC_ATTRIBUTE_UNUSED ::xgboost::PredictorReg& \
240  __make_##PredictorReg##_##UniqueId##__ = \
241  ::dmlc::Registry<::xgboost::PredictorReg>::Get()->__REGISTER__(Name)
242 } // namespace xgboost
Performs prediction on individual training instances or batches of instances for GBTree. Prediction functions all take a GBTreeModel and a DMatrix as input and output a vector of predictions. The predictor does not modify any state of the model itself.
Definition: predictor.h:104
The input data structure of xgboost.
Definition: generic_parameters.h:14
Internal data structured used by XGBoost during training. There are two ways to create a customized D...
Definition: data.h:451
Registry entry for predictor.
Definition: predictor.h:234
A device-and-host vector abstraction layer.
GenericParameter const * generic_param_
Definition: predictor.h:109
span class implementation, based on ISO++20 span<T>. The interface should be the same.
Definition: span.h:126
Predictor(GenericParameter const *generic_param)
Definition: predictor.h:112
std::weak_ptr< DMatrix > ref
Definition: predictor.h:41
HostDeviceVector< bst_float > predictions
Definition: predictor.h:37
Definition: predictor.h:55
namespace of xgboost
Definition: base.h:102
defines configuration macros of xgboost.
void Update(uint32_t v)
Definition: predictor.h:48
Element from a sparse vector.
Definition: data.h:167
Contains pointer to input matrix and associated cached predictions.
Definition: predictor.h:35